|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, Dict, Any |
|
|
import json, uuid, time, os |
|
|
import requests |
|
|
import websocket |
|
|
from urllib.parse import urlencode |
|
|
import gradio as gr |
|
|
|
|
|
COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") |
|
|
|
|
|
with open("workflow.json", "r", encoding="utf-8") as f: |
|
|
WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) |
|
|
|
|
|
class T2VReq(BaseModel): |
|
|
token: str = Field(...) |
|
|
text: str = Field(...) |
|
|
negative: Optional[str] = None |
|
|
seed: Optional[int] = None |
|
|
steps: Optional[int] = 4 |
|
|
cfg: Optional[float] = 1 |
|
|
width: Optional[int] = 640 |
|
|
height: Optional[int] = 640 |
|
|
length: Optional[int] = 81 |
|
|
fps: Optional[int] = 16 |
|
|
filename_prefix: Optional[str] = "video/ComfyUI" |
|
|
|
|
|
def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: |
|
|
p = json.loads(json.dumps(prompt)) |
|
|
p["89"]["inputs"]["text"] = r.text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if r.width is not None: p["74"]["inputs"]["width"] = r.width |
|
|
if r.height is not None: p["74"]["inputs"]["height"] = r.height |
|
|
if r.length is not None: p["74"]["inputs"]["length"] = r.length |
|
|
if r.fps is not None: p["88"]["inputs"]["fps"] = r.fps |
|
|
if r.filename_prefix: |
|
|
p["80"]["inputs"]["filename_prefix"] = r.filename_prefix |
|
|
return p |
|
|
|
|
|
def _open_ws(client_id: str, token: str): |
|
|
ws = websocket.WebSocket() |
|
|
ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) |
|
|
return ws |
|
|
|
|
|
def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: |
|
|
payload = {"prompt": prompt, "client_id": client_id} |
|
|
resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) |
|
|
if resp.status_code != 200: |
|
|
raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") |
|
|
data = resp.json() |
|
|
if "prompt_id" not in data: |
|
|
raise RuntimeError(f"/prompt no prompt_id: {data}") |
|
|
return data["prompt_id"] |
|
|
|
|
|
def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: |
|
|
r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) |
|
|
r.raise_for_status() |
|
|
hist = r.json() |
|
|
return hist.get(prompt_id, {}) |
|
|
|
|
|
def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: |
|
|
outputs = history.get("outputs", {}) |
|
|
for _, node_out in outputs.items(): |
|
|
if "images" in node_out: |
|
|
for it in node_out["images"]: |
|
|
if all(k in it for k in ("filename", "subfolder", "type")): |
|
|
fn = it["filename"] |
|
|
if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")): |
|
|
return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} |
|
|
for key in ("videos", "files"): |
|
|
if key in node_out and node_out[key]: |
|
|
it = node_out[key][0] |
|
|
if all(k in it for k in ("filename", "subfolder", "type")): |
|
|
return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} |
|
|
raise RuntimeError("No video file found in history outputs") |
|
|
|
|
|
sample_prompts = [ |
|
|
"A golden retriever running across a beach at sunset, cinematic", |
|
|
"A cyberpunk city street at night with neon lights, light rain, slow pan", |
|
|
"An astronaut walking on an alien planet covered in glowing crystals, purple sky with two moons, dust particles floating, slow panning shot, highly detailed, cinematic atmosphere.", |
|
|
"A cat gracefully jumping between rooftops in slow motion, warm sunset lighting, camera tracking the cat midair, cinematic composition, natural movement." |
|
|
] |
|
|
|
|
|
with gr.Blocks( |
|
|
title="T2V UI", |
|
|
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), |
|
|
) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown("# Experience Wan2.2 14B Text-to-Video on AMD MI300X — Free Trial") |
|
|
gr.Markdown("Powered by [AMD Devcloud](https://oneclickamd.ai/) and [ComfyUI](https://github.com/comfyanonymous/ComfyUI)") |
|
|
gr.Markdown("### Prompt") |
|
|
text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3) |
|
|
|
|
|
gr.Examples(examples=sample_prompts, inputs=text) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
with gr.Row(): |
|
|
width = gr.Number(label="Width", value=640, precision=0) |
|
|
height = gr.Number(label="Height", value=640, precision=0) |
|
|
with gr.Row(): |
|
|
length = gr.Number(label="Frames", value=81, precision=0) |
|
|
fps = gr.Number(label="FPS", value=8, precision=0) |
|
|
with gr.Row(): |
|
|
steps = gr.Number(label="Steps", value=4, precision=0) |
|
|
cfg = gr.Number(label="CFG", value=5.0) |
|
|
seed = gr.Number(label="Seed (optional)", value=None) |
|
|
filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") |
|
|
st_token = gr.Textbox(label="token", placeholder="name") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
run_btn = gr.Button("Generate", variant="primary", scale=1) |
|
|
prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False) |
|
|
with gr.Column(scale=1): |
|
|
out_video = gr.Video(label="Result", height=480) |
|
|
|
|
|
def _init_token(): |
|
|
return str(uuid.uuid4()) |
|
|
|
|
|
demo.load(_init_token, outputs=st_token) |
|
|
|
|
|
def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token): |
|
|
req = T2VReq( |
|
|
token=token, |
|
|
text=text, |
|
|
seed=int(seed) if seed is not None else None, |
|
|
steps=int(steps) if steps is not None else None, |
|
|
cfg=float(cfg) if cfg is not None else None, |
|
|
width=int(width) if width is not None else None, |
|
|
height=int(height) if height is not None else None, |
|
|
length=int(length) if length is not None else None, |
|
|
fps=int(fps) if fps is not None else None, |
|
|
filename_prefix=filename_prefix if filename_prefix else None, |
|
|
) |
|
|
prompt = _inject_params(WORKFLOW_TEMPLATE, req) |
|
|
client_id = str(uuid.uuid4()) |
|
|
ws = _open_ws(client_id, req.token) |
|
|
prompt_id = _queue_prompt(prompt, client_id, req.token) |
|
|
total_nodes = max(1, len(prompt)) |
|
|
seen = set() |
|
|
p = 0 |
|
|
last_emit = -1 |
|
|
start = time.time() |
|
|
ws.settimeout(180) |
|
|
while True: |
|
|
out = ws.recv() |
|
|
if isinstance(out, (bytes, bytearray)): |
|
|
if p < 95 and time.time() - start > 2: |
|
|
p = min(95, p + 1) |
|
|
if p != last_emit: |
|
|
last_emit = p |
|
|
yield p, None |
|
|
continue |
|
|
msg = json.loads(out) |
|
|
if msg.get("type") == "executing": |
|
|
data = msg.get("data", {}) |
|
|
if data.get("prompt_id") != prompt_id: |
|
|
continue |
|
|
node = data.get("node") |
|
|
if node is None: |
|
|
break |
|
|
if node not in seen: |
|
|
seen.add(node) |
|
|
p = min(99, int(len(seen) / total_nodes * 100)) |
|
|
if p != last_emit: |
|
|
last_emit = p |
|
|
yield p, None |
|
|
ws.close() |
|
|
hist = _get_history(prompt_id, req.token) |
|
|
info = _extract_video_from_history(hist) |
|
|
q = urlencode(info) |
|
|
video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" |
|
|
yield 100, video_url |
|
|
|
|
|
run_btn.click( |
|
|
generate_fn, |
|
|
inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token], |
|
|
outputs=[prog_bar, out_video] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |