yhirokawa commited on
Commit
cae8da3
·
verified ·
1 Parent(s): 067619a

Support transformers>=4.56 (#4)

Browse files

- Support `transformers>=4.56` (e8e271ef786e4e498d90a6c1ab214ac12a0796de)

Files changed (1) hide show
  1. modeling_plamo.py +21 -5
modeling_plamo.py CHANGED
@@ -19,6 +19,7 @@ import torch
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
 
22
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
 
24
 
@@ -327,7 +328,8 @@ class Plamo2Cache(torch.nn.Module):
327
  if sequence_length is not None
328
  else layer_cache.key.shape[2]
329
  )
330
- assert sequence_length is not None
 
331
  return sequence_length
332
 
333
  def get_max_length(self) -> int | None:
@@ -1387,7 +1389,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
1387
  input_ids: Optional[torch.LongTensor] = None,
1388
  attention_mask: Optional[torch.Tensor] = None,
1389
  position_ids: Optional[torch.Tensor] = None,
1390
- past_key_values: Optional[Plamo2Cache] = None,
1391
  inputs_embeds: Optional[torch.Tensor] = None,
1392
  image_features: Optional[torch.Tensor] = None,
1393
  use_cache: Optional[bool] = None,
@@ -1419,6 +1421,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
1419
  seq_length_with_past = seq_length
1420
  past_key_values_length = 0
1421
  if past_key_values is not None:
 
 
 
 
 
 
 
 
 
 
1422
  past_key_values_length = past_key_values.get_seq_length()
1423
  seq_length_with_past = seq_length_with_past + past_key_values_length
1424
  assert cache_position is None, "cache_position is not supported yet"
@@ -1434,7 +1446,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
1434
  require_attn_mask = False
1435
  if not self.training or past_key_values is not None:
1436
  require_attn_mask = True
1437
- if seq_length_with_past >= self.config.attention_window_size:
1438
  require_attn_mask = True
1439
  if require_attn_mask and attention_mask is None:
1440
  attention_mask = torch.ones(
@@ -1623,7 +1635,11 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1623
  image_features: Optional[torch.Tensor] = None,
1624
  **kwargs: Any,
1625
  ) -> Dict[str, Any]:
1626
- if past_key_values:
 
 
 
 
1627
  input_ids = input_ids[:, -1:]
1628
  if image_features is not None:
1629
  image_features = image_features[:, -1:, :]
@@ -1633,7 +1649,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1633
  # create position_ids on the fly for batch generation
1634
  position_ids = attention_mask.long().cumsum(-1) - 1
1635
  position_ids.masked_fill_(attention_mask == 0, 1)
1636
- if past_key_values:
1637
  position_ids = position_ids[:, -1].unsqueeze(-1)
1638
 
1639
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.cache_utils import DynamicCache
23
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24
 
25
 
 
328
  if sequence_length is not None
329
  else layer_cache.key.shape[2]
330
  )
331
+ if sequence_length is None:
332
+ return 0
333
  return sequence_length
334
 
335
  def get_max_length(self) -> int | None:
 
1389
  input_ids: Optional[torch.LongTensor] = None,
1390
  attention_mask: Optional[torch.Tensor] = None,
1391
  position_ids: Optional[torch.Tensor] = None,
1392
+ past_key_values: Optional[Plamo2Cache | DynamicCache] = None,
1393
  inputs_embeds: Optional[torch.Tensor] = None,
1394
  image_features: Optional[torch.Tensor] = None,
1395
  use_cache: Optional[bool] = None,
 
1421
  seq_length_with_past = seq_length
1422
  past_key_values_length = 0
1423
  if past_key_values is not None:
1424
+ # In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
1425
+ if not isinstance(past_key_values, Plamo2Cache):
1426
+ past_key_values_prev = past_key_values
1427
+ past_key_values = Plamo2Cache(self.config)
1428
+
1429
+ # If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
1430
+ assert len(past_key_values_prev) == 0 or not any(
1431
+ layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
1432
+ )
1433
+ assert isinstance(past_key_values, Plamo2Cache)
1434
  past_key_values_length = past_key_values.get_seq_length()
1435
  seq_length_with_past = seq_length_with_past + past_key_values_length
1436
  assert cache_position is None, "cache_position is not supported yet"
 
1446
  require_attn_mask = False
1447
  if not self.training or past_key_values is not None:
1448
  require_attn_mask = True
1449
+ if seq_length_with_past > self.config.attention_window_size + 1:
1450
  require_attn_mask = True
1451
  if require_attn_mask and attention_mask is None:
1452
  attention_mask = torch.ones(
 
1635
  image_features: Optional[torch.Tensor] = None,
1636
  **kwargs: Any,
1637
  ) -> Dict[str, Any]:
1638
+ # Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
1639
+ # and its length becomes non-zero from v4.56 onward.
1640
+ # `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
1641
+ # se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
1642
+ if isinstance(past_key_values, Plamo2Cache):
1643
  input_ids = input_ids[:, -1:]
1644
  if image_features is not None:
1645
  image_features = image_features[:, -1:, :]
 
1649
  # create position_ids on the fly for batch generation
1650
  position_ids = attention_mask.long().cumsum(-1) - 1
1651
  position_ids.masked_fill_(attention_mask == 0, 1)
1652
+ if isinstance(past_key_values, Plamo2Cache):
1653
  position_ids = position_ids[:, -1].unsqueeze(-1)
1654
 
1655
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step