CocoBro commited on
Commit
62c2ea1
·
1 Parent(s): 11cf650
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -230,13 +230,13 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
230
 
231
  # 这一步将模型送入显卡
232
  def safe_move_model(m, dev):
233
- logger.info("🛡️ Moving model layer by layer to avoid RAM spike...")
234
  for name, child in m.named_children():
235
- # 逐层搬运:CPU内存释放一点 -> GPU显存增加一点
236
- child.to(dev, dtype=torch.float16)
237
- logger.info(f"Moving {name} to GPU...")
238
- m.to(dev, dtype=torch.float16)
239
  return m
 
240
 
241
  model = safe_move_model(model, device)
242
  model.eval()
@@ -257,7 +257,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
257
  torch.manual_seed(int(seed))
258
  np.random.seed(int(seed))
259
 
260
- wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float16)
261
 
262
  batch = {
263
  "audio_id": [Path(audio_file).stem],
@@ -272,7 +272,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
272
 
273
  logger.info("Inference running...")
274
  t0 = time.time()
275
- with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
276
  out = model.inference(scheduler=scheduler, **batch)
277
 
278
 
 
230
 
231
  # 这一步将模型送入显卡
232
  def safe_move_model(m, dev):
233
+ logger.info("🛡️ Moving model to GPU in FP32...")
234
  for name, child in m.named_children():
235
+ child.to(dev, dtype=torch.float32)
236
+ logger.info(f"Moving {name} to GPU (fp32)...")
237
+ m.to(dev, dtype=torch.float32)
 
238
  return m
239
+
240
 
241
  model = safe_move_model(model, device)
242
  model.eval()
 
257
  torch.manual_seed(int(seed))
258
  np.random.seed(int(seed))
259
 
260
+ wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32)
261
 
262
  batch = {
263
  "audio_id": [Path(audio_file).stem],
 
272
 
273
  logger.info("Inference running...")
274
  t0 = time.time()
275
+ with torch.no_grad():
276
  out = model.inference(scheduler=scheduler, **batch)
277
 
278