zhendery
commited on
Commit
·
eec53ba
1
Parent(s):
5476ef5
fix: 漏掉音频下载
Browse files
api.py
CHANGED
|
@@ -33,6 +33,38 @@ class GenerateRequest(BaseModel):
|
|
| 33 |
do_normalize: bool = True
|
| 34 |
denoise: bool = True
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# 队列相关变量
|
| 37 |
task_queue = queue.Queue()
|
| 38 |
output_dir = "./output"
|
|
@@ -70,7 +102,8 @@ async def process_queue():
|
|
| 70 |
|
| 71 |
if model is None:
|
| 72 |
raise RuntimeError("Failed to initialize model")
|
| 73 |
-
|
|
|
|
| 74 |
print_with_time(f"Generating audio for : '{text[:60]}...'")
|
| 75 |
with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
|
| 76 |
wav = model.generate(
|
|
@@ -120,39 +153,6 @@ async def get_generate_result(task_id: str, token: str = Depends(verify_token)):
|
|
| 120 |
except Exception as e:
|
| 121 |
raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}")
|
| 122 |
|
| 123 |
-
|
| 124 |
-
def download_voices():
|
| 125 |
-
# 检查 /workspace/voices/ 目录中是否有 .pmt 文件
|
| 126 |
-
voices_dir = "/workspace/voices"
|
| 127 |
-
if not os.path.exists(voices_dir):
|
| 128 |
-
os.makedirs(voices_dir)
|
| 129 |
-
|
| 130 |
-
pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
|
| 131 |
-
if not pmt_files:
|
| 132 |
-
# 如果没有 .pmt 文件,尝试从远程下载
|
| 133 |
-
voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
|
| 134 |
-
|
| 135 |
-
if voice_download_url:
|
| 136 |
-
try:
|
| 137 |
-
response = requests.get(voice_download_url)
|
| 138 |
-
response.raise_for_status()
|
| 139 |
-
|
| 140 |
-
# 保存下载的zip文件
|
| 141 |
-
zip_path = f"{voices_dir}/voices.zip"
|
| 142 |
-
with open(zip_path, "wb") as f:
|
| 143 |
-
f.write(response.content)
|
| 144 |
-
|
| 145 |
-
# 解压zip文件
|
| 146 |
-
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 147 |
-
zip_ref.extractall(voices_dir)
|
| 148 |
-
|
| 149 |
-
# 删除临时zip文件
|
| 150 |
-
os.remove(zip_path)
|
| 151 |
-
|
| 152 |
-
except Exception as e:
|
| 153 |
-
print_with_time(f"Failed to download and extract voices: {e}")
|
| 154 |
-
raise HTTPException(status_code=500, detail="Failed to download voice files")
|
| 155 |
-
|
| 156 |
@app.post("/upload_voice")
|
| 157 |
def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
|
| 158 |
# 保存wav文件
|
|
|
|
| 33 |
do_normalize: bool = True
|
| 34 |
denoise: bool = True
|
| 35 |
|
| 36 |
+
def download_voices(bForce=False):
|
| 37 |
+
# 检查 /workspace/voices/ 目录中是否有 .pmt 文件
|
| 38 |
+
voices_dir = "/workspace/voices"
|
| 39 |
+
if not os.path.exists(voices_dir):
|
| 40 |
+
os.makedirs(voices_dir)
|
| 41 |
+
|
| 42 |
+
pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
|
| 43 |
+
if bForce or not pmt_files:
|
| 44 |
+
# 如果没有 .pmt 文件,尝试从远程下载
|
| 45 |
+
voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
|
| 46 |
+
|
| 47 |
+
if voice_download_url:
|
| 48 |
+
try:
|
| 49 |
+
response = requests.get(voice_download_url)
|
| 50 |
+
response.raise_for_status()
|
| 51 |
+
|
| 52 |
+
# 保存下载的zip文件
|
| 53 |
+
zip_path = f"{voices_dir}/voices.zip"
|
| 54 |
+
with open(zip_path, "wb") as f:
|
| 55 |
+
f.write(response.content)
|
| 56 |
+
|
| 57 |
+
# 解压zip文件
|
| 58 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 59 |
+
zip_ref.extractall(voices_dir)
|
| 60 |
+
|
| 61 |
+
# 删除临时zip文件
|
| 62 |
+
os.remove(zip_path)
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print_with_time(f"Failed to download and extract voices: {e}")
|
| 66 |
+
raise HTTPException(status_code=500, detail="Failed to download voice files")
|
| 67 |
+
|
| 68 |
# 队列相关变量
|
| 69 |
task_queue = queue.Queue()
|
| 70 |
output_dir = "./output"
|
|
|
|
| 102 |
|
| 103 |
if model is None:
|
| 104 |
raise RuntimeError("Failed to initialize model")
|
| 105 |
+
|
| 106 |
+
download_voices()
|
| 107 |
print_with_time(f"Generating audio for : '{text[:60]}...'")
|
| 108 |
with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
|
| 109 |
wav = model.generate(
|
|
|
|
| 153 |
except Exception as e:
|
| 154 |
raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}")
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
@app.post("/upload_voice")
|
| 157 |
def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
|
| 158 |
# 保存wav文件
|