psidharth567 commited on
Commit
a69bde0
·
verified ·
1 Parent(s): 3de860e

Upload model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - as
5
+ - bn
6
+ - brx
7
+ - doi
8
+ - gu
9
+ - hi
10
+ - kn
11
+ - ks
12
+ - kok
13
+ - mai
14
+ - ml
15
+ - mni
16
+ - mr
17
+ - ne
18
+ - or
19
+ - pa
20
+ - sa
21
+ - sat
22
+ - sd
23
+ - ta
24
+ - te
25
+ - ur
26
+ license: mit
27
+ tags:
28
+ - punctuation-restoration
29
+ - multilingual
30
+ - indic-languages
31
+ - ai4bharat
32
+ datasets:
33
+ - ai4bharat/sangraha
34
+ - HuggingFaceFW/fineweb-2
35
+ - ai4bharat/indicvoices_r
36
+ - ai4bharat/IndicCorpV2
37
+ metrics:
38
+ - f1
39
+ pipeline_tag: token-classification
40
+ library_name: cadence-punctuation
41
+ base_model:
42
+ - google/gemma-3-1b-pt
43
+ widget:
44
+ - text: hello world how are you today
45
+ example_title: English Punctuation
46
+ - text: यह एक हिंदी वाक्य है
47
+ example_title: Hindi Punctuation
48
+ - text: cadence is a great model for punctuation
49
+ example_title: Another English Example
50
+ ---
51
+
52
+ # Cadence-Fast
53
+
54
+ This is a multilingual punctuation restoration model based on Gemma-3-270M, fine-tuned for punctuation prediction in English and Indic languages.
55
+
56
+ <a href="https://arxiv.org/abs/2506.03793v1" target="_blank" rel="noopener noreferrer" style="text-decoration: none; color: inherit;">
57
+ <span style="display: inline-flex; align-items: center; gap: 0.3em;">
58
+ <img src="https://huggingface.co/ai4bharat/Cadence/resolve/main/arxiv_logo.svg" alt="arXiv" style="height: 1em;">
59
+ <span>Mark My Words: A Robust Multilingual Model for Punctuation in Text and Speech Transcripts</span>
60
+ </span>
61
+ </a>
62
+
63
+ ## Model Description
64
+
65
+ - **Model Type**: Token Classification (Punctuation Prediction)
66
+ - **Base Model**: Gemma-3-270M
67
+ - **Languages**: English + 22 Indic Languages
68
+ - **Task**: Automatic punctuation restoration
69
+
70
+ ## Installation (Optional)
71
+ Python package has features such as sliding-window decoding, (rule-based) capitalisation of English text and some (rule-based) corrections for the errors made by the model.
72
+
73
+ ```bash
74
+ pip install cadence-punctuation
75
+ ```
76
+
77
+ ## Usage
78
+
79
+ ```python
80
+ from Cadence import PunctuationModel
81
+
82
+ # Load model from local path
83
+ model = PunctuationModel(model="Cadence-Fast","path/to/model")
84
+
85
+ # Punctuate single text
86
+ text = "hello world how are you today"
87
+ result = model.punctuate([text])
88
+ print(result[0]) # "Hello world, how are you today?"
89
+
90
+ # Punctuate multiple texts
91
+ texts = [
92
+ "hello world how are you",
93
+ "this is another test sentence",
94
+ "यह एक हिंदी वाक्य है" # Hindi example
95
+ ]
96
+ results = model.punctuate(texts, batch_size=8)
97
+ for original, punctuated in zip(texts, results):
98
+ print(f"Original: {original}")
99
+ print(f"Punctuated: {punctuated}")
100
+ print()
101
+ ```
102
+
103
+ ### Using AutoModel
104
+
105
+ ```python
106
+ from transformers import AutoTokenizer, AutoModel
107
+ import torch
108
+ # Load model and tokenizer
109
+ model_name = "ai4bharat/Cadence-Fast"
110
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
111
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
112
+ id2label = model.config.id2label
113
+ text = "यह एक वाक्य है इसका क्या मतलब है"
114
+ # text = "this is a test sentence what do you think"
115
+ # Tokenize input and prepare for model
116
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
117
+ input_ids = inputs['input_ids'][0] # Get input_ids for the first (and only) sentence
118
+ with torch.no_grad():
119
+ outputs = model(**inputs)
120
+ predictions_for_sentence = torch.argmax(outputs.logits, dim=-1)[0]
121
+ result_tokens_and_punctuation = []
122
+ all_token_strings = tokenizer.convert_ids_to_tokens(input_ids.tolist()) # Get all token strings
123
+ for i, token_id_value in enumerate(input_ids.tolist()):
124
+ # Process only non-padding tokens based on the attention mask
125
+ if inputs['attention_mask'][0][i] == 0:
126
+ continue
127
+ current_token_string = all_token_strings[i]
128
+ is_special_token = token_id_value in tokenizer.all_special_ids
129
+
130
+ if not is_special_token:
131
+ result_tokens_and_punctuation.append(current_token_string)
132
+
133
+ predicted_punctuation_id = predictions_for_sentence[i].item()
134
+ punctuation_character = id2label[predicted_punctuation_id]
135
+ if punctuation_character != "O" and not is_special_token:
136
+ result_tokens_and_punctuation.append(punctuation_character)
137
+ punctuated_text = tokenizer.convert_tokens_to_string(result_tokens_and_punctuation)
138
+ print(f"Original Text: {text}")
139
+ print(f"Punctuated Text: {punctuated_text}")
140
+ ```
141
+
142
+ ## Officially Supported Languages
143
+ - English, Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri, Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali, Sindhi, Tamil, Telugu, Urdu
144
+
145
+ Tokenizer doesn't support Manipuri's Meitei script well. The model can punctuate if the text is transliterated to Bengali's script.
146
+
147
+ One can try using this model for languages not listed above. Performance may vary.
148
+
149
+ ## Supported Punctuation
150
+ The model can predict the following punctuation marks:
151
+ - Period (.)
152
+ - Comma (,)
153
+ - Question mark (?)
154
+ - Exclamation mark (!)
155
+ - Semicolon (;)
156
+ - Colon (:)
157
+ - Hyphen (-)
158
+ - Quotes (" and ')
159
+ - Ellipse (...)
160
+ - Parentheses ()
161
+ - Hindi Danda (।)
162
+ - Urdu punctuation (۔، ؟)
163
+ - Arabic punctuation (٬ ،)
164
+ - Santali punctuation (᱾ ᱾।)
165
+ - Sanskrit punctuation (॥)
166
+ - And various combinations
167
+
168
+ ## Configuration Options for cadence-puncuation
169
+
170
+ ### PunctuationModel Parameters
171
+ All the parameters are optional to pass.
172
+ - `model`: Can be choose between "Cadence" (based on Gemma-3-1B) and "Cadence-Fast" (based on Gemma-3-270M) (default: "Cadence").
173
+ - `model_path`: Path to a local directory where model weights will be downloaded to and cached, or from which pre-downloaded weights will be loaded. If None, weights downloaded to default HuggingFace cache location.
174
+ - `gpu_id`: Specific GPU device ID to use (e.g., 0, 1). If None, the model will attempt to auto-detect and use an available GPU. This parameter is ignored if cpu is True. (default: None)
175
+ - `cpu`: If True, forces the model to run on the CPU, even if a GPU is available. (default: False)
176
+ - `max_length`: Maximum sequence length the model can process at once. If sliding_window is True, this value is used as the width of each sliding window. If sliding_window is False, texts longer than max_length will be truncated. (default: 300)
177
+ - `attn_implementation`: The attention implementation to use. (default: "eager")
178
+ - `sliding_window`: If True, enables sliding window mechanism to process texts longer than max_length. The text is split into overlapping chunks of max_length. If False, texts longer than max_length are truncated. (default: True)
179
+ - `verbose`: Enable verbose logging (default: False)
180
+ - `d_type`: Precision with which weights are loaded (default: bfloat16)
181
+ - `batch_size`: ((for punctuate() method)): Batch size to use (default: 8)
182
+
183
+ ```python
184
+ # Custom configuration
185
+ model = PunctuationModel(
186
+ model="Cadence"
187
+ model_path="path/to/download/weights",
188
+ gpu_id=0, # Use specific GPU
189
+ max_length=512, # length for trunation; also used as window size when sliding_window=True
190
+ attn_implementation="flash_attention_2",
191
+ sliding_window=True, # Handle long texts
192
+ verbose=False, # Quiet mode
193
+ d_type="bfloat16"
194
+ )
195
+ batch_size=32
196
+ # Process long texts with sliding window
197
+ long_text = "Your very long text here..." * 100
198
+ short_text = "a short text"
199
+ result = model.punctuate([long_text, short_text],batch_size=batch_size)
200
+ ```
201
+
202
+ ## License
203
+ MIT License
204
+
arxiv_logo (1).svg ADDED
config.json ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_sliding_window_pattern": 6,
3
+ "architectures": [
4
+ "Gemma3ForTokenClassification"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "attn_logit_softcapping": null,
9
+ "bos_token_id": 2,
10
+ "cache_implementation": "hybrid",
11
+ "classifier_dropout_prob": 0.0,
12
+ "dtype": "float32",
13
+ "eos_token_id": 1,
14
+ "final_logit_softcapping": null,
15
+ "head_dim": 256,
16
+ "hidden_activation": "gelu_pytorch_tanh",
17
+ "hidden_size": 640,
18
+ "id2label": {
19
+ "0": "O",
20
+ "1": ".",
21
+ "10": "\"",
22
+ "11": "\u0964",
23
+ "12": "(",
24
+ "13": ")",
25
+ "14": ":",
26
+ "15": "\u066c",
27
+ "16": "\u06d4",
28
+ "17": "\u061f",
29
+ "18": ".\"",
30
+ "19": ").",
31
+ "2": ",",
32
+ "20": "),",
33
+ "21": "\",",
34
+ "22": "\".",
35
+ "23": "?\"",
36
+ "24": "\"?",
37
+ "25": "\u0964\"",
38
+ "26": "\"\u0964",
39
+ "27": "\u060c",
40
+ "28": "\u1c7e",
41
+ "29": "\u0965",
42
+ "3": "?",
43
+ "30": "\u1c7e\u0964",
44
+ "4": "-",
45
+ "5": ";",
46
+ "6": "_",
47
+ "7": "!",
48
+ "8": "'",
49
+ "9": "..."
50
+ },
51
+ "initializer_range": 0.02,
52
+ "intermediate_size": 2048,
53
+ "label2id": {
54
+ "!": 7,
55
+ "\"": 10,
56
+ "\",": 21,
57
+ "\".": 22,
58
+ "\"?": 24,
59
+ "\"\u0964": 26,
60
+ "'": 8,
61
+ "(": 12,
62
+ ")": 13,
63
+ "),": 20,
64
+ ").": 19,
65
+ ",": 2,
66
+ "-": 4,
67
+ ".": 1,
68
+ ".\"": 18,
69
+ "...": 9,
70
+ ":": 14,
71
+ ";": 5,
72
+ "?": 3,
73
+ "?\"": 23,
74
+ "O": 0,
75
+ "_": 6,
76
+ "\u060c": 27,
77
+ "\u061f": 17,
78
+ "\u066c": 15,
79
+ "\u06d4": 16,
80
+ "\u0964": 11,
81
+ "\u0964\"": 25,
82
+ "\u0965": 29,
83
+ "\u1c7e": 28,
84
+ "\u1c7e\u0964": 30
85
+ },
86
+ "layer_types": [
87
+ "sliding_attention",
88
+ "sliding_attention",
89
+ "sliding_attention",
90
+ "sliding_attention",
91
+ "sliding_attention",
92
+ "full_attention",
93
+ "sliding_attention",
94
+ "sliding_attention",
95
+ "sliding_attention",
96
+ "sliding_attention",
97
+ "sliding_attention",
98
+ "full_attention",
99
+ "sliding_attention",
100
+ "sliding_attention",
101
+ "sliding_attention",
102
+ "sliding_attention",
103
+ "sliding_attention",
104
+ "full_attention"
105
+ ],
106
+ "max_position_embeddings": 32768,
107
+ "model_type": "cadence_punctuation",
108
+ "num_attention_heads": 4,
109
+ "num_hidden_layers": 18,
110
+ "num_key_value_heads": 1,
111
+ "pad_token_id": 0,
112
+ "query_pre_attn_scalar": 256,
113
+ "rms_norm_eps": 1e-06,
114
+ "rope_local_base_freq": 10000.0,
115
+ "rope_scaling": null,
116
+ "rope_theta": 1000000.0,
117
+ "sliding_window": 512,
118
+ "sliding_window_pattern": 6,
119
+ "transformers_version": "4.57.1",
120
+ "use_bidirectional_attention": false,
121
+ "use_cache": false,
122
+ "use_non_causal_attention": true,
123
+ "vocab_size": 262144
124
+ }
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "cache_implementation": "hybrid",
5
+ "eos_token_id": 1,
6
+ "pad_token_id": 0,
7
+ "transformers_version": "4.57.1"
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b60302bc4aeb9ca487e3733819aabf9bcf970792d2d8f10c8ee0a5f144d41552
3
+ size 1072498892
modeling_gemma3_punctuation.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Change the attention of Gemma3 to be bidirectional.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing import Optional, List, Dict, Any
8
+ from functools import partial
9
+
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers import Gemma3ForCausalLM, Gemma3TextConfig
12
+ from transformers.models.gemma3.modeling_gemma3 import (
13
+ Gemma3Attention,
14
+ Gemma3DecoderLayer,
15
+ Gemma3TextModel,
16
+
17
+ )
18
+
19
+ from transformers.modeling_outputs import TokenClassifierOutput
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Gemma3PunctuationConfig(Gemma3TextConfig):
26
+ """
27
+ Configuration class for Gemma3 punctuation model.
28
+ """
29
+ model_type = "cadence_punctuation"
30
+
31
+ def __init__(
32
+ self,
33
+ num_labels: int = 31,
34
+ classifier_dropout_prob: float = 0.0,
35
+ use_non_causal_attention: bool = True,
36
+ **kwargs
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.classifier_dropout_prob = classifier_dropout_prob
40
+ self.use_non_causal_attention = use_non_causal_attention
41
+ self.num_labels = num_labels
42
+
43
+
44
+ # ============ Token Classification Model Components ============
45
+
46
+ class NonCausalGemma3Attention(Gemma3Attention):
47
+ """Gemma3Attention configured for non-causal token classification."""
48
+ def __init__(self, config, layer_idx: int):
49
+ super().__init__(config, layer_idx)
50
+ self.is_causal = False
51
+ self.sliding_window = None
52
+
53
+
54
+ class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer):
55
+ """Decoder layer with non-causal attention for token classification."""
56
+ def __init__(self, config, layer_idx: int):
57
+ super().__init__(config, layer_idx)
58
+ self.self_attn = NonCausalGemma3Attention(config, layer_idx)
59
+
60
+
61
+ class Gemma3TokenClassificationModel(Gemma3TextModel):
62
+ """Gemma3 base model configured for token classification."""
63
+ _no_split_modules = ["NonCausalGemma3DecoderLayer"]
64
+
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+ if getattr(config, 'use_non_causal_attention', True):
68
+ # Replace layers with non-causal versions
69
+ self.layers = nn.ModuleList(
70
+ [
71
+ NonCausalGemma3DecoderLayer(config, layer_idx)
72
+ for layer_idx in range(config.num_hidden_layers)
73
+ ]
74
+ )
75
+
76
+ def _update_causal_mask(
77
+ self,
78
+ attention_mask: torch.Tensor,
79
+ input_tensor: torch.Tensor,
80
+ cache_position: torch.Tensor,
81
+ past_key_values = None,
82
+ output_attentions: bool = False,
83
+ ):
84
+ """Override to create bidirectional attention mask (no causal masking)."""
85
+ if self.config._attn_implementation == "flash_attention_2":
86
+ if attention_mask is not None and 0.0 in attention_mask:
87
+ return attention_mask
88
+ return None
89
+
90
+ past_seen_tokens = (
91
+ past_key_values.get_seq_length() if past_key_values is not None else 0
92
+ )
93
+ using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length')
94
+
95
+ dtype, device = input_tensor.dtype, input_tensor.device
96
+ min_dtype = torch.finfo(dtype).min
97
+ sequence_length = input_tensor.shape[1]
98
+
99
+ if using_static_cache:
100
+ target_length = past_key_values.get_max_length()
101
+ else:
102
+ target_length = (
103
+ attention_mask.shape[-1]
104
+ if isinstance(attention_mask, torch.Tensor)
105
+ else past_seen_tokens + sequence_length + 1
106
+ )
107
+
108
+ if attention_mask is not None and attention_mask.dim() == 4:
109
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
110
+ if attention_mask.max() != 0:
111
+ raise ValueError(
112
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
113
+ )
114
+ causal_mask = attention_mask
115
+ else:
116
+ # KEY CHANGE: Start with zeros (attend to all) instead of min_dtype (mask all)
117
+ causal_mask = torch.zeros(
118
+ (sequence_length, target_length), dtype=dtype, device=device
119
+ )
120
+ # REMOVED: Causal masking lines that would make it lower triangular
121
+ # if sequence_length != 1:
122
+ # causal_mask = torch.triu(causal_mask, diagonal=1)
123
+
124
+ causal_mask *= torch.arange(
125
+ target_length, device=device
126
+ ) > cache_position.reshape(-1, 1)
127
+ causal_mask = causal_mask[None, None, :, :].expand(
128
+ input_tensor.shape[0], 1, -1, -1
129
+ )
130
+
131
+ if attention_mask is not None:
132
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
133
+ mask_length = attention_mask.shape[-1]
134
+ padding_mask = (
135
+ causal_mask[:, :, :, :mask_length]
136
+ + attention_mask[:, None, None, :]
137
+ )
138
+ padding_mask = padding_mask == 0
139
+ causal_mask[:, :, :, :mask_length] = causal_mask[
140
+ :, :, :, :mask_length
141
+ ].masked_fill(padding_mask, min_dtype)
142
+
143
+ # Handle SDPA-specific optimizations if needed
144
+ if (
145
+ self.config._attn_implementation == "sdpa"
146
+ and attention_mask is not None
147
+ and attention_mask.device.type == "cuda"
148
+ and not output_attentions
149
+ ):
150
+ try:
151
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
152
+ causal_mask = AttentionMaskConverter._unmask_unattended(
153
+ causal_mask, min_dtype
154
+ )
155
+ except ImportError:
156
+ pass # Fallback for older transformers versions
157
+
158
+ return causal_mask
159
+
160
+
161
+ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
162
+ """
163
+ Gemma3 model for token classification (punctuation prediction).
164
+ Uses class-based architecture without monkey patching.
165
+ """
166
+
167
+ config_class = Gemma3PunctuationConfig
168
+
169
+ def __init__(self, config):
170
+ # Initialize with base Gemma3ForCausalLM structure
171
+ super().__init__(config)
172
+ self.num_labels = config.num_labels
173
+
174
+ # Replace the base model with token classification version
175
+ if getattr(config, 'use_non_causal_attention', True):
176
+ self.model = Gemma3TokenClassificationModel(config)
177
+
178
+ # Replace the lm_head with classification head
179
+ classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
180
+ self.lm_head = nn.Sequential(
181
+ nn.Dropout(classifier_dropout_prob),
182
+ nn.Linear(config.hidden_size, config.num_labels)
183
+ )
184
+
185
+ # Update config for classification
186
+ self.config.num_labels = config.num_labels
187
+
188
+ # Initialize weights for the new head
189
+ self.post_init()
190
+
191
+ def forward(
192
+ self,
193
+ input_ids: Optional[torch.LongTensor] = None,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ position_ids: Optional[torch.LongTensor] = None,
196
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
197
+ inputs_embeds: Optional[torch.FloatTensor] = None,
198
+ labels: Optional[torch.LongTensor] = None,
199
+ use_cache: Optional[bool] = None,
200
+ output_attentions: Optional[bool] = None,
201
+ output_hidden_states: Optional[bool] = None,
202
+ return_dict: Optional[bool] = None,
203
+ cache_position: Optional[torch.LongTensor] = None,
204
+ ) -> TokenClassifierOutput:
205
+ """Forward pass for token classification."""
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # Get hidden states from the model
209
+ outputs = self.model(
210
+ input_ids=input_ids,
211
+ attention_mask=attention_mask,
212
+ position_ids=position_ids,
213
+ past_key_values=past_key_values,
214
+ inputs_embeds=inputs_embeds,
215
+ use_cache=use_cache,
216
+ output_attentions=output_attentions,
217
+ output_hidden_states=output_hidden_states,
218
+ return_dict=return_dict,
219
+ cache_position=cache_position,
220
+ )
221
+
222
+ # Get the hidden states from the model output
223
+ sequence_output = outputs[0]
224
+
225
+ # Apply the classification head (which is now self.lm_head)
226
+ logits = self.lm_head(sequence_output)
227
+
228
+ loss = None
229
+ if labels is not None:
230
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
231
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
232
+
233
+ if not return_dict:
234
+ output = (logits,) + outputs[2:]
235
+ return ((loss,) + output) if loss is not None else output
236
+
237
+ return TokenClassifierOutput(
238
+ loss=loss,
239
+ logits=logits,
240
+ hidden_states=outputs.hidden_states,
241
+ attentions=outputs.attentions,
242
+ )
243
+
244
+
245
+ # ============ Model Registration ============
246
+
247
+ from transformers import AutoConfig, AutoModel
248
+
249
+ # Register the punctuation config and model
250
+ AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
251
+ AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
252
+
253
+
254
+ # ============ Utility Functions ============
255
+
256
+
257
+ def create_token_classification_model(config: Gemma3PunctuationConfig):
258
+ """Create a token classification model with non-causal attention."""
259
+ return Gemma3ForTokenClassification(config)
260
+
261
+
262
+ def load_from_pretrained_with_config_detection(model_path: str, **kwargs):
263
+ """
264
+ Load model and auto-detect whether it's for token classification or bidirectional tasks
265
+ based on the config.
266
+ """
267
+ from transformers import AutoConfig
268
+
269
+ config = AutoConfig.from_pretrained(model_path)
270
+
271
+ if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation":
272
+ # Token classification model
273
+ return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs)
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff