Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import torch.quantization | |
| from .model import ( | |
| DiffTransformerLLM, | |
| ByteTokenizer, | |
| IM_START_TOKEN, | |
| IM_END_TOKEN, | |
| PAD_TOKEN, | |
| ) | |
| force_CPU = True | |
| def list_checkpoints(checkpoint_dir="checkpoints"): | |
| """List all available checkpoints in the directory.""" | |
| if not os.path.exists(checkpoint_dir): | |
| print(f"Checkpoint directory {checkpoint_dir} not found.") | |
| return [] | |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] | |
| return sorted(checkpoints) | |
| def load_model(checkpoint_path, device=None, fp16=True): | |
| """Load a trained model from a checkpoint, applying optimizations as needed.""" | |
| import torch | |
| if device is None: | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" | |
| ) | |
| print(f"Loading checkpoint from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| # Hyperparams | |
| vocab_size = 259 # 256 bytes + 3 special tokens | |
| embed_dim = 768 | |
| num_layers = 28 | |
| num_heads = 12 | |
| ffn_hidden_dim = embed_dim * 4 | |
| max_seq_len = 512 | |
| dropout = 0.1 # For inference you can set dropout=0 | |
| # Model | |
| model = DiffTransformerLLM( | |
| vocab_size=vocab_size, | |
| embed_dim=embed_dim, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| ffn_hidden_dim=ffn_hidden_dim, | |
| max_seq_len=max_seq_len, | |
| dropout=dropout, | |
| ) | |
| # The checkpoint is the state dict itself | |
| state_dict = checkpoint | |
| # Load the state dict into the float32 model first | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Apply device-specific optimizations | |
| if device.type == "cpu": | |
| print("Optimizing for CPU with dynamic quantization (int8).") | |
| # Set the quantization engine | |
| torch.backends.quantized.engine = "qnnpack" | |
| # Quantize the linear layers to int8 for performance | |
| model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| elif device.type == "cuda" and fp16: | |
| print("Casting model to fp16 for CUDA.") | |
| model = model.half() | |
| model = model.to(device) | |
| print("Model loaded successfully.") | |
| return model | |
| def generate_text( | |
| model, | |
| tokenizer, | |
| prompt, | |
| max_new_tokens=100, | |
| temperature=1.0, | |
| top_k=0, | |
| top_p=0.9, | |
| repetition_penalty=1.0, | |
| device=None, | |
| stop_sequences=[], | |
| ): | |
| """ | |
| Generate text from a prompt using the trained model. | |
| Args: | |
| model: The trained DiffTransformerLLM model | |
| tokenizer: ByteTokenizer instance | |
| prompt: Text prompt to start generation (as a string) | |
| max_new_tokens: Maximum number of new tokens to generate | |
| temperature: Controls randomness. Lower is more deterministic. | |
| top_k: If > 0, only sample from the top k most likely tokens | |
| top_p: If > 0, sample from the smallest set of tokens whose cumulative probability exceeds p | |
| repetition_penalty: Penalize repetition. 1.0 means no penalty. | |
| device: Device to run inference on | |
| Returns: | |
| The generated text as a string | |
| """ | |
| if device is None: | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" | |
| ) | |
| # Convert prompt to bytes and tokenize - process as-is without adding special tokens | |
| prompt_bytes = prompt.encode("utf-8", errors="replace") | |
| input_ids = ( | |
| torch.tensor( | |
| tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long | |
| ) | |
| .unsqueeze(0) | |
| .to(device) | |
| ) | |
| stop_sequences = [ | |
| tokenizer.encode( | |
| seq.encode("utf-8", errors="replace"), add_special_tokens=False | |
| ) | |
| for seq in stop_sequences | |
| ] | |
| # Track generated token IDs | |
| generated_ids = input_ids.clone() | |
| generated_bytes = b"" | |
| # Set the model to evaluation mode | |
| model.eval() | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| # Only use the last max_seq_len tokens if we exceed the model's context length | |
| if generated_ids.size(1) > model.max_seq_len: | |
| input_ids = generated_ids[:, -model.max_seq_len :] | |
| else: | |
| input_ids = generated_ids | |
| # Forward pass to get logits for the next token | |
| logits = model(input_ids) | |
| # Get logits for the next token (last position) | |
| next_token_logits = logits[:, -1, :].squeeze(0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Apply repetition penalty | |
| if repetition_penalty > 1.0: | |
| for token_id in set(generated_ids[0].tolist()): | |
| next_token_logits[token_id] /= repetition_penalty | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) | |
| next_token_logits = torch.full_like(next_token_logits, float("-inf")) | |
| next_token_logits.scatter_(0, top_k_indices, top_k_logits) | |
| # Apply top-p (nucleus) filtering | |
| if 0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort( | |
| next_token_logits, descending=True | |
| ) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=0), dim=0) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |
| ..., :-1 | |
| ].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| next_token_logits[indices_to_remove] = float("-inf") | |
| # Sample from the filtered distribution | |
| probs = F.softmax(next_token_logits, dim=0) | |
| next_token = torch.multinomial(probs, 1) | |
| # Append the generated token to the sequence | |
| generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) | |
| # Check if IM_END_TOKEN has been generated | |
| token_bytes = tokenizer.decode([next_token.item()]) | |
| generated_bytes += token_bytes | |
| try: | |
| print(token_bytes.decode("utf-8", errors="replace"), end="", flush=True) | |
| except Exception as e: | |
| print(f"<Error decoding token: {e}>", end="", flush=True) | |
| stop_generated = False | |
| stop_seq = None | |
| for stop_seq in stop_sequences: | |
| if generated_ids.tolist()[0][-len(stop_seq) :] == stop_seq: | |
| stop_generated = True | |
| break | |
| if stop_generated: | |
| # Remove the stop sequence from the generated IDs | |
| generated_ids = generated_ids[:, : -len(stop_seq)] | |
| generated_bytes = generated_bytes[: -len(stop_seq)] | |
| break | |
| # Decode to bytes and then to string | |
| try: | |
| generated_text = generated_bytes.decode("utf-8", errors="replace") | |
| except Exception as e: | |
| print(f"\nError decoding generated text: {e}") | |
| generated_text = "<decoding error>" | |
| return generated_text, prompt + generated_text | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Text generation with DiffAttention LLM" | |
| ) | |
| parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="""\nHow many 'b's are in "barber"? \n""", | |
| ) | |
| parser.add_argument( | |
| "--max_tokens", | |
| type=int, | |
| default=500, | |
| help="Maximum number of tokens to generate", | |
| ) | |
| parser.add_argument( | |
| "--temperature", type=float, default=0.7, help="Sampling temperature" | |
| ) | |
| parser.add_argument( | |
| "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)" | |
| ) | |
| parser.add_argument( | |
| "--top_p", | |
| type=float, | |
| default=0.9, | |
| help="Top-p (nucleus) sampling parameter (0 to disable)", | |
| ) | |
| parser.add_argument( | |
| "--repetition_penalty", | |
| type=float, | |
| default=1.2, | |
| help="Repetition penalty (1.0 for no penalty)", | |
| ) | |
| parser.add_argument( | |
| "--list_checkpoints", | |
| action="store_true", | |
| help="List available checkpoints and exit", | |
| ) | |
| args = parser.parse_args() | |
| # List checkpoints if requested | |
| if args.list_checkpoints: | |
| print("Available checkpoints:") | |
| checkpoints = list_checkpoints() | |
| for i, ckpt in enumerate(checkpoints): | |
| print(f"{i+1}. {ckpt}") | |
| return | |
| # If no checkpoint specified, use the latest one | |
| if not args.checkpoint: | |
| checkpoints = list_checkpoints() | |
| if not checkpoints: | |
| print("No checkpoints found. Please train the model first.") | |
| return | |
| # Find the latest epoch_end checkpoint | |
| end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] | |
| if end_checkpoints: | |
| latest_checkpoint = max(end_checkpoints) | |
| else: | |
| latest_checkpoint = max(checkpoints) | |
| checkpoint_path = os.path.join("checkpoints", latest_checkpoint) | |
| else: | |
| checkpoint_path = args.checkpoint | |
| # Set device | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" | |
| ) | |
| print(f"Using device: {device}") | |
| # Initialize tokenizer | |
| tokenizer = ByteTokenizer() | |
| # Load model | |
| model = load_model(checkpoint_path, device) | |
| # Generate text | |
| print(f"\nGenerating text with prompt: '{args.prompt}'") | |
| print( | |
| f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" | |
| ) | |
| print("\nGenerating...") | |
| generated_text, full_text = generate_text( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| device=device, | |
| ) | |
| print("\n\nGenerated completion only:") | |
| print("-" * 40) | |
| print(generated_text) | |
| print("-" * 40) | |
| print("\nFull generated text (prompt + completion):") | |
| print("-" * 40) | |
| print(full_text) | |
| print("-" * 40) | |
| if __name__ == "__main__": | |
| import argparse | |
| main() | |