Support transformers>=4.56 (#4)
Browse files- Support `transformers>=4.56` (e8e271ef786e4e498d90a6c1ab214ac12a0796de)
- modeling_plamo.py +21 -5
modeling_plamo.py
CHANGED
|
@@ -19,6 +19,7 @@ import torch
|
|
| 19 |
from torch import nn
|
| 20 |
from torch.nn import functional as F
|
| 21 |
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
|
| 22 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 23 |
|
| 24 |
|
|
@@ -327,7 +328,8 @@ class Plamo2Cache(torch.nn.Module):
|
|
| 327 |
if sequence_length is not None
|
| 328 |
else layer_cache.key.shape[2]
|
| 329 |
)
|
| 330 |
-
|
|
|
|
| 331 |
return sequence_length
|
| 332 |
|
| 333 |
def get_max_length(self) -> int | None:
|
|
@@ -1387,7 +1389,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
| 1387 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1388 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1389 |
position_ids: Optional[torch.Tensor] = None,
|
| 1390 |
-
past_key_values: Optional[Plamo2Cache] = None,
|
| 1391 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1392 |
image_features: Optional[torch.Tensor] = None,
|
| 1393 |
use_cache: Optional[bool] = None,
|
|
@@ -1419,6 +1421,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
| 1419 |
seq_length_with_past = seq_length
|
| 1420 |
past_key_values_length = 0
|
| 1421 |
if past_key_values is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1422 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1423 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 1424 |
assert cache_position is None, "cache_position is not supported yet"
|
|
@@ -1434,7 +1446,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
| 1434 |
require_attn_mask = False
|
| 1435 |
if not self.training or past_key_values is not None:
|
| 1436 |
require_attn_mask = True
|
| 1437 |
-
if seq_length_with_past
|
| 1438 |
require_attn_mask = True
|
| 1439 |
if require_attn_mask and attention_mask is None:
|
| 1440 |
attention_mask = torch.ones(
|
|
@@ -1623,7 +1635,11 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1623 |
image_features: Optional[torch.Tensor] = None,
|
| 1624 |
**kwargs: Any,
|
| 1625 |
) -> Dict[str, Any]:
|
| 1626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1627 |
input_ids = input_ids[:, -1:]
|
| 1628 |
if image_features is not None:
|
| 1629 |
image_features = image_features[:, -1:, :]
|
|
@@ -1633,7 +1649,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1633 |
# create position_ids on the fly for batch generation
|
| 1634 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1635 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1636 |
-
if past_key_values:
|
| 1637 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1638 |
|
| 1639 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
|
| 19 |
from torch import nn
|
| 20 |
from torch.nn import functional as F
|
| 21 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 22 |
+
from transformers.cache_utils import DynamicCache
|
| 23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 24 |
|
| 25 |
|
|
|
|
| 328 |
if sequence_length is not None
|
| 329 |
else layer_cache.key.shape[2]
|
| 330 |
)
|
| 331 |
+
if sequence_length is None:
|
| 332 |
+
return 0
|
| 333 |
return sequence_length
|
| 334 |
|
| 335 |
def get_max_length(self) -> int | None:
|
|
|
|
| 1389 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1390 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1391 |
position_ids: Optional[torch.Tensor] = None,
|
| 1392 |
+
past_key_values: Optional[Plamo2Cache | DynamicCache] = None,
|
| 1393 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1394 |
image_features: Optional[torch.Tensor] = None,
|
| 1395 |
use_cache: Optional[bool] = None,
|
|
|
|
| 1421 |
seq_length_with_past = seq_length
|
| 1422 |
past_key_values_length = 0
|
| 1423 |
if past_key_values is not None:
|
| 1424 |
+
# In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
|
| 1425 |
+
if not isinstance(past_key_values, Plamo2Cache):
|
| 1426 |
+
past_key_values_prev = past_key_values
|
| 1427 |
+
past_key_values = Plamo2Cache(self.config)
|
| 1428 |
+
|
| 1429 |
+
# If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
|
| 1430 |
+
assert len(past_key_values_prev) == 0 or not any(
|
| 1431 |
+
layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
|
| 1432 |
+
)
|
| 1433 |
+
assert isinstance(past_key_values, Plamo2Cache)
|
| 1434 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1435 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 1436 |
assert cache_position is None, "cache_position is not supported yet"
|
|
|
|
| 1446 |
require_attn_mask = False
|
| 1447 |
if not self.training or past_key_values is not None:
|
| 1448 |
require_attn_mask = True
|
| 1449 |
+
if seq_length_with_past > self.config.attention_window_size + 1:
|
| 1450 |
require_attn_mask = True
|
| 1451 |
if require_attn_mask and attention_mask is None:
|
| 1452 |
attention_mask = torch.ones(
|
|
|
|
| 1635 |
image_features: Optional[torch.Tensor] = None,
|
| 1636 |
**kwargs: Any,
|
| 1637 |
) -> Dict[str, Any]:
|
| 1638 |
+
# Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
|
| 1639 |
+
# and its length becomes non-zero from v4.56 onward.
|
| 1640 |
+
# `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
|
| 1641 |
+
# se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
|
| 1642 |
+
if isinstance(past_key_values, Plamo2Cache):
|
| 1643 |
input_ids = input_ids[:, -1:]
|
| 1644 |
if image_features is not None:
|
| 1645 |
image_features = image_features[:, -1:, :]
|
|
|
|
| 1649 |
# create position_ids on the fly for batch generation
|
| 1650 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1651 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1652 |
+
if isinstance(past_key_values, Plamo2Cache):
|
| 1653 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1654 |
|
| 1655 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|