Spaces:
Running
Running
| import base64 | |
| import os | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| from typing import Optional | |
| import requests | |
| import torch | |
| from fastapi import Body, FastAPI, Header, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field, HttpUrl | |
| from TTS.api import TTS | |
| SPACE_API_KEY = os.getenv("SPACE_API_KEY") | |
| MAX_TEXT_LENGTH = 1000 | |
| DEFAULT_LANGUAGE = "en" | |
| # Pick CUDA if available | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the XTTS v2 model once at startup | |
| # Hugging Face Spaces caches model weights on persistent storage | |
| try: | |
| tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=DEVICE == "cuda") | |
| except Exception as exc: # pragma: no cover - startup failure path | |
| # Fail fast on startup; Spaces will show the error in logs | |
| raise RuntimeError(f"Failed to load XTTS v2 model: {exc}") from exc | |
| app = FastAPI(title="xtts-v2-api", version="1.0.0") | |
| class GenerateRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH) | |
| speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded WAV/MP3/M4A") | |
| language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO language code, default en") | |
| def _require_api_key(x_api_key: Optional[str]): | |
| if not SPACE_API_KEY: | |
| return | |
| if x_api_key != SPACE_API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def _write_temp_audio_from_url(url: HttpUrl) -> str: | |
| response = requests.get(url, stream=True, timeout=30) | |
| if response.status_code >= 400: | |
| raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}") | |
| suffix = Path(url.path).suffix or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp.write(chunk) | |
| return tmp.name | |
| def _write_temp_audio_from_base64(payload: str) -> str: | |
| try: | |
| raw = base64.b64decode(payload) | |
| except Exception as exc: # pragma: no cover - malformed base64 | |
| raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(raw) | |
| return tmp.name | |
| def _temp_speaker_file(speaker_wav: str) -> str: | |
| if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"): | |
| return _write_temp_audio_from_url(HttpUrl(speaker_wav)) | |
| return _write_temp_audio_from_base64(speaker_wav) | |
| def health(x_api_key: Optional[str] = Header(default=None)): | |
| _require_api_key(x_api_key) | |
| return {"status": "ok", "model": "xtts_v2", "device": DEVICE} | |
| def generate( | |
| payload: GenerateRequest = Body(...), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| _require_api_key(x_api_key) | |
| speaker_file = None | |
| output_file = None | |
| try: | |
| speaker_file = _temp_speaker_file(payload.speaker_wav) | |
| output_file = os.path.join(tempfile.gettempdir(), f"xtts-{uuid.uuid4()}.wav") | |
| tts_model.tts_to_file( | |
| text=payload.text, | |
| file_path=output_file, | |
| speaker_wav=speaker_file, | |
| language=payload.language or DEFAULT_LANGUAGE, | |
| split_sentences=True, | |
| use_cuda=DEVICE == "cuda", | |
| ) | |
| return FileResponse(output_file, media_type="audio/wav", filename="output.wav") | |
| except HTTPException: | |
| raise | |
| except Exception as exc: # pragma: no cover - runtime failure path | |
| # Surface readable errors to client | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| finally: | |
| if speaker_file and Path(speaker_file).exists(): | |
| Path(speaker_file).unlink(missing_ok=True) | |
| if output_file and Path(output_file).exists(): | |
| Path(output_file).unlink(missing_ok=True) | |
| def root(): | |
| return {"name": "xtts-v2-api", "endpoints": ["/health", "/generate"]} | |