zhendery commited on
Commit
1102626
·
1 Parent(s): 2603711

feat: 接口修改为异步

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -2
  2. api.py +95 -27
Dockerfile CHANGED
@@ -4,13 +4,13 @@ RUN ln -sf /share/zoneinfo/Asia/Shanghai /etc/localtime && \
4
  echo "Asia/Shanghai" > /etc/timezone
5
 
6
  RUN apt-get update && apt-get install -y \
7
- curl wget unzip git git-lfs ffmpeg && \
8
  apt-get clean && rm -rf /var/lib/apt/lists/*
9
 
10
  RUN pip install voxcpm && pip cache purge
11
 
12
  WORKDIR /workspace
13
- COPY . .
14
 
15
  ENV API_TOKEN my_secret_token
16
  ENV VOICE_DOWNLOAD_URL http://localhost/voices.zip
 
4
  echo "Asia/Shanghai" > /etc/timezone
5
 
6
  RUN apt-get update && apt-get install -y \
7
+ unzip ffmpeg && \
8
  apt-get clean && rm -rf /var/lib/apt/lists/*
9
 
10
  RUN pip install voxcpm && pip cache purge
11
 
12
  WORKDIR /workspace
13
+ COPY utils.py api.py ./
14
 
15
  ENV API_TOKEN my_secret_token
16
  ENV VOICE_DOWNLOAD_URL http://localhost/voices.zip
api.py CHANGED
@@ -8,10 +8,11 @@ import os
8
  import requests
9
  import zipfile
10
  from utils import *
 
 
 
 
11
 
12
- print_with_time("Loading VoxCPM model...")
13
- model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
14
- print_with_time("VoxCPM model loaded.")
15
 
16
  security = HTTPBearer()
17
  app = FastAPI()
@@ -32,28 +33,91 @@ class GenerateRequest(BaseModel):
32
  do_normalize: bool = True
33
  denoise: bool = True
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @app.post("/generate")
36
- def generate_tts(request: GenerateRequest, token: str = Depends(verify_token)):
37
- download_voices()
38
- text = (request.text or "").strip()
39
- if len(text) == 0:
40
- raise ValueError("Please input text to synthesize.")
41
- print_with_time(f"Generating audio for text: '{text[:60]}...'")
42
 
43
- with open(f"/workspace/voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
44
- wav = model.generate(
45
- text=text,
46
- prompt_wav_path=f"/workspace/voices/{request.voice}.wav",
47
- prompt_text=f.read(),
48
- cfg_value=request.cfg_value,
49
- inference_timesteps=request.inference_timesteps,
50
- normalize=request.do_normalize,
51
- denoise=request.denoise
52
- )
53
-
54
- sf.write("output.wav", wav, 16000)
55
- print_with_time(f"Audio generated, saving to output.wav")
56
- return Response(content=open("output.wav", 'rb').read(), media_type="audio/wav")
 
 
 
 
57
 
58
 
59
  def download_voices():
@@ -118,22 +182,26 @@ def delete_voice(name: str, token: str = Depends(verify_token)):
118
  def get_voices(token: str = Depends(verify_token)):
119
  download_voices()
120
  # 获取所有 .pmt 文件
121
- pmt_files = [f for f in os.listdir("/workspace/voices") if f.endswith(".pmt")]
122
  # 提取文件名(去掉 .pmt 后缀)
123
  voices = [f.split(".")[0] for f in pmt_files]
124
  # 确保对应的 .wav 文件也存在
125
  valid_voices = []
126
  for voice in voices:
127
- if os.path.exists(f"/workspace/voices/{voice}.wav"):
128
  valid_voices.append(voice)
129
  return {"voices": valid_voices}
130
 
131
 
132
  # ↓↓↓↓↓↓↓↓↓无需验证↓↓↓↓↓↓↓↓
133
  @app.get("/")
 
134
  def health_check():
135
  return {"status": "health"}
136
 
137
- if __name__ == "__main__":
138
  import uvicorn
139
- uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
 
 
 
 
8
  import requests
9
  import zipfile
10
  from utils import *
11
+ import uuid
12
+ import queue
13
+ import threading
14
+ import asyncio
15
 
 
 
 
16
 
17
  security = HTTPBearer()
18
  app = FastAPI()
 
33
  do_normalize: bool = True
34
  denoise: bool = True
35
 
36
+ # 队列相关变量
37
+ task_queue = queue.Queue()
38
+ output_dir = "./output"
39
+ max_output_files = 10
40
+
41
+ # 确保输出目录存在
42
+ os.makedirs(output_dir, exist_ok=True)
43
+
44
+ # 处理函数
45
+ def cleanup_old_files():
46
+ """清理最老的文件,保持最多10个"""
47
+ try:
48
+ files = [(f, os.path.getctime(os.path.join(output_dir, f))) for f in os.listdir(output_dir) if f.endswith('.wav')]
49
+ files.sort(key=lambda x: x[1]) # 按创建时间排序
50
+
51
+ # 删除最老的文件直到只剩10个
52
+ while len(files) > max_output_files:
53
+ oldest_file = files.pop(0)[0]
54
+ os.remove(os.path.join(output_dir, oldest_file))
55
+ except Exception as e:
56
+ print_with_time(f"Error cleaning up old files: {e}")
57
+
58
+ async def process_queue():
59
+ print_with_time("Loading VoxCPM model...")
60
+ model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
61
+ print_with_time("VoxCPM model loaded.")
62
+
63
+ while True:
64
+ try:
65
+ task_data = task_queue.get_nowait()
66
+ request = task_data["request"]
67
+ text = (request.text or "").strip()
68
+ if len(text) == 0:
69
+ continue
70
+
71
+ if model is None:
72
+ raise RuntimeError("Failed to initialize model")
73
+
74
+ print_with_time(f"Generating audio for : '{text[:60]}...'")
75
+ with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
76
+ wav = model.generate(
77
+ text=text,
78
+ prompt_wav_path=f"./voices/{request.voice}.wav",
79
+ prompt_text=f.read(),
80
+ cfg_value=request.cfg_value,
81
+ inference_timesteps=request.inference_timesteps,
82
+ normalize=request.do_normalize,
83
+ denoise=request.denoise
84
+ )
85
+ sf.write(os.path.join(output_dir, f"{task_data['task_id']}.wav"), wav, 16000)
86
+
87
+ # 清理旧文件
88
+ cleanup_old_files()
89
+
90
+ task_queue.task_done()
91
+ await asyncio.sleep(0.6)
92
+ except queue.Empty:
93
+ await asyncio.sleep(0.6)
94
+ except Exception as e:
95
+ print_with_time(f"Error processing queue item: {e}")
96
+ await asyncio.sleep(0.6)
97
+
98
+
99
  @app.post("/generate")
100
+ async def generate_tts_async(request: GenerateRequest, token: str = Depends(verify_token)):
101
+ task_id = str(uuid.uuid4())
 
 
 
 
102
 
103
+ # 将任务添加到队列
104
+ task_data = {"task_id": task_id, "request": request}
105
+ task_queue.put(task_data)
106
+
107
+ return {"task_id": task_id}
108
+
109
+ @app.get("/tts/{task_id}")
110
+ async def get_generate_result(task_id: str, token: str = Depends(verify_token)):
111
+ filepath = os.path.join(output_dir, f"{task_id}.wav")
112
+
113
+ if not os.path.exists(filepath):
114
+ raise HTTPException(status_code=404, detail="Result file not found")
115
+ try:
116
+ with open(filepath, 'rb') as f:
117
+ content = f.read()
118
+ return Response(content=content, media_type="audio/wav")
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}")
121
 
122
 
123
  def download_voices():
 
182
  def get_voices(token: str = Depends(verify_token)):
183
  download_voices()
184
  # 获取所有 .pmt 文件
185
+ pmt_files = [f for f in os.listdir("./voices") if f.endswith(".pmt")]
186
  # 提取文件名(去掉 .pmt 后缀)
187
  voices = [f.split(".")[0] for f in pmt_files]
188
  # 确保对应的 .wav 文件也存在
189
  valid_voices = []
190
  for voice in voices:
191
+ if os.path.exists(f"./voices/{voice}.wav"):
192
  valid_voices.append(voice)
193
  return {"voices": valid_voices}
194
 
195
 
196
  # ↓↓↓↓↓↓↓↓↓无需验证↓↓↓↓↓↓↓↓
197
  @app.get("/")
198
+ @app.get("/health")
199
  def health_check():
200
  return {"status": "health"}
201
 
202
+ def start_api_server():
203
  import uvicorn
204
+ uvicorn.run(app, host="0.0.0.0", port=7860)
205
+
206
+ threading.Thread(target=start_api_server, daemon=True).start()
207
+ asyncio.run(process_queue())