Upload model
Browse files- .gitattributes +1 -0
- README.md +204 -0
- arxiv_logo (1).svg +1 -0
- config.json +124 -0
- generation_config.json +8 -0
- model.safetensors +3 -0
- modeling_gemma3_punctuation.py +273 -0
- special_tokens_map.json +33 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- as
|
| 5 |
+
- bn
|
| 6 |
+
- brx
|
| 7 |
+
- doi
|
| 8 |
+
- gu
|
| 9 |
+
- hi
|
| 10 |
+
- kn
|
| 11 |
+
- ks
|
| 12 |
+
- kok
|
| 13 |
+
- mai
|
| 14 |
+
- ml
|
| 15 |
+
- mni
|
| 16 |
+
- mr
|
| 17 |
+
- ne
|
| 18 |
+
- or
|
| 19 |
+
- pa
|
| 20 |
+
- sa
|
| 21 |
+
- sat
|
| 22 |
+
- sd
|
| 23 |
+
- ta
|
| 24 |
+
- te
|
| 25 |
+
- ur
|
| 26 |
+
license: mit
|
| 27 |
+
tags:
|
| 28 |
+
- punctuation-restoration
|
| 29 |
+
- multilingual
|
| 30 |
+
- indic-languages
|
| 31 |
+
- ai4bharat
|
| 32 |
+
datasets:
|
| 33 |
+
- ai4bharat/sangraha
|
| 34 |
+
- HuggingFaceFW/fineweb-2
|
| 35 |
+
- ai4bharat/indicvoices_r
|
| 36 |
+
- ai4bharat/IndicCorpV2
|
| 37 |
+
metrics:
|
| 38 |
+
- f1
|
| 39 |
+
pipeline_tag: token-classification
|
| 40 |
+
library_name: cadence-punctuation
|
| 41 |
+
base_model:
|
| 42 |
+
- google/gemma-3-1b-pt
|
| 43 |
+
widget:
|
| 44 |
+
- text: hello world how are you today
|
| 45 |
+
example_title: English Punctuation
|
| 46 |
+
- text: यह एक हिंदी वाक्य है
|
| 47 |
+
example_title: Hindi Punctuation
|
| 48 |
+
- text: cadence is a great model for punctuation
|
| 49 |
+
example_title: Another English Example
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
# Cadence-Fast
|
| 53 |
+
|
| 54 |
+
This is a multilingual punctuation restoration model based on Gemma-3-270M, fine-tuned for punctuation prediction in English and Indic languages.
|
| 55 |
+
|
| 56 |
+
<a href="https://arxiv.org/abs/2506.03793v1" target="_blank" rel="noopener noreferrer" style="text-decoration: none; color: inherit;">
|
| 57 |
+
<span style="display: inline-flex; align-items: center; gap: 0.3em;">
|
| 58 |
+
<img src="https://huggingface.co/ai4bharat/Cadence/resolve/main/arxiv_logo.svg" alt="arXiv" style="height: 1em;">
|
| 59 |
+
<span>Mark My Words: A Robust Multilingual Model for Punctuation in Text and Speech Transcripts</span>
|
| 60 |
+
</span>
|
| 61 |
+
</a>
|
| 62 |
+
|
| 63 |
+
## Model Description
|
| 64 |
+
|
| 65 |
+
- **Model Type**: Token Classification (Punctuation Prediction)
|
| 66 |
+
- **Base Model**: Gemma-3-270M
|
| 67 |
+
- **Languages**: English + 22 Indic Languages
|
| 68 |
+
- **Task**: Automatic punctuation restoration
|
| 69 |
+
|
| 70 |
+
## Installation (Optional)
|
| 71 |
+
Python package has features such as sliding-window decoding, (rule-based) capitalisation of English text and some (rule-based) corrections for the errors made by the model.
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
pip install cadence-punctuation
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Usage
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
from Cadence import PunctuationModel
|
| 81 |
+
|
| 82 |
+
# Load model from local path
|
| 83 |
+
model = PunctuationModel(model="Cadence-Fast","path/to/model")
|
| 84 |
+
|
| 85 |
+
# Punctuate single text
|
| 86 |
+
text = "hello world how are you today"
|
| 87 |
+
result = model.punctuate([text])
|
| 88 |
+
print(result[0]) # "Hello world, how are you today?"
|
| 89 |
+
|
| 90 |
+
# Punctuate multiple texts
|
| 91 |
+
texts = [
|
| 92 |
+
"hello world how are you",
|
| 93 |
+
"this is another test sentence",
|
| 94 |
+
"यह एक हिंदी वाक्य है" # Hindi example
|
| 95 |
+
]
|
| 96 |
+
results = model.punctuate(texts, batch_size=8)
|
| 97 |
+
for original, punctuated in zip(texts, results):
|
| 98 |
+
print(f"Original: {original}")
|
| 99 |
+
print(f"Punctuated: {punctuated}")
|
| 100 |
+
print()
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Using AutoModel
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
from transformers import AutoTokenizer, AutoModel
|
| 107 |
+
import torch
|
| 108 |
+
# Load model and tokenizer
|
| 109 |
+
model_name = "ai4bharat/Cadence-Fast"
|
| 110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 111 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
| 112 |
+
id2label = model.config.id2label
|
| 113 |
+
text = "यह एक वाक्य है इसका क्या मतलब है"
|
| 114 |
+
# text = "this is a test sentence what do you think"
|
| 115 |
+
# Tokenize input and prepare for model
|
| 116 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
| 117 |
+
input_ids = inputs['input_ids'][0] # Get input_ids for the first (and only) sentence
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
outputs = model(**inputs)
|
| 120 |
+
predictions_for_sentence = torch.argmax(outputs.logits, dim=-1)[0]
|
| 121 |
+
result_tokens_and_punctuation = []
|
| 122 |
+
all_token_strings = tokenizer.convert_ids_to_tokens(input_ids.tolist()) # Get all token strings
|
| 123 |
+
for i, token_id_value in enumerate(input_ids.tolist()):
|
| 124 |
+
# Process only non-padding tokens based on the attention mask
|
| 125 |
+
if inputs['attention_mask'][0][i] == 0:
|
| 126 |
+
continue
|
| 127 |
+
current_token_string = all_token_strings[i]
|
| 128 |
+
is_special_token = token_id_value in tokenizer.all_special_ids
|
| 129 |
+
|
| 130 |
+
if not is_special_token:
|
| 131 |
+
result_tokens_and_punctuation.append(current_token_string)
|
| 132 |
+
|
| 133 |
+
predicted_punctuation_id = predictions_for_sentence[i].item()
|
| 134 |
+
punctuation_character = id2label[predicted_punctuation_id]
|
| 135 |
+
if punctuation_character != "O" and not is_special_token:
|
| 136 |
+
result_tokens_and_punctuation.append(punctuation_character)
|
| 137 |
+
punctuated_text = tokenizer.convert_tokens_to_string(result_tokens_and_punctuation)
|
| 138 |
+
print(f"Original Text: {text}")
|
| 139 |
+
print(f"Punctuated Text: {punctuated_text}")
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Officially Supported Languages
|
| 143 |
+
- English, Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri, Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali, Sindhi, Tamil, Telugu, Urdu
|
| 144 |
+
|
| 145 |
+
Tokenizer doesn't support Manipuri's Meitei script well. The model can punctuate if the text is transliterated to Bengali's script.
|
| 146 |
+
|
| 147 |
+
One can try using this model for languages not listed above. Performance may vary.
|
| 148 |
+
|
| 149 |
+
## Supported Punctuation
|
| 150 |
+
The model can predict the following punctuation marks:
|
| 151 |
+
- Period (.)
|
| 152 |
+
- Comma (,)
|
| 153 |
+
- Question mark (?)
|
| 154 |
+
- Exclamation mark (!)
|
| 155 |
+
- Semicolon (;)
|
| 156 |
+
- Colon (:)
|
| 157 |
+
- Hyphen (-)
|
| 158 |
+
- Quotes (" and ')
|
| 159 |
+
- Ellipse (...)
|
| 160 |
+
- Parentheses ()
|
| 161 |
+
- Hindi Danda (।)
|
| 162 |
+
- Urdu punctuation (۔، ؟)
|
| 163 |
+
- Arabic punctuation (٬ ،)
|
| 164 |
+
- Santali punctuation (᱾ ᱾।)
|
| 165 |
+
- Sanskrit punctuation (॥)
|
| 166 |
+
- And various combinations
|
| 167 |
+
|
| 168 |
+
## Configuration Options for cadence-puncuation
|
| 169 |
+
|
| 170 |
+
### PunctuationModel Parameters
|
| 171 |
+
All the parameters are optional to pass.
|
| 172 |
+
- `model`: Can be choose between "Cadence" (based on Gemma-3-1B) and "Cadence-Fast" (based on Gemma-3-270M) (default: "Cadence").
|
| 173 |
+
- `model_path`: Path to a local directory where model weights will be downloaded to and cached, or from which pre-downloaded weights will be loaded. If None, weights downloaded to default HuggingFace cache location.
|
| 174 |
+
- `gpu_id`: Specific GPU device ID to use (e.g., 0, 1). If None, the model will attempt to auto-detect and use an available GPU. This parameter is ignored if cpu is True. (default: None)
|
| 175 |
+
- `cpu`: If True, forces the model to run on the CPU, even if a GPU is available. (default: False)
|
| 176 |
+
- `max_length`: Maximum sequence length the model can process at once. If sliding_window is True, this value is used as the width of each sliding window. If sliding_window is False, texts longer than max_length will be truncated. (default: 300)
|
| 177 |
+
- `attn_implementation`: The attention implementation to use. (default: "eager")
|
| 178 |
+
- `sliding_window`: If True, enables sliding window mechanism to process texts longer than max_length. The text is split into overlapping chunks of max_length. If False, texts longer than max_length are truncated. (default: True)
|
| 179 |
+
- `verbose`: Enable verbose logging (default: False)
|
| 180 |
+
- `d_type`: Precision with which weights are loaded (default: bfloat16)
|
| 181 |
+
- `batch_size`: ((for punctuate() method)): Batch size to use (default: 8)
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
# Custom configuration
|
| 185 |
+
model = PunctuationModel(
|
| 186 |
+
model="Cadence"
|
| 187 |
+
model_path="path/to/download/weights",
|
| 188 |
+
gpu_id=0, # Use specific GPU
|
| 189 |
+
max_length=512, # length for trunation; also used as window size when sliding_window=True
|
| 190 |
+
attn_implementation="flash_attention_2",
|
| 191 |
+
sliding_window=True, # Handle long texts
|
| 192 |
+
verbose=False, # Quiet mode
|
| 193 |
+
d_type="bfloat16"
|
| 194 |
+
)
|
| 195 |
+
batch_size=32
|
| 196 |
+
# Process long texts with sliding window
|
| 197 |
+
long_text = "Your very long text here..." * 100
|
| 198 |
+
short_text = "a short text"
|
| 199 |
+
result = model.punctuate([long_text, short_text],batch_size=batch_size)
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## License
|
| 203 |
+
MIT License
|
| 204 |
+
|
arxiv_logo (1).svg
ADDED
|
|
config.json
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_sliding_window_pattern": 6,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Gemma3ForTokenClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_dropout": 0.0,
|
| 8 |
+
"attn_logit_softcapping": null,
|
| 9 |
+
"bos_token_id": 2,
|
| 10 |
+
"cache_implementation": "hybrid",
|
| 11 |
+
"classifier_dropout_prob": 0.0,
|
| 12 |
+
"dtype": "float32",
|
| 13 |
+
"eos_token_id": 1,
|
| 14 |
+
"final_logit_softcapping": null,
|
| 15 |
+
"head_dim": 256,
|
| 16 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 17 |
+
"hidden_size": 640,
|
| 18 |
+
"id2label": {
|
| 19 |
+
"0": "O",
|
| 20 |
+
"1": ".",
|
| 21 |
+
"10": "\"",
|
| 22 |
+
"11": "\u0964",
|
| 23 |
+
"12": "(",
|
| 24 |
+
"13": ")",
|
| 25 |
+
"14": ":",
|
| 26 |
+
"15": "\u066c",
|
| 27 |
+
"16": "\u06d4",
|
| 28 |
+
"17": "\u061f",
|
| 29 |
+
"18": ".\"",
|
| 30 |
+
"19": ").",
|
| 31 |
+
"2": ",",
|
| 32 |
+
"20": "),",
|
| 33 |
+
"21": "\",",
|
| 34 |
+
"22": "\".",
|
| 35 |
+
"23": "?\"",
|
| 36 |
+
"24": "\"?",
|
| 37 |
+
"25": "\u0964\"",
|
| 38 |
+
"26": "\"\u0964",
|
| 39 |
+
"27": "\u060c",
|
| 40 |
+
"28": "\u1c7e",
|
| 41 |
+
"29": "\u0965",
|
| 42 |
+
"3": "?",
|
| 43 |
+
"30": "\u1c7e\u0964",
|
| 44 |
+
"4": "-",
|
| 45 |
+
"5": ";",
|
| 46 |
+
"6": "_",
|
| 47 |
+
"7": "!",
|
| 48 |
+
"8": "'",
|
| 49 |
+
"9": "..."
|
| 50 |
+
},
|
| 51 |
+
"initializer_range": 0.02,
|
| 52 |
+
"intermediate_size": 2048,
|
| 53 |
+
"label2id": {
|
| 54 |
+
"!": 7,
|
| 55 |
+
"\"": 10,
|
| 56 |
+
"\",": 21,
|
| 57 |
+
"\".": 22,
|
| 58 |
+
"\"?": 24,
|
| 59 |
+
"\"\u0964": 26,
|
| 60 |
+
"'": 8,
|
| 61 |
+
"(": 12,
|
| 62 |
+
")": 13,
|
| 63 |
+
"),": 20,
|
| 64 |
+
").": 19,
|
| 65 |
+
",": 2,
|
| 66 |
+
"-": 4,
|
| 67 |
+
".": 1,
|
| 68 |
+
".\"": 18,
|
| 69 |
+
"...": 9,
|
| 70 |
+
":": 14,
|
| 71 |
+
";": 5,
|
| 72 |
+
"?": 3,
|
| 73 |
+
"?\"": 23,
|
| 74 |
+
"O": 0,
|
| 75 |
+
"_": 6,
|
| 76 |
+
"\u060c": 27,
|
| 77 |
+
"\u061f": 17,
|
| 78 |
+
"\u066c": 15,
|
| 79 |
+
"\u06d4": 16,
|
| 80 |
+
"\u0964": 11,
|
| 81 |
+
"\u0964\"": 25,
|
| 82 |
+
"\u0965": 29,
|
| 83 |
+
"\u1c7e": 28,
|
| 84 |
+
"\u1c7e\u0964": 30
|
| 85 |
+
},
|
| 86 |
+
"layer_types": [
|
| 87 |
+
"sliding_attention",
|
| 88 |
+
"sliding_attention",
|
| 89 |
+
"sliding_attention",
|
| 90 |
+
"sliding_attention",
|
| 91 |
+
"sliding_attention",
|
| 92 |
+
"full_attention",
|
| 93 |
+
"sliding_attention",
|
| 94 |
+
"sliding_attention",
|
| 95 |
+
"sliding_attention",
|
| 96 |
+
"sliding_attention",
|
| 97 |
+
"sliding_attention",
|
| 98 |
+
"full_attention",
|
| 99 |
+
"sliding_attention",
|
| 100 |
+
"sliding_attention",
|
| 101 |
+
"sliding_attention",
|
| 102 |
+
"sliding_attention",
|
| 103 |
+
"sliding_attention",
|
| 104 |
+
"full_attention"
|
| 105 |
+
],
|
| 106 |
+
"max_position_embeddings": 32768,
|
| 107 |
+
"model_type": "cadence_punctuation",
|
| 108 |
+
"num_attention_heads": 4,
|
| 109 |
+
"num_hidden_layers": 18,
|
| 110 |
+
"num_key_value_heads": 1,
|
| 111 |
+
"pad_token_id": 0,
|
| 112 |
+
"query_pre_attn_scalar": 256,
|
| 113 |
+
"rms_norm_eps": 1e-06,
|
| 114 |
+
"rope_local_base_freq": 10000.0,
|
| 115 |
+
"rope_scaling": null,
|
| 116 |
+
"rope_theta": 1000000.0,
|
| 117 |
+
"sliding_window": 512,
|
| 118 |
+
"sliding_window_pattern": 6,
|
| 119 |
+
"transformers_version": "4.57.1",
|
| 120 |
+
"use_bidirectional_attention": false,
|
| 121 |
+
"use_cache": false,
|
| 122 |
+
"use_non_causal_attention": true,
|
| 123 |
+
"vocab_size": 262144
|
| 124 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 2,
|
| 4 |
+
"cache_implementation": "hybrid",
|
| 5 |
+
"eos_token_id": 1,
|
| 6 |
+
"pad_token_id": 0,
|
| 7 |
+
"transformers_version": "4.57.1"
|
| 8 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b60302bc4aeb9ca487e3733819aabf9bcf970792d2d8f10c8ee0a5f144d41552
|
| 3 |
+
size 1072498892
|
modeling_gemma3_punctuation.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Change the attention of Gemma3 to be bidirectional.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from typing import Optional, List, Dict, Any
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 11 |
+
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
|
| 12 |
+
from transformers.models.gemma3.modeling_gemma3 import (
|
| 13 |
+
Gemma3Attention,
|
| 14 |
+
Gemma3DecoderLayer,
|
| 15 |
+
Gemma3TextModel,
|
| 16 |
+
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Gemma3PunctuationConfig(Gemma3TextConfig):
|
| 26 |
+
"""
|
| 27 |
+
Configuration class for Gemma3 punctuation model.
|
| 28 |
+
"""
|
| 29 |
+
model_type = "cadence_punctuation"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
num_labels: int = 31,
|
| 34 |
+
classifier_dropout_prob: float = 0.0,
|
| 35 |
+
use_non_causal_attention: bool = True,
|
| 36 |
+
**kwargs
|
| 37 |
+
):
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
| 40 |
+
self.use_non_causal_attention = use_non_causal_attention
|
| 41 |
+
self.num_labels = num_labels
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ============ Token Classification Model Components ============
|
| 45 |
+
|
| 46 |
+
class NonCausalGemma3Attention(Gemma3Attention):
|
| 47 |
+
"""Gemma3Attention configured for non-causal token classification."""
|
| 48 |
+
def __init__(self, config, layer_idx: int):
|
| 49 |
+
super().__init__(config, layer_idx)
|
| 50 |
+
self.is_causal = False
|
| 51 |
+
self.sliding_window = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer):
|
| 55 |
+
"""Decoder layer with non-causal attention for token classification."""
|
| 56 |
+
def __init__(self, config, layer_idx: int):
|
| 57 |
+
super().__init__(config, layer_idx)
|
| 58 |
+
self.self_attn = NonCausalGemma3Attention(config, layer_idx)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Gemma3TokenClassificationModel(Gemma3TextModel):
|
| 62 |
+
"""Gemma3 base model configured for token classification."""
|
| 63 |
+
_no_split_modules = ["NonCausalGemma3DecoderLayer"]
|
| 64 |
+
|
| 65 |
+
def __init__(self, config):
|
| 66 |
+
super().__init__(config)
|
| 67 |
+
if getattr(config, 'use_non_causal_attention', True):
|
| 68 |
+
# Replace layers with non-causal versions
|
| 69 |
+
self.layers = nn.ModuleList(
|
| 70 |
+
[
|
| 71 |
+
NonCausalGemma3DecoderLayer(config, layer_idx)
|
| 72 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _update_causal_mask(
|
| 77 |
+
self,
|
| 78 |
+
attention_mask: torch.Tensor,
|
| 79 |
+
input_tensor: torch.Tensor,
|
| 80 |
+
cache_position: torch.Tensor,
|
| 81 |
+
past_key_values = None,
|
| 82 |
+
output_attentions: bool = False,
|
| 83 |
+
):
|
| 84 |
+
"""Override to create bidirectional attention mask (no causal masking)."""
|
| 85 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 86 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 87 |
+
return attention_mask
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
past_seen_tokens = (
|
| 91 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 92 |
+
)
|
| 93 |
+
using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length')
|
| 94 |
+
|
| 95 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 96 |
+
min_dtype = torch.finfo(dtype).min
|
| 97 |
+
sequence_length = input_tensor.shape[1]
|
| 98 |
+
|
| 99 |
+
if using_static_cache:
|
| 100 |
+
target_length = past_key_values.get_max_length()
|
| 101 |
+
else:
|
| 102 |
+
target_length = (
|
| 103 |
+
attention_mask.shape[-1]
|
| 104 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 105 |
+
else past_seen_tokens + sequence_length + 1
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 109 |
+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
| 110 |
+
if attention_mask.max() != 0:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Custom 4D attention mask should be passed in inverted form with max==0`"
|
| 113 |
+
)
|
| 114 |
+
causal_mask = attention_mask
|
| 115 |
+
else:
|
| 116 |
+
# KEY CHANGE: Start with zeros (attend to all) instead of min_dtype (mask all)
|
| 117 |
+
causal_mask = torch.zeros(
|
| 118 |
+
(sequence_length, target_length), dtype=dtype, device=device
|
| 119 |
+
)
|
| 120 |
+
# REMOVED: Causal masking lines that would make it lower triangular
|
| 121 |
+
# if sequence_length != 1:
|
| 122 |
+
# causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 123 |
+
|
| 124 |
+
causal_mask *= torch.arange(
|
| 125 |
+
target_length, device=device
|
| 126 |
+
) > cache_position.reshape(-1, 1)
|
| 127 |
+
causal_mask = causal_mask[None, None, :, :].expand(
|
| 128 |
+
input_tensor.shape[0], 1, -1, -1
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if attention_mask is not None:
|
| 132 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 133 |
+
mask_length = attention_mask.shape[-1]
|
| 134 |
+
padding_mask = (
|
| 135 |
+
causal_mask[:, :, :, :mask_length]
|
| 136 |
+
+ attention_mask[:, None, None, :]
|
| 137 |
+
)
|
| 138 |
+
padding_mask = padding_mask == 0
|
| 139 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
| 140 |
+
:, :, :, :mask_length
|
| 141 |
+
].masked_fill(padding_mask, min_dtype)
|
| 142 |
+
|
| 143 |
+
# Handle SDPA-specific optimizations if needed
|
| 144 |
+
if (
|
| 145 |
+
self.config._attn_implementation == "sdpa"
|
| 146 |
+
and attention_mask is not None
|
| 147 |
+
and attention_mask.device.type == "cuda"
|
| 148 |
+
and not output_attentions
|
| 149 |
+
):
|
| 150 |
+
try:
|
| 151 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 152 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(
|
| 153 |
+
causal_mask, min_dtype
|
| 154 |
+
)
|
| 155 |
+
except ImportError:
|
| 156 |
+
pass # Fallback for older transformers versions
|
| 157 |
+
|
| 158 |
+
return causal_mask
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
| 162 |
+
"""
|
| 163 |
+
Gemma3 model for token classification (punctuation prediction).
|
| 164 |
+
Uses class-based architecture without monkey patching.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
config_class = Gemma3PunctuationConfig
|
| 168 |
+
|
| 169 |
+
def __init__(self, config):
|
| 170 |
+
# Initialize with base Gemma3ForCausalLM structure
|
| 171 |
+
super().__init__(config)
|
| 172 |
+
self.num_labels = config.num_labels
|
| 173 |
+
|
| 174 |
+
# Replace the base model with token classification version
|
| 175 |
+
if getattr(config, 'use_non_causal_attention', True):
|
| 176 |
+
self.model = Gemma3TokenClassificationModel(config)
|
| 177 |
+
|
| 178 |
+
# Replace the lm_head with classification head
|
| 179 |
+
classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
|
| 180 |
+
self.lm_head = nn.Sequential(
|
| 181 |
+
nn.Dropout(classifier_dropout_prob),
|
| 182 |
+
nn.Linear(config.hidden_size, config.num_labels)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Update config for classification
|
| 186 |
+
self.config.num_labels = config.num_labels
|
| 187 |
+
|
| 188 |
+
# Initialize weights for the new head
|
| 189 |
+
self.post_init()
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 194 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 195 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 196 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 197 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 198 |
+
labels: Optional[torch.LongTensor] = None,
|
| 199 |
+
use_cache: Optional[bool] = None,
|
| 200 |
+
output_attentions: Optional[bool] = None,
|
| 201 |
+
output_hidden_states: Optional[bool] = None,
|
| 202 |
+
return_dict: Optional[bool] = None,
|
| 203 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 204 |
+
) -> TokenClassifierOutput:
|
| 205 |
+
"""Forward pass for token classification."""
|
| 206 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 207 |
+
|
| 208 |
+
# Get hidden states from the model
|
| 209 |
+
outputs = self.model(
|
| 210 |
+
input_ids=input_ids,
|
| 211 |
+
attention_mask=attention_mask,
|
| 212 |
+
position_ids=position_ids,
|
| 213 |
+
past_key_values=past_key_values,
|
| 214 |
+
inputs_embeds=inputs_embeds,
|
| 215 |
+
use_cache=use_cache,
|
| 216 |
+
output_attentions=output_attentions,
|
| 217 |
+
output_hidden_states=output_hidden_states,
|
| 218 |
+
return_dict=return_dict,
|
| 219 |
+
cache_position=cache_position,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Get the hidden states from the model output
|
| 223 |
+
sequence_output = outputs[0]
|
| 224 |
+
|
| 225 |
+
# Apply the classification head (which is now self.lm_head)
|
| 226 |
+
logits = self.lm_head(sequence_output)
|
| 227 |
+
|
| 228 |
+
loss = None
|
| 229 |
+
if labels is not None:
|
| 230 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 231 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 232 |
+
|
| 233 |
+
if not return_dict:
|
| 234 |
+
output = (logits,) + outputs[2:]
|
| 235 |
+
return ((loss,) + output) if loss is not None else output
|
| 236 |
+
|
| 237 |
+
return TokenClassifierOutput(
|
| 238 |
+
loss=loss,
|
| 239 |
+
logits=logits,
|
| 240 |
+
hidden_states=outputs.hidden_states,
|
| 241 |
+
attentions=outputs.attentions,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ============ Model Registration ============
|
| 246 |
+
|
| 247 |
+
from transformers import AutoConfig, AutoModel
|
| 248 |
+
|
| 249 |
+
# Register the punctuation config and model
|
| 250 |
+
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
|
| 251 |
+
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ============ Utility Functions ============
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def create_token_classification_model(config: Gemma3PunctuationConfig):
|
| 258 |
+
"""Create a token classification model with non-causal attention."""
|
| 259 |
+
return Gemma3ForTokenClassification(config)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def load_from_pretrained_with_config_detection(model_path: str, **kwargs):
|
| 263 |
+
"""
|
| 264 |
+
Load model and auto-detect whether it's for token classification or bidirectional tasks
|
| 265 |
+
based on the config.
|
| 266 |
+
"""
|
| 267 |
+
from transformers import AutoConfig
|
| 268 |
+
|
| 269 |
+
config = AutoConfig.from_pretrained(model_path)
|
| 270 |
+
|
| 271 |
+
if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation":
|
| 272 |
+
# Token classification model
|
| 273 |
+
return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<eos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
| 3 |
+
size 33384568
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|