indextts2-api / app.py
ataberkkilavuzcu's picture
huggingface files.
be85c0f
raw
history blame
4.09 kB
import base64
import os
import tempfile
import uuid
from pathlib import Path
from typing import Optional
import requests
import torch
from fastapi import Body, FastAPI, Header, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field, HttpUrl
from TTS.api import TTS
SPACE_API_KEY = os.getenv("SPACE_API_KEY")
MAX_TEXT_LENGTH = 1000
DEFAULT_LANGUAGE = "en"
# Pick CUDA if available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the XTTS v2 model once at startup
# Hugging Face Spaces caches model weights on persistent storage
try:
tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=DEVICE == "cuda")
except Exception as exc: # pragma: no cover - startup failure path
# Fail fast on startup; Spaces will show the error in logs
raise RuntimeError(f"Failed to load XTTS v2 model: {exc}") from exc
app = FastAPI(title="xtts-v2-api", version="1.0.0")
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded WAV/MP3/M4A")
language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO language code, default en")
def _require_api_key(x_api_key: Optional[str]):
if not SPACE_API_KEY:
return
if x_api_key != SPACE_API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
def _write_temp_audio_from_url(url: HttpUrl) -> str:
response = requests.get(url, stream=True, timeout=30)
if response.status_code >= 400:
raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}")
suffix = Path(url.path).suffix or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
tmp.write(chunk)
return tmp.name
def _write_temp_audio_from_base64(payload: str) -> str:
try:
raw = base64.b64decode(payload)
except Exception as exc: # pragma: no cover - malformed base64
raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(raw)
return tmp.name
def _temp_speaker_file(speaker_wav: str) -> str:
if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
return _write_temp_audio_from_url(HttpUrl(speaker_wav))
return _write_temp_audio_from_base64(speaker_wav)
@app.post("/health")
def health(x_api_key: Optional[str] = Header(default=None)):
_require_api_key(x_api_key)
return {"status": "ok", "model": "xtts_v2", "device": DEVICE}
@app.post("/generate")
def generate(
payload: GenerateRequest = Body(...),
x_api_key: Optional[str] = Header(default=None),
):
_require_api_key(x_api_key)
speaker_file = None
output_file = None
try:
speaker_file = _temp_speaker_file(payload.speaker_wav)
output_file = os.path.join(tempfile.gettempdir(), f"xtts-{uuid.uuid4()}.wav")
tts_model.tts_to_file(
text=payload.text,
file_path=output_file,
speaker_wav=speaker_file,
language=payload.language or DEFAULT_LANGUAGE,
split_sentences=True,
use_cuda=DEVICE == "cuda",
)
return FileResponse(output_file, media_type="audio/wav", filename="output.wav")
except HTTPException:
raise
except Exception as exc: # pragma: no cover - runtime failure path
# Surface readable errors to client
return JSONResponse(status_code=500, content={"error": str(exc)})
finally:
if speaker_file and Path(speaker_file).exists():
Path(speaker_file).unlink(missing_ok=True)
if output_file and Path(output_file).exists():
Path(output_file).unlink(missing_ok=True)
@app.get("/")
def root():
return {"name": "xtts-v2-api", "endpoints": ["/health", "/generate"]}