klemenk commited on
Commit
1defa8d
·
verified ·
1 Parent(s): 4cec4ed

Create modeling_wavtokenizer.py

Browse files
Files changed (1) hide show
  1. modeling_wavtokenizer.py +813 -0
modeling_wavtokenizer.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WavTokenizer Model for HuggingFace Transformers
3
+
4
+ This module contains the complete implementation of WavTokenizer,
5
+ an acoustic discrete codec tokenizer for audio language modeling.
6
+ All dependencies are included to avoid external imports.
7
+
8
+ The architecture follows the original WavTokenizer implementation:
9
+ - Encoder: Strided convolutions for audio compression
10
+ - VQ: Vector quantization with single codebook
11
+ - Decoder: Vocos-style backbone with ConvNeXt blocks + iSTFT head
12
+
13
+ Reference: https://github.com/jishengpeng/WavTokenizer
14
+ Paper: "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling"
15
+ """
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+ from dataclasses import dataclass
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch import Tensor
25
+ from torch.nn.utils import weight_norm, remove_weight_norm
26
+
27
+ from transformers import PreTrainedModel
28
+ from transformers.tokenization_utils import BatchEncoding
29
+
30
+ from .configuration_wavtokenizer import WavTokenizerConfig
31
+
32
+
33
+ # ==============================================================================
34
+ # Utility Functions
35
+ # ==============================================================================
36
+
37
+ def convert_audio(wav: Tensor, sr: int, target_sr: int, target_channels: int) -> Tensor:
38
+ """
39
+ Convert audio to target sample rate and number of channels.
40
+
41
+ Args:
42
+ wav: Input waveform [C, T] or [T]
43
+ sr: Source sample rate
44
+ target_sr: Target sample rate
45
+ target_channels: Target number of channels (1 for mono, 2 for stereo)
46
+
47
+ Returns:
48
+ Converted waveform [target_channels, T']
49
+ """
50
+ import torchaudio
51
+
52
+ # Ensure 2D
53
+ if wav.dim() == 1:
54
+ wav = wav.unsqueeze(0)
55
+
56
+ # Convert channels
57
+ if wav.size(0) > target_channels:
58
+ wav = wav.mean(dim=0, keepdim=True)
59
+ elif wav.size(0) < target_channels:
60
+ wav = wav.expand(target_channels, -1)
61
+
62
+ # Resample if needed
63
+ if sr != target_sr:
64
+ wav = torchaudio.functional.resample(wav, sr, target_sr)
65
+
66
+ return wav
67
+
68
+
69
+ # ==============================================================================
70
+ # Encoder Components (DAC-style)
71
+ # ==============================================================================
72
+
73
+ def WNConv1d(*args, **kwargs):
74
+ """Weight-normalized Conv1d."""
75
+ return weight_norm(nn.Conv1d(*args, **kwargs))
76
+
77
+
78
+ def WNConvTranspose1d(*args, **kwargs):
79
+ """Weight-normalized ConvTranspose1d."""
80
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
81
+
82
+
83
+ class ResidualUnit(nn.Module):
84
+ """Residual unit with dilated convolution."""
85
+
86
+ def __init__(self, dim: int = 16, dilation: int = 1):
87
+ super().__init__()
88
+ pad = ((7 - 1) * dilation) // 2
89
+ self.block = nn.Sequential(
90
+ nn.ELU(),
91
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
92
+ nn.ELU(),
93
+ WNConv1d(dim, dim, kernel_size=1),
94
+ )
95
+
96
+ def forward(self, x: Tensor) -> Tensor:
97
+ return x + self.block(x)
98
+
99
+
100
+ class EncoderBlock(nn.Module):
101
+ """Encoder block with residual units and downsampling."""
102
+
103
+ def __init__(self, dim: int = 16, stride: int = 1):
104
+ super().__init__()
105
+ self.block = nn.Sequential(
106
+ ResidualUnit(dim // 2, dilation=1),
107
+ ResidualUnit(dim // 2, dilation=3),
108
+ ResidualUnit(dim // 2, dilation=9),
109
+ nn.ELU(),
110
+ WNConv1d(
111
+ dim // 2, dim,
112
+ kernel_size=2 * stride,
113
+ stride=stride,
114
+ padding=math.ceil(stride / 2),
115
+ ),
116
+ )
117
+
118
+ def forward(self, x: Tensor) -> Tensor:
119
+ return self.block(x)
120
+
121
+
122
+ class Encoder(nn.Module):
123
+ """
124
+ DAC-style encoder that compresses waveform to latent representation.
125
+ Uses strided convolutions for downsampling.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ d_model: int = 64,
131
+ strides: List[int] = [8, 5, 4, 2],
132
+ d_latent: int = 512,
133
+ ):
134
+ super().__init__()
135
+
136
+ # Initial conv
137
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
138
+
139
+ # Encoder blocks with increasing channels
140
+ for stride in strides:
141
+ d_model *= 2
142
+ self.block.append(EncoderBlock(d_model, stride=stride))
143
+
144
+ # Final projection
145
+ self.block.extend([
146
+ nn.ELU(),
147
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
148
+ ])
149
+
150
+ self.block = nn.Sequential(*self.block)
151
+ self.enc_dim = d_model
152
+
153
+ def forward(self, x: Tensor) -> Tensor:
154
+ return self.block(x)
155
+
156
+
157
+ # ==============================================================================
158
+ # Vector Quantization
159
+ # ==============================================================================
160
+
161
+ class VectorQuantize(nn.Module):
162
+ """
163
+ Improved vector quantization with EMA codebook updates.
164
+
165
+ Uses L2-normalized codes for better stability.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ input_dim: int,
171
+ codebook_size: int,
172
+ codebook_dim: int,
173
+ commitment: float = 0.25,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.input_dim = input_dim
178
+ self.codebook_size = codebook_size
179
+ self.codebook_dim = codebook_dim
180
+ self.commitment = commitment
181
+
182
+ # Projections
183
+ requires_projection = input_dim != codebook_dim
184
+ self.project_in = nn.Linear(input_dim, codebook_dim) if requires_projection else nn.Identity()
185
+ self.project_out = nn.Linear(codebook_dim, input_dim) if requires_projection else nn.Identity()
186
+
187
+ # Codebook
188
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
189
+ nn.init.uniform_(self.codebook.weight, -1.0 / codebook_size, 1.0 / codebook_size)
190
+
191
+ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
192
+ """
193
+ Forward pass.
194
+
195
+ Args:
196
+ z: Input [B, D, T]
197
+
198
+ Returns:
199
+ z_q: Quantized [B, D, T]
200
+ commitment_loss: Loss scalar
201
+ indices: Codes [B, T]
202
+ """
203
+ # [B, D, T] -> [B, T, D]
204
+ z = z.transpose(1, 2)
205
+ z_e = self.project_in(z)
206
+
207
+ # L2 normalize
208
+ z_e_norm = F.normalize(z_e, dim=-1)
209
+ codebook_norm = F.normalize(self.codebook.weight, dim=-1)
210
+
211
+ # Find nearest codes
212
+ dist = (
213
+ z_e_norm.pow(2).sum(-1, keepdim=True)
214
+ + codebook_norm.pow(2).sum(-1)
215
+ - 2 * torch.einsum('btd,kd->btk', z_e_norm, codebook_norm)
216
+ )
217
+ indices = dist.argmin(dim=-1)
218
+
219
+ # Look up quantized values
220
+ z_q = F.embedding(indices, codebook_norm)
221
+
222
+ # Commitment loss
223
+ commitment_loss = F.mse_loss(z_e_norm, z_q.detach()) * self.commitment
224
+
225
+ # Straight-through
226
+ z_q = z_e_norm + (z_q - z_e_norm).detach()
227
+
228
+ # Project out and transpose back
229
+ z_q = self.project_out(z_q)
230
+ z_q = z_q.transpose(1, 2) # [B, D, T]
231
+
232
+ return z_q, commitment_loss, indices
233
+
234
+ def decode(self, indices: Tensor) -> Tensor:
235
+ """Decode indices to vectors."""
236
+ codebook = F.normalize(self.codebook.weight, dim=-1)
237
+ z_q = F.embedding(indices, codebook)
238
+ z_q = self.project_out(z_q)
239
+ return z_q.transpose(1, 2)
240
+
241
+
242
+ class ResidualVectorQuantize(nn.Module):
243
+ """Residual VQ with multiple codebooks (typically 1 for WavTokenizer)."""
244
+
245
+ def __init__(
246
+ self,
247
+ input_dim: int = 512,
248
+ codebook_size: int = 4096,
249
+ codebook_dim: int = 8,
250
+ num_quantizers: int = 1,
251
+ commitment: float = 0.25,
252
+ ):
253
+ super().__init__()
254
+
255
+ self.num_quantizers = num_quantizers
256
+ self.quantizers = nn.ModuleList([
257
+ VectorQuantize(input_dim, codebook_size, codebook_dim, commitment)
258
+ for _ in range(num_quantizers)
259
+ ])
260
+
261
+ def forward(
262
+ self, z: Tensor, n_quantizers: int = None
263
+ ) -> Tuple[Tensor, Tensor, Tensor]:
264
+ n_q = n_quantizers or self.num_quantizers
265
+
266
+ residual = z
267
+ z_q = torch.zeros_like(z)
268
+ all_indices = []
269
+ all_losses = []
270
+
271
+ for i, quantizer in enumerate(self.quantizers[:n_q]):
272
+ _z_q, loss, indices = quantizer(residual)
273
+ residual = residual - _z_q
274
+ z_q = z_q + _z_q
275
+ all_indices.append(indices)
276
+ all_losses.append(loss)
277
+
278
+ codes = torch.stack(all_indices, dim=0) # [N_q, B, T]
279
+ commitment_loss = sum(all_losses)
280
+
281
+ return z_q, commitment_loss, codes
282
+
283
+ def decode(self, codes: Tensor) -> Tensor:
284
+ """Decode codes to vectors."""
285
+ if codes.dim() == 2:
286
+ codes = codes.unsqueeze(0)
287
+
288
+ z_q = None
289
+ for i, quantizer in enumerate(self.quantizers[:codes.size(0)]):
290
+ _z_q = quantizer.decode(codes[i])
291
+ z_q = _z_q if z_q is None else z_q + _z_q
292
+
293
+ return z_q
294
+
295
+
296
+ # ==============================================================================
297
+ # Decoder Components (Vocos-style)
298
+ # ==============================================================================
299
+
300
+ class ConvNeXtBlock(nn.Module):
301
+ """ConvNeXt block with depthwise conv + pointwise expansion."""
302
+
303
+ def __init__(
304
+ self,
305
+ dim: int,
306
+ intermediate_dim: int,
307
+ kernel_size: int = 7,
308
+ layer_scale_init_value: float = 1e-6,
309
+ ):
310
+ super().__init__()
311
+
312
+ padding = (kernel_size - 1) // 2
313
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
314
+ self.norm = nn.LayerNorm(dim)
315
+ self.pwconv1 = nn.Linear(dim, intermediate_dim)
316
+ self.act = nn.GELU()
317
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
318
+
319
+ self.gamma = nn.Parameter(
320
+ layer_scale_init_value * torch.ones(dim)
321
+ ) if layer_scale_init_value > 0 else None
322
+
323
+ def forward(self, x: Tensor) -> Tensor:
324
+ residual = x
325
+ x = self.dwconv(x)
326
+ x = x.transpose(1, 2) # [B, T, D]
327
+ x = self.norm(x)
328
+ x = self.pwconv1(x)
329
+ x = self.act(x)
330
+ x = self.pwconv2(x)
331
+ if self.gamma is not None:
332
+ x = self.gamma * x
333
+ x = x.transpose(1, 2) # [B, D, T]
334
+ return residual + x
335
+
336
+
337
+ class VocosBackbone(nn.Module):
338
+ """Vocos backbone with attention and ConvNeXt blocks."""
339
+
340
+ def __init__(
341
+ self,
342
+ input_dim: int,
343
+ dim: int,
344
+ intermediate_dim: int,
345
+ num_blocks: int,
346
+ kernel_size: int = 7,
347
+ layer_scale_init_value: float = 1e-6,
348
+ use_attention: bool = True,
349
+ num_heads: int = 8,
350
+ num_attention_layers: int = 1,
351
+ ):
352
+ super().__init__()
353
+
354
+ # Input projection
355
+ self.input_conv = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3)
356
+ self.norm = nn.LayerNorm(dim)
357
+
358
+ # Attention layers
359
+ self.use_attention = use_attention
360
+ if use_attention:
361
+ self.attention = nn.ModuleList([
362
+ nn.MultiheadAttention(dim, num_heads, batch_first=True)
363
+ for _ in range(num_attention_layers)
364
+ ])
365
+ self.attn_norms = nn.ModuleList([
366
+ nn.LayerNorm(dim) for _ in range(num_attention_layers)
367
+ ])
368
+
369
+ # ConvNeXt blocks
370
+ self.convnext = nn.ModuleList([
371
+ ConvNeXtBlock(dim, intermediate_dim, kernel_size, layer_scale_init_value)
372
+ for _ in range(num_blocks)
373
+ ])
374
+
375
+ self.final_norm = nn.LayerNorm(dim)
376
+
377
+ def forward(self, x: Tensor) -> Tensor:
378
+ # Input projection
379
+ x = self.input_conv(x)
380
+ x = x.transpose(1, 2) # [B, T, D]
381
+ x = self.norm(x)
382
+ x = x.transpose(1, 2) # [B, D, T]
383
+
384
+ # Attention
385
+ if self.use_attention:
386
+ for attn, norm in zip(self.attention, self.attn_norms):
387
+ x_t = x.transpose(1, 2) # [B, T, D]
388
+ residual = x_t
389
+ x_t = norm(x_t)
390
+ x_t, _ = attn(x_t, x_t, x_t)
391
+ x_t = residual + x_t
392
+ x = x_t.transpose(1, 2) # [B, D, T]
393
+
394
+ # ConvNeXt blocks
395
+ for block in self.convnext:
396
+ x = block(x)
397
+
398
+ # Final norm
399
+ x = x.transpose(1, 2)
400
+ x = self.final_norm(x)
401
+ x = x.transpose(1, 2)
402
+
403
+ return x
404
+
405
+
406
+ class ISTFTHead(nn.Module):
407
+ """Inverse STFT head for waveform synthesis."""
408
+
409
+ def __init__(
410
+ self,
411
+ dim: int,
412
+ n_fft: int,
413
+ hop_length: int,
414
+ padding: str = "center",
415
+ ):
416
+ super().__init__()
417
+
418
+ self.n_fft = n_fft
419
+ self.hop_length = hop_length
420
+ self.padding = padding
421
+
422
+ self.out_dim = n_fft // 2 + 1
423
+ self.proj = nn.Conv1d(dim, self.out_dim * 2, kernel_size=1)
424
+
425
+ # Register window buffer
426
+ self.register_buffer(
427
+ "window",
428
+ torch.hann_window(n_fft),
429
+ persistent=False
430
+ )
431
+
432
+ def forward(self, x: Tensor) -> Tensor:
433
+ """
434
+ Args:
435
+ x: [B, D, T]
436
+ Returns:
437
+ wav: [B, 1, T']
438
+ """
439
+ x = self.proj(x)
440
+
441
+ # Split mag/phase
442
+ mag, phase = x.chunk(2, dim=1)
443
+
444
+ # Process
445
+ mag = torch.exp(mag)
446
+ phase = torch.sin(phase)
447
+
448
+ # Complex spectrum
449
+ S = torch.complex(mag * torch.cos(phase * math.pi), mag * torch.sin(phase * math.pi))
450
+
451
+ # Ensure window is on same device
452
+ window = self.window.to(x.device)
453
+
454
+ # iSTFT
455
+ wav = torch.istft(
456
+ S,
457
+ n_fft=self.n_fft,
458
+ hop_length=self.hop_length,
459
+ window=window,
460
+ center=True,
461
+ normalized=False,
462
+ onesided=True,
463
+ return_complex=False,
464
+ )
465
+
466
+ return wav.unsqueeze(1)
467
+
468
+
469
+ # ==============================================================================
470
+ # Feature Extractor (Mel Spectrogram)
471
+ # ==============================================================================
472
+
473
+ class MelSpectrogramFeatures(nn.Module):
474
+ """Extract mel spectrogram features from audio."""
475
+
476
+ def __init__(
477
+ self,
478
+ sample_rate: int = 24000,
479
+ n_fft: int = 1024,
480
+ hop_length: int = 256,
481
+ n_mels: int = 100,
482
+ f_min: float = 0.0,
483
+ f_max: float = None,
484
+ padding: str = "center",
485
+ ):
486
+ super().__init__()
487
+
488
+ self.sample_rate = sample_rate
489
+ self.n_fft = n_fft
490
+ self.hop_length = hop_length
491
+ self.n_mels = n_mels
492
+ self.padding = padding
493
+
494
+ # Mel filterbank
495
+ import torchaudio
496
+ mel_fb = torchaudio.functional.melscale_fbanks(
497
+ n_freqs=n_fft // 2 + 1,
498
+ f_min=f_min,
499
+ f_max=f_max or sample_rate // 2,
500
+ n_mels=n_mels,
501
+ sample_rate=sample_rate,
502
+ norm="slaney",
503
+ mel_scale="slaney",
504
+ )
505
+ self.register_buffer("mel_fb", mel_fb, persistent=False)
506
+ self.register_buffer("window", torch.hann_window(n_fft), persistent=False)
507
+
508
+ def forward(self, wav: Tensor) -> Tensor:
509
+ """
510
+ Args:
511
+ wav: [B, 1, T] or [B, T]
512
+ Returns:
513
+ mel: [B, n_mels, T']
514
+ """
515
+ if wav.dim() == 3:
516
+ wav = wav.squeeze(1)
517
+
518
+ # STFT
519
+ stft = torch.stft(
520
+ wav,
521
+ n_fft=self.n_fft,
522
+ hop_length=self.hop_length,
523
+ window=self.window.to(wav.device),
524
+ center=True,
525
+ return_complex=True,
526
+ )
527
+
528
+ # Power spectrum
529
+ power = stft.abs().pow(2)
530
+
531
+ # Mel spectrogram
532
+ mel = torch.matmul(self.mel_fb.T.to(power.device), power)
533
+
534
+ # Log scale
535
+ mel = torch.log(mel.clamp(min=1e-5))
536
+
537
+ return mel
538
+
539
+
540
+ # ==============================================================================
541
+ # Main WavTokenizer Model
542
+ # ==============================================================================
543
+
544
+ class WavTokenizer(PreTrainedModel):
545
+ """
546
+ WavTokenizer: Efficient acoustic discrete codec tokenizer.
547
+
548
+ Architecture:
549
+ - Encoder: Strided convolutions for audio compression
550
+ - VQ: Single-codebook vector quantization (4096 codes)
551
+ - Decoder: Vocos backbone (ConvNeXt + attention) + iSTFT head
552
+
553
+ Usage:
554
+ ```python
555
+ model = WavTokenizer.from_pretrained("TuKoResearch/WavTokenizerSmall", trust_remote_code=True)
556
+
557
+ # Encode
558
+ features, codes = model.encode_infer(wav, bandwidth_id=torch.tensor([0]))
559
+
560
+ # Decode
561
+ wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
562
+
563
+ # Or use codes directly
564
+ features = model.codes_to_features(codes)
565
+ wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
566
+ ```
567
+ """
568
+
569
+ config_class = WavTokenizerConfig
570
+
571
+ def __init__(self, config: WavTokenizerConfig):
572
+ super().__init__(config)
573
+
574
+ self.sample_rate = config.sample_rate
575
+ self.hop_length = config.hop_length
576
+
577
+ # Encoder
578
+ self.encoder = Encoder(
579
+ d_model=config.encoder_dim,
580
+ strides=config.encoder_rates,
581
+ d_latent=config.latent_dim,
582
+ )
583
+
584
+ # Quantizer
585
+ self.quantizer = ResidualVectorQuantize(
586
+ input_dim=config.latent_dim,
587
+ codebook_size=config.codebook_size,
588
+ codebook_dim=config.codebook_dim,
589
+ num_quantizers=config.num_quantizers,
590
+ )
591
+
592
+ # Feature projection for decoder
593
+ self.feature_proj = nn.Conv1d(config.latent_dim, config.backbone_dim, 1)
594
+
595
+ # Decoder backbone
596
+ self.backbone = VocosBackbone(
597
+ input_dim=config.backbone_dim,
598
+ dim=config.backbone_dim,
599
+ intermediate_dim=config.backbone_intermediate_dim,
600
+ num_blocks=config.backbone_num_blocks,
601
+ kernel_size=config.backbone_kernel_size,
602
+ layer_scale_init_value=config.backbone_layer_scale_init_value,
603
+ use_attention=config.use_attention,
604
+ num_heads=config.attention_heads,
605
+ num_attention_layers=config.attention_layers,
606
+ )
607
+
608
+ # iSTFT head
609
+ self.head = ISTFTHead(
610
+ dim=config.backbone_dim,
611
+ n_fft=config.n_fft,
612
+ hop_length=config.hop_length,
613
+ padding=config.padding,
614
+ )
615
+
616
+ # Bandwidth embedding
617
+ self.bandwidth_emb = nn.Embedding(4, config.backbone_dim)
618
+
619
+ self.post_init()
620
+
621
+ @property
622
+ def vocab_size(self) -> int:
623
+ return self.config.codebook_size
624
+
625
+ @property
626
+ def frame_rate(self) -> float:
627
+ return self.config.sample_rate / self.config.hop_length
628
+
629
+ def encode(
630
+ self, wav: Tensor, bandwidth_id: Tensor = None
631
+ ) -> Tuple[Tensor, Tensor, Tensor]:
632
+ """
633
+ Encode waveform to quantized features.
634
+
635
+ Args:
636
+ wav: [B, 1, T] or [B, T]
637
+ bandwidth_id: Optional bandwidth ID
638
+
639
+ Returns:
640
+ z_q: Quantized features [B, D, T']
641
+ commitment_loss: VQ loss
642
+ codes: Discrete codes [N_q, B, T']
643
+ """
644
+ if wav.dim() == 2:
645
+ wav = wav.unsqueeze(1)
646
+
647
+ z = self.encoder(wav)
648
+ z_q, loss, codes = self.quantizer(z)
649
+
650
+ return z_q, loss, codes
651
+
652
+ @torch.no_grad()
653
+ def encode_infer(
654
+ self, wav: Tensor, bandwidth_id: Tensor = None
655
+ ) -> Tuple[Tensor, Tensor]:
656
+ """
657
+ Encode waveform to features and codes (inference).
658
+
659
+ Args:
660
+ wav: [B, 1, T] or [1, T] or [B, T]
661
+ bandwidth_id: Optional bandwidth ID
662
+
663
+ Returns:
664
+ features: [B, D, T']
665
+ codes: [B, T'] (squeezed if single quantizer)
666
+ """
667
+ if wav.dim() == 2:
668
+ if wav.size(0) == 1:
669
+ wav = wav.unsqueeze(0) # [1, T] -> [1, 1, T]
670
+ else:
671
+ wav = wav.unsqueeze(1) # [B, T] -> [B, 1, T]
672
+
673
+ z = self.encoder(wav)
674
+ z_q, _, codes = self.quantizer(z)
675
+
676
+ # Squeeze for single quantizer
677
+ if codes.size(0) == 1:
678
+ codes = codes.squeeze(0)
679
+
680
+ return z_q, codes
681
+
682
+ def decode(
683
+ self, features: Tensor, bandwidth_id: Tensor = None
684
+ ) -> Tensor:
685
+ """
686
+ Decode features to waveform.
687
+
688
+ Args:
689
+ features: [B, D, T']
690
+ bandwidth_id: Optional bandwidth ID
691
+
692
+ Returns:
693
+ wav: [B, 1, T]
694
+ """
695
+ x = self.feature_proj(features)
696
+
697
+ if bandwidth_id is not None:
698
+ bw_emb = self.bandwidth_emb(bandwidth_id)
699
+ x = x + bw_emb.unsqueeze(-1)
700
+
701
+ x = self.backbone(x)
702
+ wav = self.head(x)
703
+
704
+ return wav
705
+
706
+ @torch.no_grad()
707
+ def codes_to_features(self, codes: Tensor) -> Tensor:
708
+ """
709
+ Convert codes to features.
710
+
711
+ Args:
712
+ codes: [N_q, B, T'] or [B, T']
713
+
714
+ Returns:
715
+ features: [B, D, T']
716
+ """
717
+ return self.quantizer.decode(codes)
718
+
719
+ def forward(
720
+ self,
721
+ wav: Tensor = None,
722
+ codes: Tensor = None,
723
+ bandwidth_id: Tensor = None,
724
+ **kwargs
725
+ ) -> Union[BatchEncoding, Tensor]:
726
+ """
727
+ Forward pass.
728
+
729
+ If wav provided: encode to get tokens
730
+ If codes provided: decode to get wav
731
+ """
732
+ if wav is not None:
733
+ features, codes = self.encode_infer(wav, bandwidth_id)
734
+ return BatchEncoding({
735
+ "input_values": features,
736
+ "input_ids": codes,
737
+ })
738
+ elif codes is not None:
739
+ features = self.codes_to_features(codes)
740
+ return self.decode(features, bandwidth_id)
741
+ else:
742
+ raise ValueError("Provide either 'wav' or 'codes'")
743
+
744
+ @classmethod
745
+ def from_pretrained0802(
746
+ cls,
747
+ config_path: str,
748
+ checkpoint_path: str,
749
+ device: str = "cpu",
750
+ ) -> "WavTokenizer":
751
+ """
752
+ Load from original WavTokenizer checkpoint.
753
+
754
+ Args:
755
+ config_path: Path to YAML config
756
+ checkpoint_path: Path to .ckpt file
757
+ device: Device to load to
758
+
759
+ Returns:
760
+ Loaded model
761
+ """
762
+ import yaml
763
+
764
+ # Load YAML config
765
+ with open(config_path, 'r') as f:
766
+ yaml_cfg = yaml.safe_load(f)
767
+
768
+ # Extract config params
769
+ model_args = yaml_cfg.get('model', {}).get('init_args', {})
770
+
771
+ # Create HF config
772
+ config = WavTokenizerConfig(
773
+ sample_rate=24000,
774
+ n_fft=model_args.get('head', {}).get('init_args', {}).get('n_fft', 1280),
775
+ hop_length=model_args.get('head', {}).get('init_args', {}).get('hop_length', 320),
776
+ feature_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
777
+ latent_dim=model_args.get('backbone', {}).get('init_args', {}).get('input_channels', 512),
778
+ backbone_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
779
+ backbone_intermediate_dim=model_args.get('backbone', {}).get('init_args', {}).get('intermediate_dim', 1536),
780
+ backbone_num_blocks=model_args.get('backbone', {}).get('init_args', {}).get('num_layers', 8),
781
+ codebook_size=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_size', 4096),
782
+ codebook_dim=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_dim', 8),
783
+ num_quantizers=model_args.get('quantizer', {}).get('init_args', {}).get('num_quantizers', 1),
784
+ use_attention=True,
785
+ attention_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
786
+ attention_heads=8,
787
+ attention_layers=1,
788
+ )
789
+
790
+ # Create model
791
+ model = cls(config)
792
+
793
+ # Load checkpoint
794
+ ckpt = torch.load(checkpoint_path, map_location=device)
795
+ state_dict = ckpt.get('state_dict', ckpt)
796
+
797
+ # Clean state dict
798
+ new_state_dict = {}
799
+ for k, v in state_dict.items():
800
+ # Remove 'model.' prefix if present
801
+ if k.startswith('model.'):
802
+ k = k[6:]
803
+ new_state_dict[k] = v
804
+
805
+ # Load (non-strict to handle mismatches)
806
+ missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
807
+
808
+ if missing:
809
+ print(f"Missing keys: {len(missing)}")
810
+ if unexpected:
811
+ print(f"Unexpected keys: {len(unexpected)}")
812
+
813
+ return model.to(device)