klemenk commited on
Commit
85ac35b
·
verified ·
1 Parent(s): 9b7a9cb

Create convert_wavtokenizer.py

Browse files
Files changed (1) hide show
  1. convert_wavtokenizer.py +173 -0
convert_wavtokenizer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Convert original WavTokenizer checkpoint to HuggingFace format.
4
+
5
+ Usage:
6
+ python convert_wavtokenizer.py \
7
+ --config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \
8
+ --checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \
9
+ --output_dir ./wavtokenizer_hf_converted
10
+
11
+ This will create a HuggingFace-compatible model directory that can be loaded with:
12
+ model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True)
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ import yaml
23
+
24
+
25
+ def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str):
26
+ """Convert WavTokenizer checkpoint to HuggingFace format."""
27
+
28
+ print(f"Loading config from: {config_path}")
29
+ print(f"Loading checkpoint from: {checkpoint_path}")
30
+
31
+ # Load YAML config
32
+ with open(config_path, 'r') as f:
33
+ yaml_cfg = yaml.safe_load(f)
34
+
35
+ # Extract model parameters
36
+ model_args = yaml_cfg.get('model', {}).get('init_args', {})
37
+
38
+ # Get specific component configs
39
+ head_args = model_args.get('head', {}).get('init_args', {})
40
+ backbone_args = model_args.get('backbone', {}).get('init_args', {})
41
+ quantizer_args = model_args.get('quantizer', {}).get('init_args', {})
42
+ feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {})
43
+
44
+ # Create HuggingFace config
45
+ hf_config = {
46
+ "_name_or_path": "WavTokenizerSmall",
47
+ "architectures": ["WavTokenizer"],
48
+ "auto_map": {
49
+ "AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig",
50
+ "AutoModel": "modeling_wavtokenizer.WavTokenizer"
51
+ },
52
+ "model_type": "wavtokenizer",
53
+
54
+ # Audio parameters
55
+ "sample_rate": feature_extractor_args.get('sample_rate', 24000),
56
+ "n_fft": head_args.get('n_fft', 1280),
57
+ "hop_length": head_args.get('hop_length', 320),
58
+ "n_mels": feature_extractor_args.get('n_mels', 128),
59
+ "padding": head_args.get('padding', 'center'),
60
+
61
+ # Feature dimensions
62
+ "feature_dim": backbone_args.get('dim', 512),
63
+ "encoder_dim": 64, # Default DAC encoder
64
+ "encoder_rates": [8, 5, 4, 2], # Default DAC encoder rates
65
+ "latent_dim": backbone_args.get('input_channels', 512),
66
+
67
+ # Quantizer parameters
68
+ "codebook_size": quantizer_args.get('codebook_size', 4096),
69
+ "codebook_dim": quantizer_args.get('codebook_dim', 8),
70
+ "num_quantizers": quantizer_args.get('num_quantizers', 1),
71
+
72
+ # Backbone parameters
73
+ "backbone_type": "vocos",
74
+ "backbone_dim": backbone_args.get('dim', 512),
75
+ "backbone_num_blocks": backbone_args.get('num_layers', 8),
76
+ "backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536),
77
+ "backbone_kernel_size": 7,
78
+ "backbone_layer_scale_init_value": 1e-6,
79
+
80
+ # Head parameters
81
+ "head_type": "istft",
82
+ "head_dim": head_args.get('n_fft', 1280) // 2 + 1,
83
+
84
+ # Attention parameters
85
+ "use_attention": True,
86
+ "attention_dim": backbone_args.get('dim', 512),
87
+ "attention_heads": 8,
88
+ "attention_layers": 1,
89
+
90
+ "torch_dtype": "float32",
91
+ "transformers_version": "4.40.0"
92
+ }
93
+
94
+ # Create output directory
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ # Save config.json
98
+ config_out_path = os.path.join(output_dir, "config.json")
99
+ with open(config_out_path, 'w') as f:
100
+ json.dump(hf_config, f, indent=2)
101
+ print(f"Saved config to: {config_out_path}")
102
+
103
+ # Load checkpoint
104
+ print("Loading checkpoint...")
105
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
106
+ state_dict = ckpt.get('state_dict', ckpt)
107
+
108
+ # Clean state dict keys
109
+ new_state_dict = {}
110
+ for k, v in state_dict.items():
111
+ # Remove 'model.' prefix if present
112
+ if k.startswith('model.'):
113
+ k = k[6:]
114
+ new_state_dict[k] = v
115
+
116
+ # Save as pytorch_model.bin
117
+ model_out_path = os.path.join(output_dir, "pytorch_model.bin")
118
+ torch.save(new_state_dict, model_out_path)
119
+ print(f"Saved model weights to: {model_out_path}")
120
+
121
+ # Copy Python files
122
+ script_dir = Path(__file__).parent
123
+
124
+ # Copy configuration file
125
+ config_py = script_dir / "configuration_wavtokenizer.py"
126
+ if config_py.exists():
127
+ shutil.copy(config_py, output_dir)
128
+ print(f"Copied: configuration_wavtokenizer.py")
129
+
130
+ # Copy modeling file
131
+ modeling_py = script_dir / "modeling_wavtokenizer.py"
132
+ if modeling_py.exists():
133
+ shutil.copy(modeling_py, output_dir)
134
+ print(f"Copied: modeling_wavtokenizer.py")
135
+
136
+ # Copy README
137
+ readme = script_dir / "README.md"
138
+ if readme.exists():
139
+ shutil.copy(readme, output_dir)
140
+ print(f"Copied: README.md")
141
+
142
+ print(f"\nConversion complete! Model saved to: {output_dir}")
143
+ print("\nTo load the model:")
144
+ print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)')
145
+
146
+
147
+ def main():
148
+ parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format")
149
+ parser.add_argument(
150
+ "--config_path",
151
+ type=str,
152
+ required=True,
153
+ help="Path to WavTokenizer YAML config file"
154
+ )
155
+ parser.add_argument(
156
+ "--checkpoint_path",
157
+ type=str,
158
+ required=True,
159
+ help="Path to WavTokenizer .ckpt checkpoint file"
160
+ )
161
+ parser.add_argument(
162
+ "--output_dir",
163
+ type=str,
164
+ default="./wavtokenizer_hf_converted",
165
+ help="Output directory for HuggingFace model"
166
+ )
167
+
168
+ args = parser.parse_args()
169
+ convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir)
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()