from __future__ import annotations import torch def sample_token( logits: torch.Tensor, *, temp: float, top_k: int = 0, ) -> torch.Tensor: logits32 = logits.to(torch.float32) if temp <= 0.0: return torch.argmax(logits32, dim=-1, keepdim=True) probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1) probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) probs = torch.clamp_min(probs, 0.0) flat = probs.reshape(-1, probs.shape[-1]) norm = flat.sum(dim=-1, keepdim=True) zero_mask = norm <= 0 norm = norm.clamp_min(1e-12) flat = flat / norm if zero_mask.any(): filler = torch.zeros_like(flat) filler[..., 0] = 1.0 mask = zero_mask.expand_as(flat) flat = torch.where(mask, filler, flat) vocab = flat.shape[-1] if top_k > 0 and top_k < vocab: topv, indices = torch.topk(flat, top_k, dim=-1) topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12) draws = torch.multinomial(topv, num_samples=1) picks = torch.gather(indices, dim=-1, index=draws) else: picks = torch.multinomial(flat, num_samples=1) picks = picks.reshape(*probs.shape[:-1], 1) return picks