Spaces:
Sleeping
Sleeping
| from inference.inference import generate_text, list_checkpoints, load_model | |
| import argparse | |
| import torch | |
| from inference.model import ByteTokenizer | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Text generation with DiffAttention LLM" | |
| ) | |
| parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""", | |
| ) | |
| parser.add_argument( | |
| "--max_tokens", | |
| type=int, | |
| default=500, | |
| help="Maximum number of tokens to generate", | |
| ) | |
| parser.add_argument( | |
| "--temperature", type=float, default=0.7, help="Sampling temperature" | |
| ) | |
| parser.add_argument( | |
| "--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)" | |
| ) | |
| parser.add_argument( | |
| "--top_p", | |
| type=float, | |
| default=0.9, | |
| help="Top-p (nucleus) sampling parameter (0 to disable)", | |
| ) | |
| parser.add_argument( | |
| "--repetition_penalty", | |
| type=float, | |
| default=1.0, | |
| help="Repetition penalty (1.0 for no penalty)", | |
| ) | |
| parser.add_argument( | |
| "--list_checkpoints", | |
| action="store_true", | |
| help="List available checkpoints and exit", | |
| ) | |
| args = parser.parse_args() | |
| # List checkpoints if requested | |
| if args.list_checkpoints: | |
| print("Available checkpoints:") | |
| checkpoints = list_checkpoints() | |
| for i, ckpt in enumerate(checkpoints): | |
| print(f"{i+1}. {ckpt}") | |
| return | |
| # If no checkpoint specified, use the latest one | |
| if not args.checkpoint: | |
| checkpoints = list_checkpoints() | |
| if not checkpoints: | |
| print("No checkpoints found. Please train the model first.") | |
| return | |
| # Find the latest epoch_end checkpoint | |
| end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] | |
| if end_checkpoints: | |
| latest_checkpoint = max(end_checkpoints) | |
| else: | |
| latest_checkpoint = max(checkpoints) | |
| checkpoint_path = os.path.join("checkpoints", latest_checkpoint) | |
| else: | |
| checkpoint_path = args.checkpoint | |
| # Set device | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" | |
| ) | |
| print(f"Using device: {device}") | |
| # Initialize tokenizer | |
| tokenizer = ByteTokenizer() | |
| # Load model | |
| model = load_model(checkpoint_path, device) | |
| # Generate text | |
| print(f"\nGenerating text with prompt: '{args.prompt}'") | |
| print( | |
| f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" | |
| ) | |
| print("\nGenerating...") | |
| generated_text, full_text = generate_text( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| device=device, | |
| ) | |
| print("\n\nGenerated completion only:") | |
| print("-" * 40) | |
| print(generated_text) | |
| print("-" * 40) | |
| print("\nFull generated text (prompt + completion):") | |
| print("-" * 40) | |
| print(full_text) | |
| print("-" * 40) | |
| if __name__ == "__main__": | |
| import argparse | |
| main() | |