Text Generation
Transformers
Safetensors
English
gidd
custom_code
gidd-unif-3b / modeling_gidd.py
dvruette's picture
Upload folder using huggingface_hub
76a125c verified
import typing as tp
import warnings
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention
from transformers import PreTrainedModel
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import GenerationMixin
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .configuration_gidd import GiddConfig
@dataclass
class AttentionLayerOutput:
hidden_states: torch.Tensor
attentions: tp.Optional[torch.Tensor] = None
past_key_values: tp.Optional[tp.List[tp.Tuple[torch.Tensor, torch.Tensor]]] = None
@dataclass
class DecoderLayerOutput:
hidden_states: torch.Tensor
attentions: tp.Optional[torch.Tensor] = None
past_key_values: tp.Optional[tp.List[tp.Tuple[torch.Tensor, torch.Tensor]]] = None
def promote_dtype(args: tuple, *, dtype: torch.dtype | None = None) -> tuple:
return tuple(
torch.as_tensor(x, dtype=dtype) if x is not None else None
for x in args
)
class ScaledLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
scale: float | tp.Literal["fan_in", "fan_out"] = 1.0,
use_bias: bool = True,
dtype: torch.dtype | None = None,
):
super().__init__()
if scale == "fan_in":
scale = in_features**-0.5
elif scale == "fan_out":
scale = out_features**-0.5
if scale != 1.0:
def _scale_operator(x):
return x * scale
else:
def _scale_operator(x):
return x
self._scale_operator = _scale_operator
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
weight_shape = (out_features, in_features)
weight = torch.zeros(weight_shape, dtype=dtype)
self.weight = nn.Parameter(weight)
if use_bias:
bias = torch.zeros((out_features,), dtype=dtype)
self.bias = nn.Parameter(bias)
else:
self.bias = None
def forward(
self,
inputs: torch.Tensor,
w: torch.Tensor | None = None,
) -> torch.Tensor:
dtype = inputs.dtype
weight = self.weight if w is None else w
bias = self.bias if self.use_bias else None
if bias is not None:
inputs, weight, bias = promote_dtype((inputs, weight, bias), dtype=dtype)
else:
inputs, weight = promote_dtype((inputs, weight), dtype=dtype)
y = torch.matmul(
inputs,
weight.T,
)
y = self._scale_operator(y)
if bias is not None:
y = y + bias.reshape((1,) * (y.ndim - 1) + (-1,))
return y
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(2).to(dtype=x.dtype)
sin = sin.unsqueeze(2).to(dtype=x.dtype)
assert sin.ndim == x.ndim
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).reshape(x.shape)
def apply_basic_rope(
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
frequencies: torch.Tensor,
rotary_dim: int,
is_neox_style: bool,
offsets: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
):
if offsets is not None:
positions = positions + offsets
cos, sin = torch.chunk(frequencies[positions], 2, dim=-1)
if rotary_dim != query.shape[-1]:
query_rot = _apply_rotary_emb(query[..., :rotary_dim], cos, sin, is_neox_style)
query = torch.cat((query_rot, query[..., rotary_dim:]), dim=-1)
key_rot = _apply_rotary_emb(key[..., :rotary_dim], cos, sin, is_neox_style)
key = torch.cat((key_rot, key[..., rotary_dim:]), dim=-1)
return query.to(dtype), key.to(dtype), cos, sin
else:
query = _apply_rotary_emb(query, cos, sin, is_neox_style)
key = _apply_rotary_emb(key, cos, sin, is_neox_style)
return query.to(dtype), key.to(dtype), cos, sin
def compute_basic_frequencies(
base: int,
rotary_dim: int,
max_position_embeddings: int,
):
inv = 1.0 / torch.pow(
base,
torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim,
)
freqs = torch.einsum(
"i,j->ij",
torch.arange(max_position_embeddings, dtype=torch.float32),
inv,
)
freqs = torch.cat([freqs.cos(), freqs.sin()], dim=-1)
return freqs
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: torch.Tensor | None = None,
frequencies: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if frequencies is None:
frequencies = compute_basic_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
class GiddRMSNorm(nn.Module):
def __init__(
self,
config: GiddConfig,
dtype=torch.float32,
):
super().__init__()
self.config = config
self.epsilon = self.config.rms_norm_eps
self.weight = nn.Parameter(torch.zeros(self.config.hidden_size, dtype=dtype))
# self.bias = nn.Parameter(torch.zeros(self.config.hidden_size, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32)
variance = variance.pow(2.0)
variance = variance.mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
hidden_states = ((1 + self.weight) * hidden_states)
return hidden_states.to(dtype)
ALL_LAYERNORM_LAYERS.append(GiddRMSNorm)
class GiddMLP(nn.Module):
def __init__(
self,
config: GiddConfig,
dtype=torch.float32,
):
super().__init__()
self.config = config
self.dtype = dtype
linear_class = partial(
ScaledLinear,
scale=config.weight_scaling,
dtype=dtype,
use_bias=self.config.mlp_bias,
)
self.up_proj = linear_class(config.hidden_size, config.intermediate_size)
self.down_proj = linear_class(config.intermediate_size, config.hidden_size)
def forward(self, h: torch.Tensor) -> torch.Tensor:
h = self.up_proj(h)
h = torch.relu(h) ** 2
h = self.down_proj(h)
return h
class FlexSoftcapAttention(nn.Module):
def __init__(self, head_dim, n_heads, softmax_scale, soft_cap):
super().__init__()
self.d_model = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = float(softmax_scale)
self.soft_cap = float(soft_cap)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None = None,
):
B, _, L = q.shape[:3]
def score_mod(score, b, h, q_idx, kv_idx):
soft_cap = self.soft_cap
score = soft_cap * torch.tanh(score / soft_cap)
keep = attention_mask[b, q_idx, kv_idx]
return torch.where(keep, score, torch.finfo(score.dtype).min)
out = flex_attention(
q,
k,
v,
score_mod=score_mod,
scale=self.scale,
)
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return out, None
class VanillaSoftcapAttention(nn.Module):
def __init__(self, head_dim, n_heads, softmax_scale, soft_cap):
super().__init__()
self.d_model = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = float(softmax_scale)
self.soft_cap = float(soft_cap)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None = None,
):
B, _, L = q.shape[:3]
scores = torch.einsum(
"bhqd,bhkd->bhqk",
q * self.scale,
k,
)
scores = self.soft_cap * torch.tanh(scores / self.soft_cap)
if attention_mask is not None:
scores = scores.masked_fill(~attention_mask.unsqueeze(1), torch.finfo(scores.dtype).min)
probs = torch.softmax(scores.to(torch.float32), dim=-1).to(scores.dtype)
out = torch.einsum(
"bhqk,bhkd->bhqd",
probs,
v,
)
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return out, probs
class GiddAttention(nn.Module):
def __init__(
self,
config: GiddConfig,
layer_idx: int,
dtype=torch.float32,
):
super().__init__()
self.hidden_size = config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = getattr(config, "head_dim", head_dim)
self.num_attention_heads = self.hidden_size // self.head_dim
self.is_causal = config.is_causal
self.layer_idx = layer_idx
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = GiddRMSNorm(config, dtype=torch.float32)
self.k_norm = GiddRMSNorm(config, dtype=torch.float32)
else:
self.q_norm = None
self.k_norm = None
self.attention_bias = config.attention_bias
if self.attention_bias:
self.k_bias = nn.Parameter(
torch.zeros((self.num_attention_heads, self.head_dim), dtype=dtype),
)
self.v_bias = nn.Parameter(
torch.zeros((self.num_attention_heads, self.head_dim), dtype=dtype),
)
else:
self.k_bias = None
self.v_bias = None
linear_class = partial(
ScaledLinear,
scale=config.weight_scaling,
dtype=dtype,
use_bias=False,
)
self.q_proj = linear_class(
self.hidden_size,
self.num_attention_heads * self.head_dim,
)
self.k_proj = linear_class(
self.hidden_size,
self.num_attention_heads * self.head_dim,
)
self.v_proj = linear_class(
self.hidden_size,
self.num_attention_heads * self.head_dim,
)
self.o_proj = linear_class(
self.num_attention_heads * self.head_dim,
self.hidden_size,
)
self.rotary = RotaryEmbedding(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
is_neox_style=True,
dtype=dtype,
)
if config.attn_performer == "flex":
self.attention_performer = FlexSoftcapAttention(
head_dim=self.head_dim,
n_heads=self.num_attention_heads,
softmax_scale=self.head_dim**-0.5,
soft_cap=config.attn_soft_cap,
)
elif config.attn_performer == "eager":
self.attention_performer = VanillaSoftcapAttention(
head_dim=self.head_dim,
n_heads=self.num_attention_heads,
softmax_scale=self.head_dim**-0.5,
soft_cap=config.attn_soft_cap,
)
else:
raise ValueError(f"Unknown attn_performer: {config.attn_performer}")
def concatenate(
self,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
):
assert query.shape[1] == key.shape[1], "Query and Key lengths must match for GIDD attention."
if attention_mask is not None:
if attention_mask.dtype != torch.bool:
warnings.warn("attention_mask should be a boolean array", stacklevel=1)
attention_mask = (attention_mask == 1)
batch_size = query.shape[0]
# shape of attention_mask: (batch_size, seq_len)
# or (batch_size, query_len, kv_len)
if attention_mask.ndim == 2:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.expand(-1, query.shape[1], -1)
elif attention_mask.ndim == 3:
# already in correct shape
pass
if self.attention_bias:
ones = torch.ones(
attention_mask.shape[:2] + (1,),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat(
[
ones,
attention_mask,
],
dim=-1,
)
if past_key_values is not None:
past_keys, past_values = past_key_values
key = torch.cat([past_keys, key], dim=1)
value = torch.cat([past_values, value], dim=1)
elif self.attention_bias:
n_heads = self.num_attention_heads
bias_shape = (batch_size, 1, n_heads, self.head_dim)
k_bias = self.k_bias.view(1, 1, n_heads, self.head_dim).expand(bias_shape)
v_bias = self.v_bias.view(1, 1, n_heads, self.head_dim).expand(bias_shape)
key = torch.cat([k_bias, key], dim=1)
value = torch.cat([v_bias, value], dim=1)
# shape of attention_mask: (batch_size, 1, query_len, kv_len + 1)
return query, key, value, attention_mask, (key, value)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
frequencies: tp.Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> AttentionLayerOutput:
batch_size, sequence_length = hidden_states.shape[:2]
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
qshape = (
batch_size,
sequence_length,
self.num_attention_heads,
self.head_dim,
)
kv_shape = (
batch_size,
sequence_length,
self.num_attention_heads,
self.head_dim,
)
query_states = query_states.view(qshape)
key_states = key_states.view(kv_shape)
value_states = value_states.view(kv_shape)
query_states, key_states, cos, sin = self.rotary(
positions=position_ids,
query=query_states,
key=key_states,
frequencies=frequencies,
)
(
query_states,
key_states,
value_states,
attention_mask,
past_key_values,
) = self.concatenate(
query=query_states,
key=key_states,
value=value_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
attention_out, attentions = self.attention_performer.forward(
q=query_states.transpose(1, 2),
k=key_states.transpose(1, 2),
v=value_states.transpose(1, 2),
attention_mask=attention_mask,
)
attn_output = self.o_proj(attention_out)
return AttentionLayerOutput(
hidden_states=attn_output,
attentions=attentions if output_attentions else None,
past_key_values=past_key_values,
)
class GiddLayer(nn.Module):
def __init__(
self,
config: GiddConfig,
layer_idx: int,
dtype=torch.float32,
resid_scale: float = 1.0,
):
super().__init__()
self.config = config
self.resid_scale = resid_scale
self.layer_idx = layer_idx
self.self_attn = GiddAttention(
layer_idx=layer_idx,
config=config,
dtype=dtype,
)
self.mlp = GiddMLP(
config=config,
dtype=dtype,
)
self.attn_layernorm = GiddRMSNorm(
config=config,
dtype=torch.float32,
)
self.mlp_layernorm = GiddRMSNorm(
config=config,
dtype=torch.float32,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: tp.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
frequencies: tp.Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> DecoderLayerOutput:
attn_inputs = self.attn_layernorm(hidden_states)
attn_outputs = self.self_attn(
attn_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
frequencies=frequencies,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.resid_scale * attn_outputs.hidden_states
mlp_inputs = self.mlp_layernorm(hidden_states)
mlp_output = self.mlp(mlp_inputs)
hidden_states = hidden_states + self.resid_scale * mlp_output
return DecoderLayerOutput(
hidden_states=hidden_states,
attentions=attn_outputs.attentions,
past_key_values=attn_outputs.past_key_values,
)
class GiddPreTrainedModel(PreTrainedModel):
config_class = GiddConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["GiddLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = False
_supports_sdpa = False
_supports_flex_attn = False
_can_compile_fullgraph = False
_supports_attention_backend = False
_can_record_outputs = {
"hidden_states": GiddLayer,
"attentions": GiddAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
class GiddModel(GiddPreTrainedModel):
def __init__(
self,
config: GiddConfig,
):
super().__init__(config=config)
self.resid_scale = config.resid_scale / config.num_hidden_layers
dtype = config.torch_dtype
self.embed_tokens = nn.Embedding(
num_embeddings=self.config.vocab_size,
embedding_dim=self.config.hidden_size,
)
self.embed_tokens.weight.data = self.embed_tokens.weight.data.to(dtype)
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.emb_init_scale)
freqs = compute_basic_frequencies(
base=config.rope_theta,
rotary_dim=config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
)
self.frequencies = nn.Buffer(freqs, persistent=False)
self.layers = nn.ModuleList(
[
GiddLayer(
config=config,
layer_idx=i,
resid_scale=self.resid_scale,
dtype=dtype,
)
for i in range(self.config.num_hidden_layers)
]
)
self.norm = GiddRMSNorm(
config=config,
dtype=torch.float32,
)
def forward(
self,
input_ids: tp.Optional[torch.Tensor] = None,
inputs_embeds: tp.Optional[torch.Tensor] = None,
attention_mask: tp.Optional[torch.Tensor] = None,
position_ids: tp.Optional[torch.Tensor] = None,
past_key_values: tp.Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
cache_position: tp.Optional[torch.LongTensor] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.to(torch.long))
if use_cache and past_key_values is None:
past_key_values = [None] * self.config.num_hidden_layers
elif past_key_values is not None:
past_key_values = list(past_key_values)
if position_ids is None:
past_seen_tokens = 0
if past_key_values is not None and any(past_key_values):
past_seen_tokens = [kv[0].shape[1] for kv in past_key_values if kv is not None][0]
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
position_ids = cache_position.unsqueeze(0)
batch_size, sequence_length, _ = inputs_embeds.shape
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! (expected <= {self.config.max_position_embeddings} got {sequence_length})"
)
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, sequence_length),
dtype=torch.bool,
device=inputs_embeds.device,
)
else:
if attention_mask.dtype != torch.bool:
attention_mask = (attention_mask == 1)
if position_ids is None:
position_ids = torch.arange(
inputs_embeds.shape[-2],
dtype=torch.int32,
device=inputs_embeds.device,
)
position_ids = position_ids.unsqueeze(0).expand(inputs_embeds.shape[:-1])
hidden_states = inputs_embeds
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
frequencies=self.frequencies,
past_key_values=past_key_values[idx] if past_key_values is not None else None,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attentions,)
if use_cache:
past_key_values[idx] = layer_outputs.past_key_values
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
class GiddForDiffusionLM(GiddPreTrainedModel, GenerationMixin):
def __init__(
self,
config: GiddConfig,
):
super().__init__(config=config)
self.model = GiddModel(config=config)
self.lm_head = ScaledLinear(
config.hidden_size,
config.vocab_size,
scale=config.head_scaling,
dtype=config.torch_dtype,
use_bias=False,
)
def forward(
self,
input_ids: tp.Optional[torch.Tensor] = None,
inputs_embeds: tp.Optional[torch.Tensor] = None,
attention_mask: tp.Optional[torch.Tensor] = None,
position_ids: tp.Optional[torch.Tensor] = None,
past_key_values: tp.Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
) -> CausalLMOutputWithPast:
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
hidden_states = outputs.last_hidden_state
if self.config.tie_word_embeddings:
logits = hidden_states @ self.model.embed_tokens.weight.t()
else:
logits = self.lm_head(hidden_states)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)
def _sample_prior(self, shape: tuple[int, ...], device: torch.device, mask_token_id: int = 3) -> torch.Tensor:
p_unif = torch.sigmoid(
torch.ones(shape, device=device) * self.config.min_log_snr + self.config.noise_type
)
r = torch.rand(shape, device=device)
unif = torch.randint(0, self.config.vocab_size, shape, device=device)
samples = torch.where(r < p_unif, unif, mask_token_id)
return samples
def _probs_with_topk_topp(self, logits, temperature: float, top_p: float | None, top_k: int | None):
if temperature == 0.0:
probs = torch.zeros_like(logits)
indices = torch.argmax(logits, dim=-1, keepdim=True)
probs.scatter_(-1, indices, 1.0)
return probs
x = logits / temperature
if top_k is not None and 0 < top_k < x.size(-1):
kth = torch.topk(x, top_k, dim=-1).values[..., -1, None]
x = torch.where(x < kth, torch.full_like(x, float("-inf")), x)
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(x, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumprobs = sorted_probs.cumsum(dim=-1)
remove = cumprobs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
x = x.scatter(-1, sorted_idx, sorted_logits)
probs = torch.softmax(x, dim=-1)
return probs
def _pi_lambda(self, log_snr, mask_token_id=3):
unif_vec = torch.ones((self.config.vocab_size,), device=log_snr.device) / (self.config.vocab_size - 1)
unif_vec[mask_token_id] = 0.0
alpha = torch.sigmoid(log_snr + self.config.noise_type)
pi = alpha * unif_vec
pi[..., mask_token_id] = 1.0 - alpha
return pi
def _sample_ancestral(
self,
z: torch.Tensor,
x_hat: torch.Tensor,
log_snr_t: torch.Tensor,
log_snr_s: torch.Tensor,
mask_token_id: int = 3,
):
alpha_s = log_snr_s.sigmoid()
alpha_t = log_snr_t.sigmoid()
beta_s, beta_t = 1.0 - alpha_s, 1.0 - alpha_t
alpha_t_s = alpha_t / alpha_s
pi_s = self._pi_lambda(log_snr_s, mask_token_id=mask_token_id)
pi_t = self._pi_lambda(log_snr_t, mask_token_id=mask_token_id)
beta_pi_t_s = beta_t * pi_t - alpha_t_s * beta_s * pi_s
# beta_pi_t_s_at_z = beta_pi_t_s[z]
q_t = alpha_t * x_hat + beta_t * pi_t[None, None, :]
q_s = alpha_s * x_hat + beta_s * pi_s[None, None, :]
q_t_at_z = q_t.gather(-1, z.unsqueeze(-1)).squeeze(-1)
z_vec = torch.nn.functional.one_hot(z, num_classes=self.config.vocab_size).to(q_t.dtype)
q_t_s_at_z = alpha_t_s * z_vec + beta_pi_t_s[z, None]
p_s_t = q_s * q_t_s_at_z / q_t_at_z[..., None]
z_next = torch.multinomial(p_s_t.flatten(0, 1), num_samples=1).view_as(z)
return z_next
def _sample_adaptive(
self,
z: torch.Tensor,
logits: torch.Tensor,
log_snr: torch.Tensor,
n_tokens: int = 1,
mask_token_id: int = 3,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
):
pi_vec = self._pi_lambda(log_snr, mask_token_id=mask_token_id)
p_noise = pi_vec[z]
p_noise = p_noise / p_noise.sum(dim=-1, keepdim=True)
x_hat = logits.softmax(dim=-1)
p_max = x_hat.max(dim=-1).values
p_curr = x_hat.gather(-1, z.unsqueeze(-1)).squeeze(-1)
p_delta = (p_max - p_curr) * p_noise
next_poss = torch.topk(p_delta, n_tokens, dim=-1).indices
probs = self._probs_with_topk_topp(
logits=logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
next_tokens = torch.multinomial(probs.flatten(0, 1), num_samples=1).view_as(z)
z_next = z.clone()
batch_indices = torch.arange(z.shape[0], device=z.device).unsqueeze(-1)
z_next[batch_indices, next_poss] = next_tokens[batch_indices, next_poss]
return z_next
@torch.no_grad()
def generate(
self,
inputs: tp.Optional[torch.Tensor] = None,
max_length: int = 2048,
min_length: int = 0,
temperature: float = 1.0,
block_length: int = 128,
steps: int = 128,
top_p: tp.Optional[float] = None,
top_k: tp.Optional[int] = None,
bos_token_id: int = 0,
eos_token_id: int = 1,
pad_token_id: int = 2,
mask_token_id: int = 3,
sampling_method: tp.Literal["ancestral", "adaptive"] = "ancestral",
noise_schedule: tp.Literal["linear", "cosine"] | tp.Callable[[torch.Tensor], torch.Tensor] = "cosine",
tokens_per_step: int = 1,
show_progress: bool = False,
):
r"""
Generates tokens with block-wise denoising diffusion.
Parameters:
inputs (`torch.Tensor`):
The token sequence used as a prompt for the generation.
temperature (`float`, *optional*, defaults to 0.0):
The value used to module the next token probabilities. A value of 0.0 corresponds to greedy decoding.
block_length (`int`, *optional*, defaults to 32):
The size of each generation block. The model generates text in parallel within these blocks. This is a
key parameter for controlling the granularity of the generation process.
steps (`int`, *optional*, defaults to 32):
The number of denoising steps to perform for each block.
max_length (`int`, *optional*, defaults to 2048):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated.
top_p (`float`, *optional*):
If set to a float value between 0 and 1, only the most probable tokens with probabilities that add up to
`top_p` or higher are kept for generation (nucleus sampling).
top_k (`int`, *optional*):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
bos_token_id (`int`, *optional*, defaults to 0):
The token ID for the beginning-of-sequence token.
eos_token_id (`int`, *optional*, defaults to 1):
The token ID for the end-of-sequence token.
pad_token_id (`int`, *optional*, defaults to 2):
The token ID for the padding token.
mask_token_id (`int`, *optional*, defaults to 3):
The token ID used as a placeholder for tokens that are yet to be generated.
Return:
`torch.Tensor`: A string containing the generated token IDs, starting
after the prompt and stopping at the first `eos_id` or `gen_length`.
"""
if sampling_method not in ["ancestral", "adaptive"]:
raise ValueError(f"Unsupported sampling method: {sampling_method}")
if noise_schedule not in ["linear", "cosine"] and not callable(noise_schedule):
raise ValueError("noise_schedule must be 'linear', 'cosine', or a callable function.")
if inputs is None:
inputs = torch.tensor([[bos_token_id]], device=self.device, dtype=torch.long)
batch_size = 1
prompt_length = 0
else:
batch_size = inputs.shape[0]
prompt_length = inputs.shape[1]
if eos_token_id in inputs:
warnings.warn("Input prompt contains eos_token_id. Generation may stop earlier than expected.", stacklevel=1)
input_ids = inputs.to(self.device)
total_length = self.config.max_position_embeddings
if noise_schedule == "linear":
noise_schedule_fn = lambda t: 1.0 - t
elif noise_schedule == "cosine":
noise_schedule_fn = lambda t: 0.5 + 0.5 * torch.cos(t * torch.pi)
else:
noise_schedule_fn = noise_schedule
x_prior = self._sample_prior(
shape=(batch_size, total_length),
device=self.device,
mask_token_id=mask_token_id,
)
x = x_prior.clone()
if prompt_length > 0:
x[:, :prompt_length] = input_ids.clone()
position_ids = torch.arange(total_length, device=self.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
noise_mask = torch.ones_like(x, dtype=torch.bool)
noise_mask[:, :prompt_length] = False
min_log_snr = torch.tensor(self.config.min_log_snr, device=self.device)
max_log_snr = torch.tensor(self.config.max_log_snr, device=self.device)
alpha_min = torch.sigmoid(min_log_snr)
alpha_max = torch.sigmoid(max_log_snr)
ts = torch.linspace(0.0, 1.0, steps=steps + 1, device=self.device)
alpha_t = (alpha_max - alpha_min) * noise_schedule_fn(ts) + alpha_min
log_snrs = torch.log(alpha_t / (1.0 - alpha_t)).clip(min_log_snr, max_log_snr)
if show_progress:
import tqdm.auto as tqdm
est_num_blocks = (max_length + block_length - 1) // block_length
est_num_steps = est_num_blocks * steps
pbar = tqdm.tqdm(total=est_num_steps)
update_pbar = lambda n: pbar.update(n)
def stop_pbar():
pbar.total = pbar.n
pbar.refresh()
close_pbar = lambda: pbar.close()
else:
update_pbar = lambda n: None
stop_pbar = lambda: None
close_pbar = lambda: None
try:
num_blocks = 0
while True:
current_window_start = prompt_length + num_blocks * block_length
current_window_end = current_window_start + block_length
attn_mask = (noise_mask[..., :, None] >= noise_mask[..., None, :])
keep_logits = False
past_key_values = None
for step in range(steps, 0, -1):
if past_key_values is None:
output = self.forward(
input_ids=x[:, :current_window_start],
attention_mask=attn_mask[:, :current_window_start, :current_window_start],
position_ids=position_ids[:, :current_window_start],
use_cache=True,
)
past_key_values = output.past_key_values
if not keep_logits:
logits = self.forward(
input_ids=x[:, current_window_start:],
attention_mask=attn_mask[:, current_window_start:],
position_ids=position_ids[:, current_window_start:],
past_key_values=past_key_values,
).logits
active_logits = logits[:, :block_length, :]
# logits = self.forward(
# input_ids=x,
# attention_mask=attn_mask,
# position_ids=position_ids,
# past_key_values=None
# ).logits
# active_logits = logits[:, current_window_start:current_window_end, :]
active_logits[..., mask_token_id] = float("-inf")
min_eos_idx = max(0, min_length + prompt_length - current_window_start)
active_logits[:, :min_eos_idx, eos_token_id] = float("-inf")
z_t = x[:, current_window_start:current_window_end]
if sampling_method == "ancestral":
x_hat = self._probs_with_topk_topp(
active_logits.to(torch.float32),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
z_s = self._sample_ancestral(
z=z_t,
x_hat=x_hat,
log_snr_t=log_snrs[step],
log_snr_s=log_snrs[step - 1],
mask_token_id=mask_token_id,
)
elif sampling_method == "adaptive":
z_s = self._sample_adaptive(
z=z_t,
logits=active_logits.to(torch.float32),
log_snr=log_snrs[step],
n_tokens=tokens_per_step,
mask_token_id=mask_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
keep_logits = (z_s == z_t).all().item()
x[:, current_window_start:current_window_end] = z_s.clone()
update_pbar(1)
num_blocks += 1
noise_mask[:, :current_window_end] = False
has_eos = (x == eos_token_id).any(-1).all().item()
all_done = current_window_end >= max_length + prompt_length or has_eos
if all_done:
stop_pbar()
break
finally:
close_pbar()
generated_answer = x[:, :max_length + prompt_length]
eos_idx = (generated_answer == eos_token_id).int().argmax(dim=-1)
for i, idx in enumerate(eos_idx):
if idx > 0:
generated_answer[i, idx:] = pad_token_id
return generated_answer