from inference.inference import generate_text, list_checkpoints, load_model import argparse import torch from inference.model import ByteTokenizer def main(): parser = argparse.ArgumentParser( description="Text generation with DiffAttention LLM" ) parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") parser.add_argument( "--prompt", type=str, default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""", ) parser.add_argument( "--max_tokens", type=int, default=500, help="Maximum number of tokens to generate", ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)" ) parser.add_argument( "--top_p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter (0 to disable)", ) parser.add_argument( "--repetition_penalty", type=float, default=1.0, help="Repetition penalty (1.0 for no penalty)", ) parser.add_argument( "--list_checkpoints", action="store_true", help="List available checkpoints and exit", ) args = parser.parse_args() # List checkpoints if requested if args.list_checkpoints: print("Available checkpoints:") checkpoints = list_checkpoints() for i, ckpt in enumerate(checkpoints): print(f"{i+1}. {ckpt}") return # If no checkpoint specified, use the latest one if not args.checkpoint: checkpoints = list_checkpoints() if not checkpoints: print("No checkpoints found. Please train the model first.") return # Find the latest epoch_end checkpoint end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] if end_checkpoints: latest_checkpoint = max(end_checkpoints) else: latest_checkpoint = max(checkpoints) checkpoint_path = os.path.join("checkpoints", latest_checkpoint) else: checkpoint_path = args.checkpoint # Set device device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Using device: {device}") # Initialize tokenizer tokenizer = ByteTokenizer() # Load model model = load_model(checkpoint_path, device) # Generate text print(f"\nGenerating text with prompt: '{args.prompt}'") print( f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" ) print("\nGenerating...") generated_text, full_text = generate_text( model=model, tokenizer=tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, device=device, ) print("\n\nGenerated completion only:") print("-" * 40) print(generated_text) print("-" * 40) print("\nFull generated text (prompt + completion):") print("-" * 40) print(full_text) print("-" * 40) if __name__ == "__main__": import argparse main()