CineGen / app.py
VirtualOasis's picture
Update app.py
934e5c4 verified
raw
history blame
22.4 kB
import json
import os
import tempfile
import textwrap
import time
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import requests
try:
from google import genai
except ImportError: # pragma: no cover - dependency is optional at import time
genai = None
# -----------------------------
# Domain data representations
# -----------------------------
@dataclass
class CharacterProfile:
character_id: str
name: str
description: str
visual_tags: List[str] = field(default_factory=list)
image_path: Optional[str] = None
@dataclass
class ScenePlan:
scene_id: str
title: str
summary: str
visual_prompt: str
characters: List[str] = field(default_factory=list)
@dataclass
class StoryboardPlan:
title: str
logline: str
style: str
runtime_hint: str
tone: str
characters: List[CharacterProfile]
scenes: List[ScenePlan]
def to_dict(self) -> Dict[str, Any]:
return {
"title": self.title,
"logline": self.logline,
"style": self.style,
"runtime_hint": self.runtime_hint,
"tone": self.tone,
"characters": [asdict(c) for c in self.characters],
"scenes": [asdict(s) for s in self.scenes],
}
# -----------------------------
# Helper utilities
# -----------------------------
def resolve_token(user_supplied: str, env_key: str) -> Optional[str]:
candidate = (user_supplied or "").strip()
if candidate:
return candidate
env_candidate = (os.getenv(env_key) or "").strip()
return env_candidate or None
def extract_json_block(text: str) -> str:
"""Return the first JSON object found inside the text."""
stack = []
start_index = None
for index, char in enumerate(text):
if char == "{":
if not stack:
start_index = index
stack.append(char)
elif char == "}" and stack:
stack.pop()
if not stack and start_index is not None:
return text[start_index : index + 1]
return text
def format_character_markdown(characters: List[CharacterProfile]) -> str:
lines = []
for character in characters:
tags = ", ".join(character.visual_tags) if character.visual_tags else "n/a"
lines.append(f"- **{character.name}** ({character.character_id}): {character.description} \n Visual tags: {tags}")
return "\n".join(lines) if lines else "No characters were generated yet."
def ensure_module_available(module_ref, friendly_name: str) -> None:
if module_ref is None:
raise gr.Error(
f"{friendly_name} is not installed. Install it via `pip install google-genai` and try again."
)
# -----------------------------
# Gemini services
# -----------------------------
class GeminiService:
def __init__(
self,
api_key: str,
story_model: str = "gemini-2.5-flash",
image_model: str = "gemini-2.5-flash-image",
) -> None:
ensure_module_available(genai, "google-genai")
if not api_key:
raise gr.Error("Google API key is required.")
self.client = genai.Client(api_key=api_key)
self.story_model = story_model
self.image_model = image_model
def generate_storyboard(
self,
movie_idea: str,
visual_style: str,
scene_count: int,
runtime_hint: str,
tone: str,
) -> StoryboardPlan:
prompt = textwrap.dedent(
f"""
You are CineGen, an AI creative director. Given the following idea, craft a production-ready storyboard.
Idea: {movie_idea}
Target visual style: {visual_style}
Desired runtime: {runtime_hint}
Tone keywords: {tone}
Scene count: exactly {scene_count}
Respond with valid JSON using this schema:
{{
"title": "...",
"logline": "...",
"style": "...",
"runtime_hint": "...",
"tone": "...",
"characters": [
{{"id": "char_1", "name": "...", "description": "...", "visual_tags": ["tag1", "tag2"]}}
],
"scenes": [
{{
"id": "scene_1",
"title": "...",
"summary": "...",
"visual_prompt": "...",
"characters": ["char_1", "char_2"]
}}
]
}}
Ensure each scene references character IDs from the characters array and highlight cinematic camera or lighting cues inside "visual_prompt".
"""
).strip()
response = self.client.models.generate_content(
model=self.story_model,
contents=prompt,
)
raw_text = getattr(response, "text", None) or "".join(
[getattr(part, "text", "") for part in getattr(response, "parts", [])]
)
if not raw_text:
raise gr.Error("Gemini did not return any content for the storyboard.")
serialized = extract_json_block(raw_text)
payload = json.loads(serialized)
characters = [
CharacterProfile(
character_id=entry.get("id", f"char_{idx+1}"),
name=entry.get("name", f"Character {idx+1}"),
description=entry.get("description", ""),
visual_tags=entry.get("visual_tags") or [],
)
for idx, entry in enumerate(payload.get("characters", []))
]
scenes = [
ScenePlan(
scene_id=scene.get("id", f"scene_{idx+1}"),
title=scene.get("title", f"Scene {idx+1}"),
summary=scene.get("summary", ""),
visual_prompt=scene.get("visual_prompt", ""),
characters=scene.get("characters") or [],
)
for idx, scene in enumerate(payload.get("scenes", []))
]
if len(scenes) != scene_count:
# Keep UX predictable even if the model under-delivers on scene count.
scenes = scenes[:scene_count]
return StoryboardPlan(
title=payload.get("title", "Untitled"),
logline=payload.get("logline", ""),
style=payload.get("style", visual_style),
runtime_hint=payload.get("runtime_hint", runtime_hint),
tone=payload.get("tone", tone),
characters=characters,
scenes=scenes,
)
def generate_character_images(
self,
characters: List[CharacterProfile],
visual_style: str,
max_characters: int = 4,
) -> List[CharacterProfile]:
rendered: List[CharacterProfile] = []
for character in characters[:max_characters]:
prompt = textwrap.dedent(
f"""
Create a front-facing character reference portrait for use in a video production pipeline.
Character: {character.name}
Description: {character.description}
Visual tags: {", ".join(character.visual_tags) if character.visual_tags else "n/a"}
Style: {visual_style}
Output a single cohesive concept art image.
"""
).strip()
response = self.client.models.generate_content(
model=self.image_model,
contents=prompt,
)
image_path = None
for part in getattr(response, "parts", []):
if getattr(part, "inline_data", None):
image = part.as_image()
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
image.save(tmp.name)
image_path = tmp.name
break
enriched = CharacterProfile(
character_id=character.character_id,
name=character.name,
description=character.description,
visual_tags=character.visual_tags,
image_path=image_path,
)
rendered.append(enriched)
return rendered
# -----------------------------
# Hugging Face video service
# -----------------------------
class HuggingFaceVideoService:
MODEL_FALLBACK = [
"Wan-AI/Wan2.1-T2V-14B",
"Lightricks/LTX-Video-0.9.7-distilled",
"tencent/HunyuanVideo-1.5",
"THUDM/CogVideoX-5b",
]
def __init__(self, token: str):
if not token:
raise gr.Error("Hugging Face token is required for video generation.")
self.token = token
self.session = requests.Session()
def generate(
self,
prompt: str,
preferred_model: Optional[str],
negative_prompt: str,
duration_seconds: float,
fps: int,
seed: Optional[int],
) -> Tuple[str, str]:
ordered_models = self._order_models(preferred_model)
last_error = ""
for model in ordered_models:
try:
video_path = self._invoke_model(
model=model,
prompt=prompt,
negative_prompt=negative_prompt,
duration_seconds=duration_seconds,
fps=fps,
seed=seed,
)
return model, video_path
except Exception as exc: # pragma: no cover - defensive fallback
last_error = str(exc)
time.sleep(1.5)
raise gr.Error(f"All video backends failed. Last error: {last_error}")
def _order_models(self, preferred_model: Optional[str]) -> List[str]:
models = list(self.MODEL_FALLBACK)
if preferred_model and preferred_model in models:
models.remove(preferred_model)
models.insert(0, preferred_model)
elif preferred_model:
models.insert(0, preferred_model)
return models
def _invoke_model(
self,
model: str,
prompt: str,
negative_prompt: str,
duration_seconds: float,
fps: int,
seed: Optional[int],
) -> str:
url = f"https://api-inference.huggingface.co/models/{model}"
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "video/mp4",
}
payload = {
"inputs": prompt,
"parameters": {
"negative_prompt": negative_prompt,
"num_frames": int(duration_seconds * fps),
"fps": fps,
"seed": seed,
"guidance_scale": 7.5,
},
"options": {"use_cache": True, "wait_for_model": True},
}
response = self.session.post(
url,
headers=headers,
json=payload,
timeout=600,
)
if response.status_code == 200:
return self._write_video(response.content)
if response.status_code in {503, 504, 524}:
raise RuntimeError(f"{model} is warming up or busy (status {response.status_code}).")
try:
message = response.json()
except Exception:
message = response.text
raise RuntimeError(f"{model} failed: {message}")
@staticmethod
def _write_video(content: bytes) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as handle:
handle.write(content)
return handle.name
# -----------------------------
# CineGen pipeline orchestration
# -----------------------------
def build_scene_prompt(scene: ScenePlan, storyboard: StoryboardPlan) -> str:
character_blurbs = []
pool = {c.character_id: c for c in storyboard.characters}
for actor_id in scene.characters:
profile = pool.get(actor_id)
if profile:
tags = ", ".join(profile.visual_tags) if profile.visual_tags else ""
character_blurbs.append(f"{profile.name}: {profile.description} {tags}".strip())
character_block = "\n".join(character_blurbs) if character_blurbs else "Original characters only."
return textwrap.dedent(
f"""
Title: {storyboard.title}
Logline: {storyboard.logline}
Scene: {scene.title} ({scene.scene_id})
Narrative summary: {scene.summary}
Visual prompt: {scene.visual_prompt}
Visual style: {storyboard.style}
Tone: {storyboard.tone}
Characters:\n{character_block}
"""
).strip()
# -----------------------------
# Gradio callbacks
# -----------------------------
def storyboard_callback(
movie_idea: str,
visual_style: str,
runtime_hint: str,
tone: str,
scene_count: int,
google_api_key_input: str,
):
api_key = resolve_token(google_api_key_input, "GOOGLE_API_KEY")
if not movie_idea:
raise gr.Error("Please describe your movie idea first.")
storyboard_service = GeminiService(api_key=api_key)
storyboard = storyboard_service.generate_storyboard(
movie_idea=movie_idea,
visual_style=visual_style,
scene_count=scene_count,
runtime_hint=runtime_hint,
tone=tone,
)
characters_with_images = storyboard_service.generate_character_images(storyboard.characters, visual_style)
storyboard_dict = storyboard.to_dict()
character_markdown = format_character_markdown(characters_with_images)
gallery_entries = [
(profile.image_path, f"{profile.name} ({profile.character_id})")
for profile in characters_with_images
if profile.image_path
]
scene_choices = [f"{scene.scene_id}: {scene.title}" for scene in storyboard.scenes]
status_message = f"Storyboard ready: {storyboard.title} with {len(storyboard.scenes)} scenes."
return (
status_message,
storyboard_dict,
character_markdown,
gallery_entries,
storyboard_dict,
[asdict(profile) for profile in characters_with_images],
gr.Dropdown.update(choices=scene_choices, value=scene_choices[0] if scene_choices else None),
)
def generate_video_callback(
scene_choice: str,
storyboard_state: Dict[str, Any],
hf_token_input: str,
preferred_model: str,
negative_prompt: str,
duration_seconds: float,
fps: int,
seed: int,
):
if not storyboard_state:
raise gr.Error("Generate a storyboard first.")
hf_token = resolve_token(hf_token_input, "HF_TOKEN")
if not hf_token:
raise gr.Error("Provide a Hugging Face token to render video.")
scenes = storyboard_state.get("scenes", [])
characters = storyboard_state.get("characters", [])
if not scenes:
raise gr.Error("Storyboard has no scenes to render.")
scene_id = (scene_choice or "").split(":")[0]
scene_payload = next((scene for scene in scenes if scene["scene_id"] == scene_id or scene["scene_id"] == scene_choice), None)
if not scene_payload:
scene_payload = scenes[0]
storyboard = StoryboardPlan(
title=storyboard_state.get("title", ""),
logline=storyboard_state.get("logline", ""),
style=storyboard_state.get("style", ""),
runtime_hint=storyboard_state.get("runtime_hint", ""),
tone=storyboard_state.get("tone", ""),
characters=[
CharacterProfile(
character_id=entry.get("character_id") or entry.get("id"),
name=entry.get("name", ""),
description=entry.get("description", ""),
visual_tags=entry.get("visual_tags") or [],
image_path=entry.get("image_path"),
)
for entry in characters
],
scenes=[
ScenePlan(
scene_id=scene["scene_id"],
title=scene["title"],
summary=scene["summary"],
visual_prompt=scene["visual_prompt"],
characters=scene.get("characters") or [],
)
for scene in scenes
],
)
target_scene = next((scene for scene in storyboard.scenes if scene.scene_id == scene_payload["scene_id"]), storyboard.scenes[0])
prompt = build_scene_prompt(target_scene, storyboard)
video_service = HuggingFaceVideoService(token=hf_token)
model_used, video_path = video_service.generate(
prompt=prompt,
preferred_model=preferred_model or None,
negative_prompt=negative_prompt,
duration_seconds=duration_seconds,
fps=fps,
seed=seed if seed >= 0 else None,
)
metadata = {
"model": model_used,
"scene": target_scene.scene_id,
"prompt": prompt,
"negative_prompt": negative_prompt,
"duration_seconds": duration_seconds,
"fps": fps,
}
status_message = f"Rendered scene {target_scene.scene_id} via {model_used}."
return status_message, video_path, metadata
# -----------------------------
# Gradio interface
# -----------------------------
def build_interface() -> gr.Blocks:
default_hf = os.getenv("HF_TOKEN", "")
default_google = os.getenv("GOOGLE_API_KEY", "")
with gr.Blocks() as demo:
gr.Markdown("# CineGen AI Director")
gr.Markdown(
"Transform a simple idea into a storyboard, character deck, and video shots. "
"Tokens can be loaded from the environment for local debugging; in production the fields must be filled manually."
)
with gr.Row():
with gr.Column():
gr.Markdown("### Credentials")
google_api_key_input = gr.Textbox(
label="Google API Key",
value=default_google,
type="password",
placeholder="GOOGLE_API_KEY",
)
hf_token_input = gr.Textbox(
label="Hugging Face Token",
value=default_hf,
type="password",
placeholder="hf_xxx",
)
gr.Markdown("### Story Settings")
movie_idea = gr.Textbox(
label="Movie Idea",
value="A lone robot gardener trying to revive a neon-drenched city park.",
lines=4,
)
visual_style = gr.Dropdown(
label="Visual Style",
choices=["Cinematic Realism", "American Cartoon", "Anime Noir", "Cyberpunk", "Claymation"],
value="Cinematic Realism",
)
runtime_hint = gr.Dropdown(
label="Runtime Target",
choices=["30 seconds", "45 seconds", "60 seconds"],
value="45 seconds",
)
tone = gr.Textbox(
label="Tone keywords",
value="hopeful, dynamic camera, sweeping synth score",
)
scene_count = gr.Slider(
label="Scene Count",
minimum=3,
maximum=8,
value=4,
step=1,
)
generate_storyboard_btn = gr.Button("Generate Storyboard", variant="primary")
with gr.Column():
status_box = gr.Markdown("Status: awaiting input.")
storyboard_json = gr.JSON(label="Storyboard JSON")
character_markdown = gr.Markdown(label="Character Profiles")
character_gallery = gr.Gallery(label="Character Anchors", columns=2, rows=2, height="auto")
with gr.Tab("Scene Rendering"):
scene_choice = gr.Dropdown(label="Scene", choices=[])
preferred_model = gr.Dropdown(
label="Preferred Video Model",
choices=HuggingFaceVideoService.MODEL_FALLBACK,
value=HuggingFaceVideoService.MODEL_FALLBACK[0],
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="low resolution, flicker, watermark, distorted faces",
)
duration_seconds = gr.Slider(label="Duration (s)", minimum=1.0, maximum=4.0, value=2.0, step=0.5)
fps = gr.Slider(label="FPS", minimum=12, maximum=24, value=24, step=1)
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
generate_video_btn = gr.Button("Render Selected Scene", variant="primary")
video_status = gr.Markdown("Video renderer idle.")
video_output = gr.Video(label="Generated Clip")
video_metadata = gr.JSON(label="Render Metadata")
storyboard_state = gr.State({})
character_state = gr.State([])
generate_storyboard_btn.click(
storyboard_callback,
inputs=[movie_idea, visual_style, runtime_hint, tone, scene_count, google_api_key_input],
outputs=[
status_box,
storyboard_json,
character_markdown,
character_gallery,
storyboard_state,
character_state,
scene_choice,
],
)
generate_video_btn.click(
generate_video_callback,
inputs=[
scene_choice,
storyboard_state,
hf_token_input,
preferred_model,
negative_prompt,
duration_seconds,
fps,
seed,
],
outputs=[video_status, video_output, video_metadata],
)
return demo
if __name__ == "__main__":
interface = build_interface()
server_name = os.getenv("GRADIO_SERVER_HOST") or "0.0.0.0"
server_port = int(os.getenv("GRADIO_SERVER_PORT") or os.getenv("SERVER_PORT") or "7860")
interface.launch(
server_name=server_name,
server_port=server_port,
theme=gr.themes.Soft(),
css=".gradio-container {max-width: 1200px; margin: auto;}",
footer_links=["gradio", "settings"],
allowed_paths=[str(Path.cwd())],
ssr_mode=False,
)