import torch import torch.nn.functional as F import os import torch.quantization from .model import ( DiffTransformerLLM, ByteTokenizer, IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN, ) force_CPU = True def list_checkpoints(checkpoint_dir="checkpoints"): """List all available checkpoints in the directory.""" if not os.path.exists(checkpoint_dir): print(f"Checkpoint directory {checkpoint_dir} not found.") return [] checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] return sorted(checkpoints) def load_model(checkpoint_path, device=None, fp16=True): """Load a trained model from a checkpoint, applying optimizations as needed.""" import torch if device is None: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") # Hyperparams vocab_size = 259 # 256 bytes + 3 special tokens embed_dim = 768 num_layers = 28 num_heads = 12 ffn_hidden_dim = embed_dim * 4 max_seq_len = 512 dropout = 0.1 # For inference you can set dropout=0 # Model model = DiffTransformerLLM( vocab_size=vocab_size, embed_dim=embed_dim, num_layers=num_layers, num_heads=num_heads, ffn_hidden_dim=ffn_hidden_dim, max_seq_len=max_seq_len, dropout=dropout, ) # The checkpoint is the state dict itself state_dict = checkpoint # Load the state dict into the float32 model first model.load_state_dict(state_dict) model.eval() # Apply device-specific optimizations if device.type == "cpu": print("Optimizing for CPU with dynamic quantization (int8).") # Set the quantization engine torch.backends.quantized.engine = "qnnpack" # Quantize the linear layers to int8 for performance model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) elif device.type == "cuda" and fp16: print("Casting model to fp16 for CUDA.") model = model.half() model = model.to(device) print("Model loaded successfully.") return model def generate_text( model, tokenizer, prompt, max_new_tokens=100, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device=None, stop_sequences=[], ): """ Generate text from a prompt using the trained model. Args: model: The trained DiffTransformerLLM model tokenizer: ByteTokenizer instance prompt: Text prompt to start generation (as a string) max_new_tokens: Maximum number of new tokens to generate temperature: Controls randomness. Lower is more deterministic. top_k: If > 0, only sample from the top k most likely tokens top_p: If > 0, sample from the smallest set of tokens whose cumulative probability exceeds p repetition_penalty: Penalize repetition. 1.0 means no penalty. device: Device to run inference on Returns: The generated text as a string """ if device is None: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) # Convert prompt to bytes and tokenize - process as-is without adding special tokens prompt_bytes = prompt.encode("utf-8", errors="replace") input_ids = ( torch.tensor( tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long ) .unsqueeze(0) .to(device) ) stop_sequences = [ tokenizer.encode( seq.encode("utf-8", errors="replace"), add_special_tokens=False ) for seq in stop_sequences ] # Track generated token IDs generated_ids = input_ids.clone() generated_bytes = b"" # Set the model to evaluation mode model.eval() with torch.no_grad(): for _ in range(max_new_tokens): # Only use the last max_seq_len tokens if we exceed the model's context length if generated_ids.size(1) > model.max_seq_len: input_ids = generated_ids[:, -model.max_seq_len :] else: input_ids = generated_ids # Forward pass to get logits for the next token logits = model(input_ids) # Get logits for the next token (last position) next_token_logits = logits[:, -1, :].squeeze(0) # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature # Apply repetition penalty if repetition_penalty > 1.0: for token_id in set(generated_ids[0].tolist()): next_token_logits[token_id] /= repetition_penalty # Apply top-k filtering if top_k > 0: top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) next_token_logits = torch.full_like(next_token_logits, float("-inf")) next_token_logits.scatter_(0, top_k_indices, top_k_logits) # Apply top-p (nucleus) filtering if 0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort( next_token_logits, descending=True ) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=0), dim=0) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float("-inf") # Sample from the filtered distribution probs = F.softmax(next_token_logits, dim=0) next_token = torch.multinomial(probs, 1) # Append the generated token to the sequence generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) # Check if IM_END_TOKEN has been generated token_bytes = tokenizer.decode([next_token.item()]) generated_bytes += token_bytes try: print(token_bytes.decode("utf-8", errors="replace"), end="", flush=True) except Exception as e: print(f"", end="", flush=True) stop_generated = False stop_seq = None for stop_seq in stop_sequences: if generated_ids.tolist()[0][-len(stop_seq) :] == stop_seq: stop_generated = True break if stop_generated: # Remove the stop sequence from the generated IDs generated_ids = generated_ids[:, : -len(stop_seq)] generated_bytes = generated_bytes[: -len(stop_seq)] break # Decode to bytes and then to string try: generated_text = generated_bytes.decode("utf-8", errors="replace") except Exception as e: print(f"\nError decoding generated text: {e}") generated_text = "" return generated_text, prompt + generated_text def main(): parser = argparse.ArgumentParser( description="Text generation with DiffAttention LLM" ) parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") parser.add_argument( "--prompt", type=str, default="""\nHow many 'b's are in "barber"? \n""", ) parser.add_argument( "--max_tokens", type=int, default=500, help="Maximum number of tokens to generate", ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)" ) parser.add_argument( "--top_p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter (0 to disable)", ) parser.add_argument( "--repetition_penalty", type=float, default=1.2, help="Repetition penalty (1.0 for no penalty)", ) parser.add_argument( "--list_checkpoints", action="store_true", help="List available checkpoints and exit", ) args = parser.parse_args() # List checkpoints if requested if args.list_checkpoints: print("Available checkpoints:") checkpoints = list_checkpoints() for i, ckpt in enumerate(checkpoints): print(f"{i+1}. {ckpt}") return # If no checkpoint specified, use the latest one if not args.checkpoint: checkpoints = list_checkpoints() if not checkpoints: print("No checkpoints found. Please train the model first.") return # Find the latest epoch_end checkpoint end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] if end_checkpoints: latest_checkpoint = max(end_checkpoints) else: latest_checkpoint = max(checkpoints) checkpoint_path = os.path.join("checkpoints", latest_checkpoint) else: checkpoint_path = args.checkpoint # Set device device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Using device: {device}") # Initialize tokenizer tokenizer = ByteTokenizer() # Load model model = load_model(checkpoint_path, device) # Generate text print(f"\nGenerating text with prompt: '{args.prompt}'") print( f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" ) print("\nGenerating...") generated_text, full_text = generate_text( model=model, tokenizer=tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, device=device, ) print("\n\nGenerated completion only:") print("-" * 40) print(generated_text) print("-" * 40) print("\nFull generated text (prompt + completion):") print("-" * 40) print(full_text) print("-" * 40) if __name__ == "__main__": import argparse main()