import torch import torch.nn.functional as F from typing import Optional, List, Generator from aetheris.config import AetherisConfig from aetheris.model import HybridMambaMoE from aetheris.data import get_tokenizer from aetheris.utils import load_latest_checkpoint class InferenceEngine: def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None): self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu')) self.config = AetherisConfig.from_yaml(config_path) self.tokenizer = get_tokenizer() self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype) # Load checkpoint # Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name) self.model.eval() def generate(self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 0, top_p: float = 0.9, repetition_penalty: float = 1.0, stream: bool = False) -> Generator[str, None, None] | str: input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) generated_ids = input_ids.clone() history_ids = set(input_ids[0].tolist()) def token_generator(): nonlocal generated_ids for _ in range(max_new_tokens): # Check if we should use autocast (skip if model uses float32) use_autocast = True if self.config.torch_dtype == torch.float32: use_autocast = False if use_autocast: with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype): outputs = self.model(generated_ids) logits = outputs['logits'] next_token_logits = logits[:, -1, :] else: outputs = self.model(generated_ids) logits = outputs['logits'] next_token_logits = logits[:, -1, :] # Repetition penalty for token_id in history_ids: if token_id < next_token_logits.size(-1): logit = next_token_logits[0, token_id].item() if logit > 0: next_token_logits[0, token_id] = logit / repetition_penalty else: next_token_logits[0, token_id] = logit * repetition_penalty # Temperature if temperature > 0: next_token_logits = next_token_logits / temperature # Top-p / Top-k if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf')) elif top_k > 0: top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) next_token_logits = torch.full_like(next_token_logits, float('-inf')) next_token_logits.scatter_(1, top_k_indices, top_k_logits) # Sample next_token_probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(next_token_probs, num_samples=1) next_token_item = next_token.item() if next_token_item == self.tokenizer.eos_token_id: break generated_ids = torch.cat([generated_ids, next_token], dim=-1) history_ids.add(next_token_item) new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True) yield new_token_text if stream: return token_generator() else: return "".join(list(token_generator())) def generate_full(self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 0, top_p: float = 0.9, repetition_penalty: float = 1.0) -> str: return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)