Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为api_v2增加接受base64编码的reference audio文件的功能 #1811

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"ref_audio_path": "", # str.(optional) reference audio path, required if `ref_audio` not set
"ref_audio": "" # str.(optional) reference audio encoded as base64, required if `ref_audio_path` not set
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
Expand All @@ -42,6 +43,53 @@
}
```

#### 使用base64发送reference audio示例代码
```python
import requests
import base64
import json

# 定义音频文件路径和文本内容
audio_file_path = "/src/GPT-SoVITS/reference/news_reporter.wav"
audio_text = "系统内随后发售的其它车次车票依然十分抢手,仅个别车次仍有余票。"

# 读取音频文件并进行Base64编码
with open(audio_file_path, "rb") as audio_file:
audio_data = audio_file.read()
audio_base64 = base64.b64encode(audio_data).decode('utf-8')

# 构建请求数据
payload = {
"text": "四渡赤水中,毛泽东险些辞职。因为他的一个打法,违反了最基本的战场原则——先弱后强,遭到了所有人的反对。但在周恩来的支持下,其他人还是按他说的打了。果然,这不仅避开了蒋中正的陷阱,还让红军找到了生机。",
"text_lang": "zh",
"ref_audio": audio_base64,
"prompt_text": audio_text,
"prompt_lang": "zh",
"top_k": 15,
"top_p": 0.7,
"temperature": 0.85,
"text_split_method": "cut5",
"batch_size": 1,
"media_type": "wav",
"streaming_mode": False,
"speed_factor": 1.1
}

# 发送POST请求
url = 'http://localhost:9880/tts'
headers = {'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, data=json.dumps(payload))

# 检查响应状态码
if response.status_code == 200:
with open('output.wav', 'wb') as output_file:
output_file.write(response.content)
print("音频文件已保存为 output.wav")
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
```

RESP:
成功: 直接返回 wav 音频流, http code 200
失败: 返回包含错误信息的 json, http code 400
Expand Down Expand Up @@ -99,6 +147,8 @@
import sys
import traceback
from typing import Generator
import base64
import tempfile

now_dir = os.getcwd()
sys.path.append(now_dir)
Expand Down Expand Up @@ -147,6 +197,7 @@ class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
ref_audio: str = None
aux_ref_audio_paths: list = None
prompt_lang: str = None
prompt_text: str = ""
Expand Down Expand Up @@ -246,8 +297,8 @@ def check_params(req:dict):
prompt_lang:str = req.get("prompt_lang", "")
text_split_method:str = req.get("text_split_method", "cut5")

if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if ref_audio_path in [None, ""] and ref_audio in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path or ref_audio is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) :
Expand Down Expand Up @@ -367,6 +418,7 @@ async def tts_get_endpoint(
parallel_infer:bool = True,
repetition_penalty:float = 1.35
):

req = {
"text": text,
"text_lang": text_lang.lower(),
Expand Down Expand Up @@ -395,6 +447,13 @@ async def tts_get_endpoint(
@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
if req['ref_audio']:
print(req['ref_audio'])
# 上传了以base64编码的参考音频,解码后存入临时文件并将文件路径赋值给ref_audio_path
decoded_audio = base64.b64decode(req['ref_audio'])
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(decoded_audio)
req['ref_audio_path'] = temp_file.name
return await tts_handle(req)


Expand Down