f32
Browse files
app.py
CHANGED
|
@@ -230,13 +230,13 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 230 |
|
| 231 |
# 这一步将模型送入显卡
|
| 232 |
def safe_move_model(m, dev):
|
| 233 |
-
logger.info("🛡️ Moving model
|
| 234 |
for name, child in m.named_children():
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
m.to(dev, dtype=torch.float16)
|
| 239 |
return m
|
|
|
|
| 240 |
|
| 241 |
model = safe_move_model(model, device)
|
| 242 |
model.eval()
|
|
@@ -257,7 +257,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 257 |
torch.manual_seed(int(seed))
|
| 258 |
np.random.seed(int(seed))
|
| 259 |
|
| 260 |
-
wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.
|
| 261 |
|
| 262 |
batch = {
|
| 263 |
"audio_id": [Path(audio_file).stem],
|
|
@@ -272,7 +272,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 272 |
|
| 273 |
logger.info("Inference running...")
|
| 274 |
t0 = time.time()
|
| 275 |
-
with torch.no_grad()
|
| 276 |
out = model.inference(scheduler=scheduler, **batch)
|
| 277 |
|
| 278 |
|
|
|
|
| 230 |
|
| 231 |
# 这一步将模型送入显卡
|
| 232 |
def safe_move_model(m, dev):
|
| 233 |
+
logger.info("🛡️ Moving model to GPU in FP32...")
|
| 234 |
for name, child in m.named_children():
|
| 235 |
+
child.to(dev, dtype=torch.float32)
|
| 236 |
+
logger.info(f"Moving {name} to GPU (fp32)...")
|
| 237 |
+
m.to(dev, dtype=torch.float32)
|
|
|
|
| 238 |
return m
|
| 239 |
+
|
| 240 |
|
| 241 |
model = safe_move_model(model, device)
|
| 242 |
model.eval()
|
|
|
|
| 257 |
torch.manual_seed(int(seed))
|
| 258 |
np.random.seed(int(seed))
|
| 259 |
|
| 260 |
+
wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32)
|
| 261 |
|
| 262 |
batch = {
|
| 263 |
"audio_id": [Path(audio_file).stem],
|
|
|
|
| 272 |
|
| 273 |
logger.info("Inference running...")
|
| 274 |
t0 = time.time()
|
| 275 |
+
with torch.no_grad():
|
| 276 |
out = model.inference(scheduler=scheduler, **batch)
|
| 277 |
|
| 278 |
|