ggunio commited on
Commit
0a81e70
·
verified ·
1 Parent(s): 2fb6cef

Upload src/core/byte_tokenizer_v6.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/core/byte_tokenizer_v6.py +263 -0
src/core/byte_tokenizer_v6.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte-Level Tokenizer V6 - Pure Learning Based
3
+ No vocabulary, no language rules - just bytes
4
+ """
5
+
6
+ import torch
7
+ from typing import List, Dict, Union, Optional
8
+ import numpy as np
9
+
10
+
11
+ class ByteTokenizerV6:
12
+ """
13
+ Pure byte-level tokenizer
14
+ - No vocabulary needed (bytes are 0-255)
15
+ - No language-specific rules
16
+ - Model learns all patterns from data
17
+ """
18
+
19
+ def __init__(self, max_seq_len: int = 512):
20
+ """Initialize byte tokenizer"""
21
+
22
+ self.max_seq_len = max_seq_len
23
+
24
+ # Special tokens (beyond byte range 0-255)
25
+ self.PAD = 256
26
+ self.BOS = 257
27
+ self.EOS = 258
28
+ self.MASK = 259
29
+
30
+ # Total vocabulary size = 256 bytes + 4 special tokens
31
+ self.vocab_size = 260
32
+
33
+ print(f"Byte tokenizer initialized (vocab_size={self.vocab_size})")
34
+
35
+ def encode(self, text: str, add_special_tokens: bool = True) -> Dict:
36
+ """
37
+ Encode text to byte IDs
38
+
39
+ Args:
40
+ text: Input text
41
+ add_special_tokens: Whether to add BOS/EOS
42
+
43
+ Returns:
44
+ dict with 'input_ids', 'attention_mask', 'length'
45
+ """
46
+ # Convert text to UTF-8 bytes (pure bytes, no rules)
47
+ byte_sequence = list(text.encode('utf-8'))
48
+
49
+ # Truncate if necessary
50
+ max_len = self.max_seq_len - 2 if add_special_tokens else self.max_seq_len
51
+ if len(byte_sequence) > max_len:
52
+ byte_sequence = byte_sequence[:max_len]
53
+
54
+ # Add special tokens
55
+ if add_special_tokens:
56
+ input_ids = [self.BOS] + byte_sequence + [self.EOS]
57
+ else:
58
+ input_ids = byte_sequence
59
+
60
+ # Create attention mask (1 for real tokens, 0 for padding)
61
+ attention_mask = [1] * len(input_ids)
62
+
63
+ return {
64
+ 'input_ids': input_ids,
65
+ 'attention_mask': attention_mask,
66
+ 'length': len(input_ids)
67
+ }
68
+
69
+ def encode_batch(self, texts: List[str], add_special_tokens: bool = True) -> Dict:
70
+ """
71
+ Encode multiple texts with padding
72
+
73
+ Args:
74
+ texts: List of input texts
75
+ add_special_tokens: Whether to add special tokens
76
+
77
+ Returns:
78
+ Batched tensors with padding
79
+ """
80
+ encoded_texts = []
81
+ max_length = 0
82
+
83
+ # Encode each text
84
+ for text in texts:
85
+ encoded = self.encode(text, add_special_tokens)
86
+ encoded_texts.append(encoded)
87
+ max_length = max(max_length, encoded['length'])
88
+
89
+ # Limit to max sequence length
90
+ max_length = min(max_length, self.max_seq_len)
91
+
92
+ # Initialize batch tensors
93
+ batch_size = len(texts)
94
+ input_ids = np.full((batch_size, max_length), self.PAD, dtype=np.int64)
95
+ attention_mask = np.zeros((batch_size, max_length), dtype=np.float32)
96
+
97
+ # Fill batch tensors
98
+ for i, encoded in enumerate(encoded_texts):
99
+ seq_len = min(encoded['length'], max_length)
100
+ input_ids[i, :seq_len] = encoded['input_ids'][:seq_len]
101
+ attention_mask[i, :seq_len] = 1.0
102
+
103
+ return {
104
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
105
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.float32),
106
+ 'lengths': torch.tensor([e['length'] for e in encoded_texts], dtype=torch.long)
107
+ }
108
+
109
+ def decode(self, input_ids: Union[List[int], torch.Tensor, np.ndarray],
110
+ skip_special_tokens: bool = True) -> str:
111
+ """
112
+ Decode byte IDs back to text
113
+
114
+ Args:
115
+ input_ids: Byte ID sequence
116
+ skip_special_tokens: Whether to skip special tokens
117
+
118
+ Returns:
119
+ Decoded text string
120
+ """
121
+ # Convert to list if needed
122
+ if isinstance(input_ids, torch.Tensor):
123
+ input_ids = input_ids.cpu().numpy().tolist()
124
+ elif isinstance(input_ids, np.ndarray):
125
+ input_ids = input_ids.tolist()
126
+
127
+ # Filter special tokens if requested
128
+ if skip_special_tokens:
129
+ # Only keep actual bytes (0-255)
130
+ input_ids = [b for b in input_ids if 0 <= b <= 255]
131
+ else:
132
+ # Replace special tokens with readable markers
133
+ processed = []
134
+ for b in input_ids:
135
+ if b == self.PAD:
136
+ continue # Skip padding
137
+ elif b == self.BOS:
138
+ processed.append(ord('[')) # Use [ for BOS
139
+ elif b == self.EOS:
140
+ processed.append(ord(']')) # Use ] for EOS
141
+ elif b == self.MASK:
142
+ processed.append(ord('*')) # Use * for MASK
143
+ elif 0 <= b <= 255:
144
+ processed.append(b)
145
+ input_ids = processed
146
+
147
+ # Convert bytes to text
148
+ try:
149
+ # Try UTF-8 decoding
150
+ byte_array = bytes(input_ids)
151
+ text = byte_array.decode('utf-8', errors='replace')
152
+ return text
153
+ except Exception as e:
154
+ # Fallback: convert directly to chars
155
+ return "".join([chr(b) if b < 128 else '?' for b in input_ids])
156
+
157
+ def decode_batch(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
158
+ """
159
+ Decode a batch of byte sequences
160
+
161
+ Args:
162
+ input_ids: Batch of byte IDs (batch_size, seq_len)
163
+ skip_special_tokens: Whether to skip special tokens
164
+
165
+ Returns:
166
+ List of decoded texts
167
+ """
168
+ texts = []
169
+ for i in range(input_ids.shape[0]):
170
+ text = self.decode(input_ids[i], skip_special_tokens)
171
+ texts.append(text)
172
+ return texts
173
+
174
+ def tokenize(self, text: str) -> List[int]:
175
+ """
176
+ Simple tokenization to byte IDs (no special tokens)
177
+
178
+ Args:
179
+ text: Input text
180
+
181
+ Returns:
182
+ List of byte IDs
183
+ """
184
+ return list(text.encode('utf-8'))
185
+
186
+ def detokenize(self, byte_ids: List[int]) -> str:
187
+ """
188
+ Simple detokenization from byte IDs
189
+
190
+ Args:
191
+ byte_ids: List of byte IDs
192
+
193
+ Returns:
194
+ Decoded text
195
+ """
196
+ try:
197
+ return bytes(byte_ids).decode('utf-8', errors='replace')
198
+ except:
199
+ return "".join([chr(b) if b < 128 else '?' for b in byte_ids])
200
+
201
+ def get_vocab_size(self) -> int:
202
+ """Get vocabulary size"""
203
+ return self.vocab_size
204
+
205
+ def get_special_tokens(self) -> Dict[str, int]:
206
+ """Get special token IDs"""
207
+ return {
208
+ 'pad_id': self.PAD,
209
+ 'bos_id': self.BOS,
210
+ 'eos_id': self.EOS,
211
+ 'mask_id': self.MASK
212
+ }
213
+
214
+
215
+ # Test code
216
+ if __name__ == "__main__":
217
+ # Initialize tokenizer
218
+ tokenizer = ByteTokenizerV6()
219
+
220
+ # Test texts in multiple languages
221
+ test_texts = [
222
+ "Hello World!",
223
+ "안녕하세요",
224
+ "你好世界",
225
+ "こんにちは",
226
+ "مرحبا بالعالم",
227
+ "Здравствуй мир"
228
+ ]
229
+
230
+ print("=" * 50)
231
+ print("Single Text Encoding/Decoding Test")
232
+ print("=" * 50)
233
+
234
+ for text in test_texts:
235
+ print(f"\nOriginal: {text}")
236
+
237
+ # Encode
238
+ encoded = tokenizer.encode(text)
239
+ print(f"Encoded length: {encoded['length']}")
240
+ print(f"First 10 bytes: {encoded['input_ids'][:10]}")
241
+
242
+ # Decode
243
+ decoded = tokenizer.decode(encoded['input_ids'])
244
+ print(f"Decoded: {decoded}")
245
+ print(f"Match: {decoded == text}")
246
+
247
+ print("\n" + "=" * 50)
248
+ print("Batch Encoding/Decoding Test")
249
+ print("=" * 50)
250
+
251
+ # Batch test
252
+ batch_result = tokenizer.encode_batch(test_texts)
253
+ print(f"Batch shape: {batch_result['input_ids'].shape}")
254
+ print(f"Attention mask shape: {batch_result['attention_mask'].shape}")
255
+
256
+ # Decode batch
257
+ decoded_texts = tokenizer.decode_batch(batch_result['input_ids'])
258
+ print("\nBatch decoding results:")
259
+ for orig, dec in zip(test_texts, decoded_texts):
260
+ print(f"Original: {orig}")
261
+ print(f"Decoded: {dec}")
262
+ print(f"Match: {orig == dec}")
263
+ print()