Pomilon
Deploy Aetheris to HF Space
1df0e33
import torch
import torch.nn.functional as F
from typing import Optional, List, Generator
from aetheris.config import AetherisConfig
from aetheris.model import HybridMambaMoE
from aetheris.data import get_tokenizer
from aetheris.utils import load_latest_checkpoint
class InferenceEngine:
def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
self.config = AetherisConfig.from_yaml(config_path)
self.tokenizer = get_tokenizer()
self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
# Load checkpoint
# Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
self.model.eval()
def generate(self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 0,
top_p: float = 0.9,
repetition_penalty: float = 1.0,
stream: bool = False) -> Generator[str, None, None] | str:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
generated_ids = input_ids.clone()
history_ids = set(input_ids[0].tolist())
def token_generator():
nonlocal generated_ids
for _ in range(max_new_tokens):
# Check if we should use autocast (skip if model uses float32)
use_autocast = True
if self.config.torch_dtype == torch.float32:
use_autocast = False
if use_autocast:
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
outputs = self.model(generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :]
else:
outputs = self.model(generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :]
# Repetition penalty
for token_id in history_ids:
if token_id < next_token_logits.size(-1):
logit = next_token_logits[0, token_id].item()
if logit > 0:
next_token_logits[0, token_id] = logit / repetition_penalty
else:
next_token_logits[0, token_id] = logit * repetition_penalty
# Temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Top-p / Top-k
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
elif top_k > 0:
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
# Sample
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1)
next_token_item = next_token.item()
if next_token_item == self.tokenizer.eos_token_id:
break
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
history_ids.add(next_token_item)
new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
yield new_token_text
if stream:
return token_generator()
else:
return "".join(list(token_generator()))
def generate_full(self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 0,
top_p: float = 0.9,
repetition_penalty: float = 1.0) -> str:
return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)