klemenk commited on
Commit
1b82420
·
verified ·
1 Parent(s): 1defa8d

Create configuration_wavtokenizer.py

Browse files
Files changed (1) hide show
  1. configuration_wavtokenizer.py +163 -0
configuration_wavtokenizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WavTokenizer Configuration for HuggingFace Transformers
3
+
4
+ This configuration class defines all the hyperparameters for WavTokenizer,
5
+ an acoustic discrete codec tokenizer for audio language modeling.
6
+ """
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class WavTokenizerConfig(PretrainedConfig):
12
+ """
13
+ Configuration class for WavTokenizer model.
14
+
15
+ WavTokenizer is a SOTA discrete acoustic codec model that compresses audio
16
+ into discrete tokens (40 or 75 tokens per second) while maintaining high
17
+ reconstruction quality.
18
+
19
+ Args:
20
+ sample_rate (`int`, *optional*, defaults to 24000):
21
+ The sample rate of input audio.
22
+ n_fft (`int`, *optional*, defaults to 1280):
23
+ FFT size for STFT.
24
+ hop_length (`int`, *optional*, defaults to 320):
25
+ Hop length for STFT (determines frame rate: 24000/320 = 75 fps).
26
+ n_mels (`int`, *optional*, defaults to 128):
27
+ Number of mel filterbank channels.
28
+ padding (`str`, *optional*, defaults to "center"):
29
+ Padding mode for STFT ("center" or "same").
30
+
31
+ feature_dim (`int`, *optional*, defaults to 512):
32
+ Dimension of the feature backbone.
33
+ encoder_dim (`int`, *optional*, defaults to 64):
34
+ Dimension of encoder output.
35
+ encoder_rates (`list[int]`, *optional*, defaults to [8, 5, 4, 2]):
36
+ Downsampling rates for the encoder.
37
+ latent_dim (`int`, *optional*):
38
+ Dimension of the latent space (defaults to feature_dim).
39
+
40
+ codebook_size (`int`, *optional*, defaults to 4096):
41
+ Size of the VQ codebook.
42
+ codebook_dim (`int`, *optional*, defaults to 8):
43
+ Dimension of codebook vectors.
44
+ num_quantizers (`int`, *optional*, defaults to 1):
45
+ Number of residual vector quantizers.
46
+
47
+ backbone_type (`str`, *optional*, defaults to "vocos"):
48
+ Type of decoder backbone ("vocos").
49
+ backbone_dim (`int`, *optional*, defaults to 512):
50
+ Dimension of the decoder backbone.
51
+ backbone_num_blocks (`int`, *optional*, defaults to 8):
52
+ Number of ConvNeXt blocks in the backbone.
53
+ backbone_intermediate_dim (`int`, *optional*, defaults to 1536):
54
+ Intermediate dimension in ConvNeXt blocks.
55
+ backbone_kernel_size (`int`, *optional*, defaults to 7):
56
+ Kernel size for depthwise convolutions.
57
+ backbone_layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
58
+ Initial value for layer scale.
59
+
60
+ head_type (`str`, *optional*, defaults to "istft"):
61
+ Type of waveform synthesis head ("istft").
62
+ head_dim (`int`, *optional*, defaults to 1025):
63
+ Output dimension for the head (n_fft // 2 + 1).
64
+
65
+ use_attention (`bool`, *optional*, defaults to True):
66
+ Whether to use attention in the decoder.
67
+ attention_dim (`int`, *optional*, defaults to 512):
68
+ Dimension for attention layers.
69
+ attention_heads (`int`, *optional*, defaults to 8):
70
+ Number of attention heads.
71
+ attention_layers (`int`, *optional*, defaults to 1):
72
+ Number of attention layers.
73
+ """
74
+
75
+ model_type = "wavtokenizer"
76
+
77
+ def __init__(
78
+ self,
79
+ # Audio parameters
80
+ sample_rate: int = 24000,
81
+ n_fft: int = 1280,
82
+ hop_length: int = 320,
83
+ n_mels: int = 128,
84
+ padding: str = "center",
85
+
86
+ # Feature dimensions
87
+ feature_dim: int = 512,
88
+ encoder_dim: int = 64,
89
+ encoder_rates: list = None,
90
+ latent_dim: int = None,
91
+
92
+ # Quantizer parameters
93
+ codebook_size: int = 4096,
94
+ codebook_dim: int = 8,
95
+ num_quantizers: int = 1,
96
+
97
+ # Backbone parameters
98
+ backbone_type: str = "vocos",
99
+ backbone_dim: int = 512,
100
+ backbone_num_blocks: int = 8,
101
+ backbone_intermediate_dim: int = 1536,
102
+ backbone_kernel_size: int = 7,
103
+ backbone_layer_scale_init_value: float = 1e-6,
104
+
105
+ # Head parameters
106
+ head_type: str = "istft",
107
+ head_dim: int = 1025,
108
+
109
+ # Attention parameters
110
+ use_attention: bool = True,
111
+ attention_dim: int = 512,
112
+ attention_heads: int = 8,
113
+ attention_layers: int = 1,
114
+
115
+ **kwargs
116
+ ):
117
+ super().__init__(**kwargs)
118
+
119
+ # Audio
120
+ self.sample_rate = sample_rate
121
+ self.n_fft = n_fft
122
+ self.hop_length = hop_length
123
+ self.n_mels = n_mels
124
+ self.padding = padding
125
+
126
+ # Feature dimensions
127
+ self.feature_dim = feature_dim
128
+ self.encoder_dim = encoder_dim
129
+ self.encoder_rates = encoder_rates if encoder_rates is not None else [8, 5, 4, 2]
130
+ self.latent_dim = latent_dim if latent_dim is not None else feature_dim
131
+
132
+ # Quantizer
133
+ self.codebook_size = codebook_size
134
+ self.codebook_dim = codebook_dim
135
+ self.num_quantizers = num_quantizers
136
+
137
+ # Backbone
138
+ self.backbone_type = backbone_type
139
+ self.backbone_dim = backbone_dim
140
+ self.backbone_num_blocks = backbone_num_blocks
141
+ self.backbone_intermediate_dim = backbone_intermediate_dim
142
+ self.backbone_kernel_size = backbone_kernel_size
143
+ self.backbone_layer_scale_init_value = backbone_layer_scale_init_value
144
+
145
+ # Head
146
+ self.head_type = head_type
147
+ self.head_dim = head_dim
148
+
149
+ # Attention
150
+ self.use_attention = use_attention
151
+ self.attention_dim = attention_dim
152
+ self.attention_heads = attention_heads
153
+ self.attention_layers = attention_layers
154
+
155
+ @property
156
+ def vocab_size(self) -> int:
157
+ """Returns the vocabulary size (codebook size)."""
158
+ return self.codebook_size
159
+
160
+ @property
161
+ def frame_rate(self) -> float:
162
+ """Returns the frame rate (tokens per second)."""
163
+ return self.sample_rate / self.hop_length