Spaces:
Sleeping
Sleeping
| import onnxruntime as ort | |
| import numpy as np | |
| import torch | |
| import time | |
| import argparse | |
| from typing import Set, Optional | |
| from .model import ByteTokenizer | |
| sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] | |
| class DRYLogitsProcessor: | |
| """ | |
| Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences. | |
| """ | |
| def __init__( | |
| self, | |
| multiplier: float = 0.5, | |
| base: float = 2.0, | |
| allowed_length: int = 1, | |
| sequence_breakers: Optional[Set[int]] = None, | |
| range: int = 512, | |
| ): | |
| """ | |
| Args: | |
| multiplier: Base penalty multiplier | |
| base: Exponential base for penalty calculation | |
| allowed_length: Length of sequence that's allowed to repeat without penalty | |
| sequence_breakers: Set of token IDs that should break sequence matching | |
| range: Number of previous tokens to consider for repetition checking | |
| """ | |
| self.multiplier = multiplier | |
| self.base = base | |
| self.allowed_length = allowed_length | |
| self.sequence_breakers = sequence_breakers or set() | |
| self.range = range | |
| def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray: | |
| """ | |
| Apply DRY penalty to logits. | |
| Args: | |
| input_ids: Array of shape (batch_size, seq_len) | |
| scores: Array of shape (vocab_size,) with logits | |
| Returns: | |
| Modified scores with penalties applied | |
| """ | |
| if self.range > 0: | |
| input_ids = input_ids[:, -self.range :] | |
| # Convert to torch tensors for easier manipulation | |
| input_tensor = torch.from_numpy(input_ids) | |
| scores_tensor = torch.from_numpy(scores) | |
| for input_ids_row in input_tensor: | |
| # Raw integer must be extracted here to check for set membership | |
| last_token = input_ids_row[-1].item() | |
| if last_token in self.sequence_breakers: | |
| continue | |
| # Exclude the last token as it always matches | |
| match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False) | |
| # Stores the maximum matching sequence length for each next token | |
| match_lengths = {} | |
| for i in match_indices.squeeze(1): | |
| i = i.item() | |
| if i + 1 >= len(input_ids_row): | |
| continue | |
| next_token = input_ids_row[i + 1].item() | |
| if next_token in self.sequence_breakers: | |
| continue | |
| # We have already found that `last_token` matches at this index, | |
| # so the match is at least of length 1. | |
| match_length = 1 | |
| # Extend the match backwards as far as possible | |
| while True: | |
| j = i - match_length | |
| if j < 0: | |
| break # Start of input reached | |
| if match_length + 1 > len(input_ids_row): | |
| break # End of input reached | |
| previous_token = input_ids_row[-(match_length + 1)].item() | |
| if input_ids_row[j] != previous_token: | |
| break # Start of match reached | |
| if previous_token in self.sequence_breakers: | |
| break # Sequence-breaking token reached | |
| match_length += 1 | |
| # Update the maximum match length for this next token | |
| if match_length >= match_lengths.get(next_token, 0): | |
| match_lengths[next_token] = match_length | |
| # Apply penalties | |
| for token, match_length in match_lengths.items(): | |
| if match_length >= self.allowed_length: | |
| penalty = self.multiplier * ( | |
| self.base ** (match_length - self.allowed_length) | |
| ) | |
| scores_tensor[token] -= penalty | |
| return scores_tensor.numpy() | |
| def generate_text( | |
| session, | |
| tokenizer, | |
| prompt, | |
| max_new_tokens=100, | |
| temperature=0.8, | |
| top_k=25, # There are only 256 bytes total | |
| stop_sequences=None, | |
| dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default | |
| dry_base: float = 2.0, | |
| dry_allowed_length: int = 20, # 20 since this is byte level. | |
| dry_sequence_breakers: Optional[Set[int]] = None, | |
| dry_range: int = 512, | |
| ): | |
| """Generate text using an ONNX model with DRY sampling and stop sequences.""" | |
| input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False) | |
| input_ids = np.array([input_ids_list], dtype=np.int64) | |
| generated_token_ids = [] | |
| start_time = time.time() | |
| for _ in range(max_new_tokens): | |
| seq_len = input_ids.shape[1] | |
| # Create a causal mask for the current sequence length. | |
| causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1) | |
| attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32) | |
| attn_mask[causal_mask] = -np.inf | |
| ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask} | |
| try: | |
| ort_outs = session.run(None, ort_inputs) | |
| except Exception as e: | |
| print(f"ONNX Runtime Error: {e}") | |
| # Potentially return or handle the error gracefully | |
| return "[ONNX Error]", 0 | |
| logits = ort_outs[0][0, -1, :] | |
| # Apply DRY penalty if enabled | |
| if dry_multiplier > 0: | |
| dry_processor = DRYLogitsProcessor( | |
| multiplier=dry_multiplier, | |
| base=dry_base, | |
| allowed_length=dry_allowed_length, | |
| sequence_breakers=dry_sequence_breakers, | |
| range=dry_range, | |
| ) | |
| logits = dry_processor(input_ids, logits) | |
| # Apply temperature scaling | |
| logits = logits / temperature | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| top_k = min(top_k, logits.shape[-1]) | |
| indices_to_remove = logits.argsort()[:-top_k] | |
| logits[indices_to_remove] = -float("inf") | |
| # Sample from the distribution | |
| probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy() | |
| next_token_id = np.random.choice(len(probs), p=probs) | |
| if next_token_id == tokenizer.im_end_id: | |
| break | |
| input_ids = np.append(input_ids, [[next_token_id]], axis=1) | |
| generated_token_ids.append(next_token_id) | |
| if stop_sequences: | |
| current_output = tokenizer.decode(np.array(generated_token_ids)) | |
| stop_generation = False | |
| for seq in stop_sequences: | |
| if current_output.endswith(seq): | |
| stop_generation = True | |
| # Remove the stop sequence from the generated text | |
| generated_token_ids = generated_token_ids[: -len(seq)] | |
| current_output = tokenizer.decode(np.array(generated_token_ids)) | |
| break | |
| if stop_generation: | |
| break | |
| final_text = tokenizer.decode(np.array(generated_token_ids)) | |
| tps = len(generated_token_ids) / (time.time() - start_time) | |
| return final_text, tps | |