|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import uuid, time |
|
|
import wai_service |
|
|
from typing import Optional |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
EXEC = ThreadPoolExecutor(max_workers=1) |
|
|
JOBS = {} |
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
word: str |
|
|
optimized_letter: str |
|
|
font: str = "KaushanScript-Regular" |
|
|
seed: int = 0 |
|
|
token: Optional[str] = None |
|
|
|
|
|
class Config: |
|
|
extra = "allow" |
|
|
|
|
|
@app.post("/generate") |
|
|
def enqueue(req: InferenceRequest): |
|
|
"""Start a job and return its UUID immediately.""" |
|
|
if req.token is None: |
|
|
raise HTTPException(422, "field 'token' is required") |
|
|
payload = req.dict(exclude_none=True) |
|
|
job_id = str(uuid.uuid4()) |
|
|
fut = EXEC.submit(wai_service.handler, dict(req.__dict__)) |
|
|
JOBS[job_id] = (time.time(), fut) |
|
|
return {"job_id": job_id} |
|
|
|
|
|
@app.get("/result/{job_id}") |
|
|
def get_result(job_id: str): |
|
|
"""Poll for job completion.""" |
|
|
if job_id not in JOBS: |
|
|
raise HTTPException(404, "job_id not found") |
|
|
|
|
|
start, fut = JOBS[job_id] |
|
|
if not fut.done(): |
|
|
return {"status": "running", "elapsed": int(time.time() - start)} |
|
|
|
|
|
try: |
|
|
img = fut.result() |
|
|
return {"status": "finished", "image_base64": img} |
|
|
finally: |
|
|
del JOBS[job_id] |
|
|
|