File size: 4,085 Bytes
be85c0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import base64
import os
import tempfile
import uuid
from pathlib import Path
from typing import Optional

import requests
import torch
from fastapi import Body, FastAPI, Header, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field, HttpUrl
from TTS.api import TTS

SPACE_API_KEY = os.getenv("SPACE_API_KEY")
MAX_TEXT_LENGTH = 1000
DEFAULT_LANGUAGE = "en"

# Pick CUDA if available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the XTTS v2 model once at startup
# Hugging Face Spaces caches model weights on persistent storage
try:
    tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=DEVICE == "cuda")
except Exception as exc:  # pragma: no cover - startup failure path
    # Fail fast on startup; Spaces will show the error in logs
    raise RuntimeError(f"Failed to load XTTS v2 model: {exc}") from exc

app = FastAPI(title="xtts-v2-api", version="1.0.0")


class GenerateRequest(BaseModel):
    text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
    speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded WAV/MP3/M4A")
    language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO language code, default en")


def _require_api_key(x_api_key: Optional[str]):
    if not SPACE_API_KEY:
        return
    if x_api_key != SPACE_API_KEY:
        raise HTTPException(status_code=401, detail="Unauthorized")


def _write_temp_audio_from_url(url: HttpUrl) -> str:
    response = requests.get(url, stream=True, timeout=30)
    if response.status_code >= 400:
        raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}")
    suffix = Path(url.path).suffix or ".wav"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                tmp.write(chunk)
        return tmp.name


def _write_temp_audio_from_base64(payload: str) -> str:
    try:
        raw = base64.b64decode(payload)
    except Exception as exc:  # pragma: no cover - malformed base64
        raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
        tmp.write(raw)
        return tmp.name


def _temp_speaker_file(speaker_wav: str) -> str:
    if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
        return _write_temp_audio_from_url(HttpUrl(speaker_wav))
    return _write_temp_audio_from_base64(speaker_wav)


@app.post("/health")
def health(x_api_key: Optional[str] = Header(default=None)):
    _require_api_key(x_api_key)
    return {"status": "ok", "model": "xtts_v2", "device": DEVICE}


@app.post("/generate")
def generate(
    payload: GenerateRequest = Body(...),
    x_api_key: Optional[str] = Header(default=None),
):
    _require_api_key(x_api_key)

    speaker_file = None
    output_file = None

    try:
        speaker_file = _temp_speaker_file(payload.speaker_wav)
        output_file = os.path.join(tempfile.gettempdir(), f"xtts-{uuid.uuid4()}.wav")

        tts_model.tts_to_file(
            text=payload.text,
            file_path=output_file,
            speaker_wav=speaker_file,
            language=payload.language or DEFAULT_LANGUAGE,
            split_sentences=True,
            use_cuda=DEVICE == "cuda",
        )

        return FileResponse(output_file, media_type="audio/wav", filename="output.wav")

    except HTTPException:
        raise
    except Exception as exc:  # pragma: no cover - runtime failure path
        # Surface readable errors to client
        return JSONResponse(status_code=500, content={"error": str(exc)})
    finally:
        if speaker_file and Path(speaker_file).exists():
            Path(speaker_file).unlink(missing_ok=True)
        if output_file and Path(output_file).exists():
            Path(output_file).unlink(missing_ok=True)


@app.get("/")
def root():
    return {"name": "xtts-v2-api", "endpoints": ["/health", "/generate"]}