attnmask
#13
by
Kamichanw
- opened
- modeling_llada.py +15 -11
- tokenizer_config.json +1 -1
modeling_llada.py
CHANGED
|
@@ -654,7 +654,7 @@ class LLaDABlock(nn.Module):
|
|
| 654 |
q,
|
| 655 |
k,
|
| 656 |
v,
|
| 657 |
-
attn_mask=
|
| 658 |
dropout_p=dropout_p,
|
| 659 |
is_causal=False,
|
| 660 |
)
|
|
@@ -665,6 +665,7 @@ class LLaDABlock(nn.Module):
|
|
| 665 |
k: torch.Tensor,
|
| 666 |
v: torch.Tensor,
|
| 667 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
|
| 668 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 669 |
use_cache: bool = False,
|
| 670 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
@@ -712,7 +713,7 @@ class LLaDABlock(nn.Module):
|
|
| 712 |
q,
|
| 713 |
k,
|
| 714 |
v,
|
| 715 |
-
attn_mask=
|
| 716 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 717 |
is_causal=False,
|
| 718 |
)
|
|
@@ -785,6 +786,7 @@ class LLaDASequentialBlock(LLaDABlock):
|
|
| 785 |
self,
|
| 786 |
x: torch.Tensor,
|
| 787 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
|
| 788 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 789 |
use_cache: bool = False,
|
| 790 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
@@ -805,10 +807,10 @@ class LLaDASequentialBlock(LLaDABlock):
|
|
| 805 |
# Get attention scores.
|
| 806 |
if self._activation_checkpoint_fn is not None:
|
| 807 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 808 |
-
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 809 |
)
|
| 810 |
else:
|
| 811 |
-
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 812 |
|
| 813 |
# Add attention scores.
|
| 814 |
# shape: (B, T, C)
|
|
@@ -887,6 +889,7 @@ class LLaDALlamaBlock(LLaDABlock):
|
|
| 887 |
self,
|
| 888 |
x: torch.Tensor,
|
| 889 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
|
| 890 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 891 |
use_cache: bool = False,
|
| 892 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
@@ -905,10 +908,10 @@ class LLaDALlamaBlock(LLaDABlock):
|
|
| 905 |
# Get attention scores.
|
| 906 |
if self._activation_checkpoint_fn is not None:
|
| 907 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 908 |
-
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 909 |
)
|
| 910 |
else:
|
| 911 |
-
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 912 |
|
| 913 |
# Add attention scores.
|
| 914 |
# shape: (B, T, C)
|
|
@@ -977,6 +980,7 @@ class LLaDABlockGroup(nn.ModuleList):
|
|
| 977 |
self,
|
| 978 |
x: torch.Tensor,
|
| 979 |
attention_bias: Optional[torch.FloatTensor] = None,
|
|
|
|
| 980 |
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 981 |
use_cache: bool = False,
|
| 982 |
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
@@ -1001,11 +1005,11 @@ class LLaDABlockGroup(nn.ModuleList):
|
|
| 1001 |
):
|
| 1002 |
# shape: (batch_size, seq_len, d_model)
|
| 1003 |
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1004 |
-
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1005 |
)
|
| 1006 |
else:
|
| 1007 |
# shape: (batch_size, seq_len, d_model)
|
| 1008 |
-
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1009 |
if attn_key_values is not None:
|
| 1010 |
assert cache is not None
|
| 1011 |
attn_key_values.append(cache)
|
|
@@ -1308,11 +1312,11 @@ class LLaDAModel(nn.Module):
|
|
| 1308 |
):
|
| 1309 |
# shape: (batch_size, seq_len, d_model)
|
| 1310 |
x, cache = self._activation_checkpoint_fn(
|
| 1311 |
-
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1312 |
)
|
| 1313 |
else:
|
| 1314 |
# shape: (batch_size, seq_len, d_model)
|
| 1315 |
-
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1316 |
if attn_key_values is not None:
|
| 1317 |
assert cache is not None
|
| 1318 |
attn_key_values.append(cache)
|
|
@@ -1330,7 +1334,7 @@ class LLaDAModel(nn.Module):
|
|
| 1330 |
]
|
| 1331 |
)
|
| 1332 |
x, cache = block_group(
|
| 1333 |
-
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
| 1334 |
)
|
| 1335 |
if attn_key_values is not None:
|
| 1336 |
assert cache is not None
|
|
|
|
| 654 |
q,
|
| 655 |
k,
|
| 656 |
v,
|
| 657 |
+
attn_mask=attn_mask,
|
| 658 |
dropout_p=dropout_p,
|
| 659 |
is_causal=False,
|
| 660 |
)
|
|
|
|
| 665 |
k: torch.Tensor,
|
| 666 |
v: torch.Tensor,
|
| 667 |
attention_bias: Optional[torch.Tensor] = None,
|
| 668 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 669 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 670 |
use_cache: bool = False,
|
| 671 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
| 713 |
q,
|
| 714 |
k,
|
| 715 |
v,
|
| 716 |
+
attn_mask=attention_mask,
|
| 717 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 718 |
is_causal=False,
|
| 719 |
)
|
|
|
|
| 786 |
self,
|
| 787 |
x: torch.Tensor,
|
| 788 |
attention_bias: Optional[torch.Tensor] = None,
|
| 789 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 790 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 791 |
use_cache: bool = False,
|
| 792 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
| 807 |
# Get attention scores.
|
| 808 |
if self._activation_checkpoint_fn is not None:
|
| 809 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 810 |
+
self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
|
| 811 |
)
|
| 812 |
else:
|
| 813 |
+
att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
|
| 814 |
|
| 815 |
# Add attention scores.
|
| 816 |
# shape: (B, T, C)
|
|
|
|
| 889 |
self,
|
| 890 |
x: torch.Tensor,
|
| 891 |
attention_bias: Optional[torch.Tensor] = None,
|
| 892 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 893 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 894 |
use_cache: bool = False,
|
| 895 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
| 908 |
# Get attention scores.
|
| 909 |
if self._activation_checkpoint_fn is not None:
|
| 910 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 911 |
+
self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
|
| 912 |
)
|
| 913 |
else:
|
| 914 |
+
att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
|
| 915 |
|
| 916 |
# Add attention scores.
|
| 917 |
# shape: (B, T, C)
|
|
|
|
| 980 |
self,
|
| 981 |
x: torch.Tensor,
|
| 982 |
attention_bias: Optional[torch.FloatTensor] = None,
|
| 983 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 984 |
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 985 |
use_cache: bool = False,
|
| 986 |
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
|
|
| 1005 |
):
|
| 1006 |
# shape: (batch_size, seq_len, d_model)
|
| 1007 |
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1008 |
+
block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
|
| 1009 |
)
|
| 1010 |
else:
|
| 1011 |
# shape: (batch_size, seq_len, d_model)
|
| 1012 |
+
x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
|
| 1013 |
if attn_key_values is not None:
|
| 1014 |
assert cache is not None
|
| 1015 |
attn_key_values.append(cache)
|
|
|
|
| 1312 |
):
|
| 1313 |
# shape: (batch_size, seq_len, d_model)
|
| 1314 |
x, cache = self._activation_checkpoint_fn(
|
| 1315 |
+
block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
|
| 1316 |
)
|
| 1317 |
else:
|
| 1318 |
# shape: (batch_size, seq_len, d_model)
|
| 1319 |
+
x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
|
| 1320 |
if attn_key_values is not None:
|
| 1321 |
assert cache is not None
|
| 1322 |
attn_key_values.append(cache)
|
|
|
|
| 1334 |
]
|
| 1335 |
)
|
| 1336 |
x, cache = block_group(
|
| 1337 |
+
x, attention_bias=attention_bias, attention_mask=attention_mask, layers_past=layers_past, use_cache=use_cache
|
| 1338 |
)
|
| 1339 |
if attn_key_values is not None:
|
| 1340 |
assert cache is not None
|
tokenizer_config.json
CHANGED
|
@@ -2164,7 +2164,7 @@
|
|
| 2164 |
"<|number_end|>"
|
| 2165 |
],
|
| 2166 |
"bos_token": "<|startoftext|>",
|
| 2167 |
-
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
| 2168 |
"clean_up_tokenization_spaces": false,
|
| 2169 |
"cls_token": "[CLS]",
|
| 2170 |
"eos_token": "<|endoftext|>",
|
|
|
|
| 2164 |
"<|number_end|>"
|
| 2165 |
],
|
| 2166 |
"bos_token": "<|startoftext|>",
|
| 2167 |
+
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{%- endif %}",
|
| 2168 |
"clean_up_tokenization_spaces": false,
|
| 2169 |
"cls_token": "[CLS]",
|
| 2170 |
"eos_token": "<|endoftext|>",
|