zhendery commited on
Commit
eec53ba
·
1 Parent(s): 5476ef5

fix: 漏掉音频下载

Browse files
Files changed (1) hide show
  1. api.py +34 -34
api.py CHANGED
@@ -33,6 +33,38 @@ class GenerateRequest(BaseModel):
33
  do_normalize: bool = True
34
  denoise: bool = True
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # 队列相关变量
37
  task_queue = queue.Queue()
38
  output_dir = "./output"
@@ -70,7 +102,8 @@ async def process_queue():
70
 
71
  if model is None:
72
  raise RuntimeError("Failed to initialize model")
73
-
 
74
  print_with_time(f"Generating audio for : '{text[:60]}...'")
75
  with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
76
  wav = model.generate(
@@ -120,39 +153,6 @@ async def get_generate_result(task_id: str, token: str = Depends(verify_token)):
120
  except Exception as e:
121
  raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}")
122
 
123
-
124
- def download_voices():
125
- # 检查 /workspace/voices/ 目录中是否有 .pmt 文件
126
- voices_dir = "/workspace/voices"
127
- if not os.path.exists(voices_dir):
128
- os.makedirs(voices_dir)
129
-
130
- pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
131
- if not pmt_files:
132
- # 如果没有 .pmt 文件,尝试从远程下载
133
- voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
134
-
135
- if voice_download_url:
136
- try:
137
- response = requests.get(voice_download_url)
138
- response.raise_for_status()
139
-
140
- # 保存下载的zip文件
141
- zip_path = f"{voices_dir}/voices.zip"
142
- with open(zip_path, "wb") as f:
143
- f.write(response.content)
144
-
145
- # 解压zip文件
146
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
147
- zip_ref.extractall(voices_dir)
148
-
149
- # 删除临时zip文件
150
- os.remove(zip_path)
151
-
152
- except Exception as e:
153
- print_with_time(f"Failed to download and extract voices: {e}")
154
- raise HTTPException(status_code=500, detail="Failed to download voice files")
155
-
156
  @app.post("/upload_voice")
157
  def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
158
  # 保存wav文件
 
33
  do_normalize: bool = True
34
  denoise: bool = True
35
 
36
+ def download_voices(bForce=False):
37
+ # 检查 /workspace/voices/ 目录中是否有 .pmt 文件
38
+ voices_dir = "/workspace/voices"
39
+ if not os.path.exists(voices_dir):
40
+ os.makedirs(voices_dir)
41
+
42
+ pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
43
+ if bForce or not pmt_files:
44
+ # 如果没有 .pmt 文件,尝试从远程下载
45
+ voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
46
+
47
+ if voice_download_url:
48
+ try:
49
+ response = requests.get(voice_download_url)
50
+ response.raise_for_status()
51
+
52
+ # 保存下载的zip文件
53
+ zip_path = f"{voices_dir}/voices.zip"
54
+ with open(zip_path, "wb") as f:
55
+ f.write(response.content)
56
+
57
+ # 解压zip文件
58
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
59
+ zip_ref.extractall(voices_dir)
60
+
61
+ # 删除临时zip文件
62
+ os.remove(zip_path)
63
+
64
+ except Exception as e:
65
+ print_with_time(f"Failed to download and extract voices: {e}")
66
+ raise HTTPException(status_code=500, detail="Failed to download voice files")
67
+
68
  # 队列相关变量
69
  task_queue = queue.Queue()
70
  output_dir = "./output"
 
102
 
103
  if model is None:
104
  raise RuntimeError("Failed to initialize model")
105
+
106
+ download_voices()
107
  print_with_time(f"Generating audio for : '{text[:60]}...'")
108
  with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
109
  wav = model.generate(
 
153
  except Exception as e:
154
  raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  @app.post("/upload_voice")
157
  def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
158
  # 保存wav文件