DAT-Byte-Demo / inference /inference.py
hudsongouge's picture
Update space
adf0368
import torch
import torch.nn.functional as F
import os
import torch.quantization
from .model import (
DiffTransformerLLM,
ByteTokenizer,
IM_START_TOKEN,
IM_END_TOKEN,
PAD_TOKEN,
)
force_CPU = True
def list_checkpoints(checkpoint_dir="checkpoints"):
"""List all available checkpoints in the directory."""
if not os.path.exists(checkpoint_dir):
print(f"Checkpoint directory {checkpoint_dir} not found.")
return []
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
return sorted(checkpoints)
def load_model(checkpoint_path, device=None, fp16=True):
"""Load a trained model from a checkpoint, applying optimizations as needed."""
import torch
if device is None:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Hyperparams
vocab_size = 259 # 256 bytes + 3 special tokens
embed_dim = 768
num_layers = 28
num_heads = 12
ffn_hidden_dim = embed_dim * 4
max_seq_len = 512
dropout = 0.1 # For inference you can set dropout=0
# Model
model = DiffTransformerLLM(
vocab_size=vocab_size,
embed_dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
ffn_hidden_dim=ffn_hidden_dim,
max_seq_len=max_seq_len,
dropout=dropout,
)
# The checkpoint is the state dict itself
state_dict = checkpoint
# Load the state dict into the float32 model first
model.load_state_dict(state_dict)
model.eval()
# Apply device-specific optimizations
if device.type == "cpu":
print("Optimizing for CPU with dynamic quantization (int8).")
# Set the quantization engine
torch.backends.quantized.engine = "qnnpack"
# Quantize the linear layers to int8 for performance
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
elif device.type == "cuda" and fp16:
print("Casting model to fp16 for CUDA.")
model = model.half()
model = model.to(device)
print("Model loaded successfully.")
return model
def generate_text(
model,
tokenizer,
prompt,
max_new_tokens=100,
temperature=1.0,
top_k=0,
top_p=0.9,
repetition_penalty=1.0,
device=None,
stop_sequences=[],
):
"""
Generate text from a prompt using the trained model.
Args:
model: The trained DiffTransformerLLM model
tokenizer: ByteTokenizer instance
prompt: Text prompt to start generation (as a string)
max_new_tokens: Maximum number of new tokens to generate
temperature: Controls randomness. Lower is more deterministic.
top_k: If > 0, only sample from the top k most likely tokens
top_p: If > 0, sample from the smallest set of tokens whose cumulative probability exceeds p
repetition_penalty: Penalize repetition. 1.0 means no penalty.
device: Device to run inference on
Returns:
The generated text as a string
"""
if device is None:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
# Convert prompt to bytes and tokenize - process as-is without adding special tokens
prompt_bytes = prompt.encode("utf-8", errors="replace")
input_ids = (
torch.tensor(
tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
)
.unsqueeze(0)
.to(device)
)
stop_sequences = [
tokenizer.encode(
seq.encode("utf-8", errors="replace"), add_special_tokens=False
)
for seq in stop_sequences
]
# Track generated token IDs
generated_ids = input_ids.clone()
generated_bytes = b""
# Set the model to evaluation mode
model.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
# Only use the last max_seq_len tokens if we exceed the model's context length
if generated_ids.size(1) > model.max_seq_len:
input_ids = generated_ids[:, -model.max_seq_len :]
else:
input_ids = generated_ids
# Forward pass to get logits for the next token
logits = model(input_ids)
# Get logits for the next token (last position)
next_token_logits = logits[:, -1, :].squeeze(0)
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Apply repetition penalty
if repetition_penalty > 1.0:
for token_id in set(generated_ids[0].tolist()):
next_token_logits[token_id] /= repetition_penalty
# Apply top-k filtering
if top_k > 0:
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
next_token_logits = torch.full_like(next_token_logits, float("-inf"))
next_token_logits.scatter_(0, top_k_indices, top_k_logits)
# Apply top-p (nucleus) filtering
if 0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(
next_token_logits, descending=True
)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=0), dim=0)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float("-inf")
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=0)
next_token = torch.multinomial(probs, 1)
# Append the generated token to the sequence
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
# Check if IM_END_TOKEN has been generated
token_bytes = tokenizer.decode([next_token.item()])
generated_bytes += token_bytes
try:
print(token_bytes.decode("utf-8", errors="replace"), end="", flush=True)
except Exception as e:
print(f"<Error decoding token: {e}>", end="", flush=True)
stop_generated = False
stop_seq = None
for stop_seq in stop_sequences:
if generated_ids.tolist()[0][-len(stop_seq) :] == stop_seq:
stop_generated = True
break
if stop_generated:
# Remove the stop sequence from the generated IDs
generated_ids = generated_ids[:, : -len(stop_seq)]
generated_bytes = generated_bytes[: -len(stop_seq)]
break
# Decode to bytes and then to string
try:
generated_text = generated_bytes.decode("utf-8", errors="replace")
except Exception as e:
print(f"\nError decoding generated text: {e}")
generated_text = "<decoding error>"
return generated_text, prompt + generated_text
def main():
parser = argparse.ArgumentParser(
description="Text generation with DiffAttention LLM"
)
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
parser.add_argument(
"--prompt",
type=str,
default="""\nHow many 'b's are in "barber"? \n""",
)
parser.add_argument(
"--max_tokens",
type=int,
default=500,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Sampling temperature"
)
parser.add_argument(
"--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
)
parser.add_argument(
"--top_p",
type=float,
default=0.9,
help="Top-p (nucleus) sampling parameter (0 to disable)",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.2,
help="Repetition penalty (1.0 for no penalty)",
)
parser.add_argument(
"--list_checkpoints",
action="store_true",
help="List available checkpoints and exit",
)
args = parser.parse_args()
# List checkpoints if requested
if args.list_checkpoints:
print("Available checkpoints:")
checkpoints = list_checkpoints()
for i, ckpt in enumerate(checkpoints):
print(f"{i+1}. {ckpt}")
return
# If no checkpoint specified, use the latest one
if not args.checkpoint:
checkpoints = list_checkpoints()
if not checkpoints:
print("No checkpoints found. Please train the model first.")
return
# Find the latest epoch_end checkpoint
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
if end_checkpoints:
latest_checkpoint = max(end_checkpoints)
else:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
else:
checkpoint_path = args.checkpoint
# Set device
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Using device: {device}")
# Initialize tokenizer
tokenizer = ByteTokenizer()
# Load model
model = load_model(checkpoint_path, device)
# Generate text
print(f"\nGenerating text with prompt: '{args.prompt}'")
print(
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
)
print("\nGenerating...")
generated_text, full_text = generate_text(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
device=device,
)
print("\n\nGenerated completion only:")
print("-" * 40)
print(generated_text)
print("-" * 40)
print("\nFull generated text (prompt + completion):")
print("-" * 40)
print(full_text)
print("-" * 40)
if __name__ == "__main__":
import argparse
main()