critique fix: logit_bias.py
Browse files
tensegrity/graft/logit_bias.py
CHANGED
|
@@ -261,9 +261,19 @@ class TensegrityLogitsProcessor:
|
|
| 261 |
token_scores = self.hypothesis_token_scores.get(hyp_id, {})
|
| 262 |
|
| 263 |
if prob <= self.suppress_threshold:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
for tid in token_ids:
|
| 265 |
if 0 <= tid < self.vocab_size:
|
| 266 |
-
bias[tid] =
|
| 267 |
suppressed += 1
|
| 268 |
else:
|
| 269 |
b = self.scale * math.log(max(float(prob), 1e-12) / p_uniform)
|
|
@@ -377,7 +387,7 @@ class StaticLogitBiasBuilder:
|
|
| 377 |
|
| 378 |
if prob <= self.suppress_threshold:
|
| 379 |
for tid in token_ids:
|
| 380 |
-
bias[tid] = -
|
| 381 |
else:
|
| 382 |
b = self.scale * math.log(max(prob, 1e-9) / p_uniform)
|
| 383 |
b = max(-self.max_bias, min(self.max_bias, b))
|
|
|
|
| 261 |
token_scores = self.hypothesis_token_scores.get(hyp_id, {})
|
| 262 |
|
| 263 |
if prob <= self.suppress_threshold:
|
| 264 |
+
# Dynamic temperature scaling instead of hard -inf.
|
| 265 |
+
# The review correctly identified that hard suppression
|
| 266 |
+
# to -inf collides with the LLM's syntactic expectations,
|
| 267 |
+
# causing broken grammar when suppressed tokens are
|
| 268 |
+
# structurally necessary (pronouns, conjunctions, etc.).
|
| 269 |
+
#
|
| 270 |
+
# Instead: apply a strong but finite negative bias that
|
| 271 |
+
# makes the token very unlikely but not impossible. The
|
| 272 |
+
# LLM can still use it if syntactic context demands it.
|
| 273 |
+
suppress_bias = -self.max_bias # e.g., -8.0 instead of -inf
|
| 274 |
for tid in token_ids:
|
| 275 |
if 0 <= tid < self.vocab_size:
|
| 276 |
+
bias[tid] = suppress_bias
|
| 277 |
suppressed += 1
|
| 278 |
else:
|
| 279 |
b = self.scale * math.log(max(float(prob), 1e-12) / p_uniform)
|
|
|
|
| 387 |
|
| 388 |
if prob <= self.suppress_threshold:
|
| 389 |
for tid in token_ids:
|
| 390 |
+
bias[tid] = -self.max_bias # Finite suppress, not -100
|
| 391 |
else:
|
| 392 |
b = self.scale * math.log(max(prob, 1e-9) / p_uniform)
|
| 393 |
b = max(-self.max_bias, min(self.max_bias, b))
|