prithivMLmods commited on
Commit
f3c6f63
·
verified ·
1 Parent(s): 200ae8f

update config (Adding the Attributes) (#1)

Browse files

- update config (Adding the Attributes) (c0cae42dcb497ec01e4624565cabda1defd65486)

Files changed (1) hide show
  1. configuration_dots.py +78 -77
configuration_dots.py CHANGED
@@ -1,77 +1,78 @@
1
- from typing import Any, Optional
2
- from transformers.configuration_utils import PretrainedConfig
3
- from transformers.models.qwen2 import Qwen2Config
4
- from transformers import Qwen2_5_VLProcessor, AutoProcessor
5
- from transformers.models.auto.configuration_auto import CONFIG_MAPPING
6
-
7
-
8
- class DotsVisionConfig(PretrainedConfig):
9
- model_type: str = "dots_vit"
10
-
11
- def __init__(
12
- self,
13
- embed_dim: int = 1536, # vision encoder embed size
14
- hidden_size: int = 1536, # after merger hidden size
15
- intermediate_size: int = 4224,
16
- num_hidden_layers: int = 42,
17
- num_attention_heads: int = 12,
18
- num_channels: int = 3,
19
- patch_size: int = 14,
20
- spatial_merge_size: int = 2,
21
- temporal_patch_size: int = 1,
22
- rms_norm_eps: float = 1e-5,
23
- use_bias: bool = False,
24
- attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2"
25
- initializer_range=0.02,
26
- init_merger_std=0.02,
27
- is_causal=False, # ve causal forward
28
- post_norm=True,
29
- gradient_checkpointing=False,
30
- **kwargs: Any,
31
- ):
32
- super().__init__(**kwargs)
33
- self.embed_dim = embed_dim
34
- self.hidden_size = hidden_size
35
- self.intermediate_size = intermediate_size
36
- self.num_hidden_layers = num_hidden_layers
37
- self.num_attention_heads = num_attention_heads
38
- self.num_channels = num_channels
39
- self.patch_size = patch_size
40
- self.spatial_merge_size = spatial_merge_size
41
- self.temporal_patch_size = temporal_patch_size
42
- self.rms_norm_eps = rms_norm_eps
43
- self.use_bias = use_bias
44
- self.attn_implementation = attn_implementation
45
- self.initializer_range = initializer_range
46
- self.init_merger_std = init_merger_std
47
- self.is_causal = is_causal
48
- self.post_norm = post_norm
49
- self.gradient_checkpointing = gradient_checkpointing
50
-
51
-
52
-
53
- class DotsOCRConfig(Qwen2Config):
54
- model_type = "dots_ocr"
55
- def __init__(self,
56
- image_token_id = 151665,
57
- video_token_id = 151656,
58
- vision_config: Optional[dict] = None, *args, **kwargs):
59
- super().__init__(*args, **kwargs)
60
- self.image_token_id = image_token_id
61
- self.video_token_id = video_token_id
62
- self.vision_config = DotsVisionConfig(**(vision_config or {}))
63
-
64
- def save_pretrained(self, save_directory, **kwargs):
65
- self._auto_class = None
66
- super().save_pretrained(save_directory, **kwargs)
67
-
68
-
69
- class DotsVLProcessor(Qwen2_5_VLProcessor):
70
- def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
71
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
72
- self.image_token = "<|imgpad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
73
- self.image_token_id = 151665 if not hasattr(tokenizer, "image_token_id") else tokenizer.image_token_id
74
-
75
-
76
- AutoProcessor.register("dots_ocr", DotsVLProcessor)
77
- CONFIG_MAPPING.register("dots_ocr", DotsOCRConfig)
 
 
1
+ from typing import Any, Optional
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.models.qwen2 import Qwen2Config
4
+ from transformers import Qwen2_5_VLProcessor, AutoProcessor
5
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
6
+
7
+
8
+ class DotsVisionConfig(PretrainedConfig):
9
+ model_type: str = "dots_vit"
10
+
11
+ def __init__(
12
+ self,
13
+ embed_dim: int = 1536, # vision encoder embed size
14
+ hidden_size: int = 1536, # after merger hidden size
15
+ intermediate_size: int = 4224,
16
+ num_hidden_layers: int = 42,
17
+ num_attention_heads: int = 12,
18
+ num_channels: int = 3,
19
+ patch_size: int = 14,
20
+ spatial_merge_size: int = 2,
21
+ temporal_patch_size: int = 1,
22
+ rms_norm_eps: float = 1e-5,
23
+ use_bias: bool = False,
24
+ attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2"
25
+ initializer_range=0.02,
26
+ init_merger_std=0.02,
27
+ is_causal=False, # ve causal forward
28
+ post_norm=True,
29
+ gradient_checkpointing=False,
30
+ **kwargs: Any,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.embed_dim = embed_dim
34
+ self.hidden_size = hidden_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+ self.num_channels = num_channels
39
+ self.patch_size = patch_size
40
+ self.spatial_merge_size = spatial_merge_size
41
+ self.temporal_patch_size = temporal_patch_size
42
+ self.rms_norm_eps = rms_norm_eps
43
+ self.use_bias = use_bias
44
+ self.attn_implementation = attn_implementation
45
+ self.initializer_range = initializer_range
46
+ self.init_merger_std = init_merger_std
47
+ self.is_causal = is_causal
48
+ self.post_norm = post_norm
49
+ self.gradient_checkpointing = gradient_checkpointing
50
+
51
+
52
+
53
+ class DotsOCRConfig(Qwen2Config):
54
+ model_type = "dots_ocr"
55
+ def __init__(self,
56
+ image_token_id = 151665,
57
+ video_token_id = 151656,
58
+ vision_config: Optional[dict] = None, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
60
+ self.image_token_id = image_token_id
61
+ self.video_token_id = video_token_id
62
+ self.vision_config = DotsVisionConfig(**(vision_config or {}))
63
+
64
+ def save_pretrained(self, save_directory, **kwargs):
65
+ self._auto_class = None
66
+ super().save_pretrained(save_directory, **kwargs)
67
+
68
+
69
+ class DotsVLProcessor(Qwen2_5_VLProcessor):
70
+ attributes = ["image_processor", "tokenizer"]
71
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
72
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
73
+ self.image_token = "<|imgpad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
74
+ self.image_token_id = 151665 if not hasattr(tokenizer, "image_token_id") else tokenizer.image_token_id
75
+
76
+
77
+ AutoProcessor.register("dots_ocr", DotsVLProcessor)
78
+ CONFIG_MAPPING.register("dots_ocr", DotsOCRConfig)