UMCU commited on
Commit
772b2e0
·
verified ·
1 Parent(s): c2e82c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -12
app.py CHANGED
@@ -25,7 +25,7 @@ except ImportError:
25
  }
26
  MODEL_SETTINGS = {"max_length": 512}
27
  VIZ_SETTINGS = {
28
- "max_perplexity_display": 1000.0,
29
  "color_scheme": {
30
  "low_perplexity": {"r": 46, "g": 204, "b": 113},
31
  "medium_perplexity": {"r": 241, "g": 196, "b": 15},
@@ -97,6 +97,45 @@ cached_models = {}
97
  cached_tokenizers = {}
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def load_model_and_tokenizer(model_name, model_type):
101
  """Load and cache model and tokenizer"""
102
  cache_key = f"{model_name}_{model_type}"
@@ -184,17 +223,23 @@ def calculate_decoder_perplexity(text, model, tokenizer):
184
  # Get tokens (excluding the first one since we predict next tokens)
185
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:])
186
 
187
- # Clean up tokens for display
188
  cleaned_tokens = []
189
- for token in tokens:
 
 
 
 
 
190
  if token.startswith("Ġ"):
191
  cleaned_tokens.append(token[1:]) # Remove Ġ prefix
192
  elif token.startswith("##"):
193
  cleaned_tokens.append(token[2:]) # Remove ## prefix
194
  else:
195
  cleaned_tokens.append(token)
 
196
 
197
- return perplexity, cleaned_tokens, token_perplexities
198
 
199
 
200
  def calculate_encoder_perplexity(
@@ -303,15 +348,23 @@ def calculate_encoder_perplexity(
303
  # Fallback if no samples collected (shouldn't happen with proper min_samples)
304
  token_perplexities.append(2.0)
305
 
306
- # Clean up tokens for display
307
  cleaned_tokens = []
308
- for token in tokens:
 
 
 
 
 
 
 
309
  if token.startswith("##"):
310
  cleaned_tokens.append(token[2:])
311
  else:
312
  cleaned_tokens.append(token)
 
313
 
314
- return overall_perplexity, cleaned_tokens, np.array(token_perplexities)
315
 
316
 
317
  def perplexity_to_color(perplexity, min_perp=1, max_perp=1000):
@@ -365,7 +418,7 @@ def create_visualization(tokens, perplexities):
365
  return "<p>No tokens to visualize.</p>"
366
 
367
  # Cap perplexities for better visualization
368
- max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"])
369
 
370
  # Normalize perplexities to 0-1 range for color mapping
371
  normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
@@ -389,20 +442,29 @@ def create_visualization(tokens, perplexities):
389
  if not token.strip():
390
  continue
391
 
 
 
 
 
392
  # Clean token for display
 
393
  clean_token = (
394
- token.replace("</w>", "").replace("##", "").replace("Ġ", "").strip()
 
 
 
 
395
  )
396
  if not clean_token:
397
  continue
398
 
399
  # Add space before token if needed
400
- if i > 0 and not clean_token[0] in ".,!?;:":
401
  html_parts.append(" ")
402
 
403
  # Get color thresholds from configuration
404
- low_thresh = VIZ_SETTINGS.get("thresholds", {}).get("low_threshold", 0.3)
405
- high_thresh = VIZ_SETTINGS.get("thresholds", {}).get("high_threshold", 0.7)
406
 
407
  # Get colors from configuration
408
  # low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
 
25
  }
26
  MODEL_SETTINGS = {"max_length": 512}
27
  VIZ_SETTINGS = {
28
+ "max_perplexity_display": 5000.0,
29
  "color_scheme": {
30
  "low_perplexity": {"r": 46, "g": 204, "b": 113},
31
  "medium_perplexity": {"r": 241, "g": 196, "b": 15},
 
97
  cached_tokenizers = {}
98
 
99
 
100
+ def is_special_character(token):
101
+ """
102
+ Check if a token is only special characters/punctuation.
103
+
104
+ Args:
105
+ token: The token string to check
106
+
107
+ Returns:
108
+ True if token contains only special characters, False otherwise
109
+
110
+ Examples:
111
+ >>> is_special_character(".")
112
+ True
113
+ >>> is_special_character(",")
114
+ True
115
+ >>> is_special_character("hello")
116
+ False
117
+ >>> is_special_character("Ġ,")
118
+ True
119
+ >>> is_special_character("##!")
120
+ True
121
+ """
122
+ # Clean up common tokenizer artifacts
123
+ clean_token = (
124
+ token.replace("</w>", "")
125
+ .replace("##", "")
126
+ .replace("Ġ", "")
127
+ .replace("Ċ", "")
128
+ .strip()
129
+ )
130
+
131
+ # Check if empty after cleaning
132
+ if not clean_token:
133
+ return True
134
+
135
+ # Check if token contains only punctuation and special characters
136
+ return all(not c.isalnum() for c in clean_token)
137
+
138
+
139
  def load_model_and_tokenizer(model_name, model_type):
140
  """Load and cache model and tokenizer"""
141
  cache_key = f"{model_name}_{model_type}"
 
223
  # Get tokens (excluding the first one since we predict next tokens)
224
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:])
225
 
226
+ # Clean up tokens for display and filter special characters
227
  cleaned_tokens = []
228
+ filtered_perplexities = []
229
+ for token, token_perp in zip(tokens, token_perplexities):
230
+ # Skip special characters
231
+ if is_special_character(token):
232
+ continue
233
+
234
  if token.startswith("Ġ"):
235
  cleaned_tokens.append(token[1:]) # Remove Ġ prefix
236
  elif token.startswith("##"):
237
  cleaned_tokens.append(token[2:]) # Remove ## prefix
238
  else:
239
  cleaned_tokens.append(token)
240
+ filtered_perplexities.append(token_perp)
241
 
242
+ return perplexity, cleaned_tokens, np.array(filtered_perplexities)
243
 
244
 
245
  def calculate_encoder_perplexity(
 
348
  # Fallback if no samples collected (shouldn't happen with proper min_samples)
349
  token_perplexities.append(2.0)
350
 
351
+ # Clean up tokens for display and filter special characters
352
  cleaned_tokens = []
353
+ filtered_perplexities = []
354
+ for idx, (token, token_perp) in enumerate(zip(tokens, token_perplexities)):
355
+ # Skip special characters and tokenizer special tokens
356
+ if input_ids[0, idx].item() in special_token_ids:
357
+ continue
358
+ if is_special_character(token):
359
+ continue
360
+
361
  if token.startswith("##"):
362
  cleaned_tokens.append(token[2:])
363
  else:
364
  cleaned_tokens.append(token)
365
+ filtered_perplexities.append(token_perp)
366
 
367
+ return overall_perplexity, cleaned_tokens, np.array(filtered_perplexities)
368
 
369
 
370
  def perplexity_to_color(perplexity, min_perp=1, max_perp=1000):
 
418
  return "<p>No tokens to visualize.</p>"
419
 
420
  # Cap perplexities for better visualization
421
+ max_perplexity = np.max(perplexities)
422
 
423
  # Normalize perplexities to 0-1 range for color mapping
424
  normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
 
442
  if not token.strip():
443
  continue
444
 
445
+ # Skip special characters (already filtered in calculation functions)
446
+ if is_special_character(token):
447
+ continue
448
+
449
  # Clean token for display
450
+ # </w>, ##, Ġ, Ċ
451
  clean_token = (
452
+ token.replace("</w>", "")
453
+ .replace("##", "")
454
+ .replace("Ġ", "")
455
+ .replace("Ċ", "")
456
+ .strip()
457
  )
458
  if not clean_token:
459
  continue
460
 
461
  # Add space before token if needed
462
+ if i > 0 and clean_token[0] not in ".,!?;:":
463
  html_parts.append(" ")
464
 
465
  # Get color thresholds from configuration
466
+ # low_thresh = VIZ_SETTINGS.get("thresholds", {}).get("low_threshold", 0.3)
467
+ # high_thresh = VIZ_SETTINGS.get("thresholds", {}).get("high_threshold", 0.7)
468
 
469
  # Get colors from configuration
470
  # low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]