import os import cv2 import numpy as np import random import sys import subprocess from typing import Sequence, Mapping, Any, Union import torch from tqdm import tqdm import argparse import json import logging import shutil import gradio as gr import spaces from huggingface_hub import snapshot_download import time import traceback from utils import get_path_after_pexel LOCAL_GRADIO_TMP = os.path.abspath("./gradio_tmp") os.makedirs(LOCAL_GRADIO_TMP, exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = LOCAL_GRADIO_TMP HF_REPOS = { "QingyanBai/Ditto_models": ["models_comfy/ditto_global_comfy.safetensors"], "Kijai/WanVideo_comfy": [ "Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "Wan2_1_VAE_bf16.safetensors", "umt5-xxl-enc-bf16.safetensors", ], } MODELS_ROOT = os.path.abspath(os.path.join(os.getcwd(), "models")) PATHS = { "diffusion_model": os.path.join(MODELS_ROOT, "diffusion_models"), "vae_wan": os.path.join(MODELS_ROOT, "vae", "wan"), "loras": os.path.join(MODELS_ROOT, "loras"), "text_encoders": os.path.join(MODELS_ROOT, "text_encoders"), } REQUIRED_FILES = [ ("Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", "diffusion_model"), ("ditto_global_comfy.safetensors", "diffusion_model"), ("Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "loras"), ("Wan2_1_VAE_bf16.safetensors", "vae_wan"), ("umt5-xxl-enc-bf16.safetensors", "text_encoders"), ] def ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def ensure_models() -> None: for filename, key in REQUIRED_FILES: target_dir = PATHS[key] ensure_dir(target_dir) target_path = os.path.join(target_dir, filename) ready_flag = os.path.join(target_dir, f"{filename}.READY") if os.path.exists(target_path) and os.path.getsize(target_path) > 0: open(ready_flag, "a").close() continue repo_id = None repo_file_path = None for repo, files in HF_REPOS.items(): for file_path in files: if filename in file_path: repo_id = repo repo_file_path = file_path break if repo_id: break if repo_id is None: raise RuntimeError(f"Could not find repository for file: {filename}") print(f"Downloading {filename} from {repo_id} to {target_dir} ...") snapshot_download( repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False, allow_patterns=[repo_file_path], token=os.getenv("HF_TOKEN", None), ) if not os.path.exists(target_path): found = [] for root, _, files in os.walk(target_dir): for f in files: if f == filename: found.append(os.path.join(root, f)) if found: src = found[0] if src != target_path: shutil.copy2(src, target_path) if not os.path.exists(target_path): raise RuntimeError(f"Failed to download required file: {filename}") open(ready_flag, "w").close() print(f"Downloaded and ready: {target_path}") ensure_models() def ensure_t5_tokenizer() -> None: """ Ensure the local T5 tokenizer folder exists and contains valid files. If missing or corrupted, download from 'google/umt5-xxl' and save locally to the exact path expected by the WanVideo wrapper nodes. """ try: script_directory = os.path.dirname(os.path.abspath(__file__)) tokenizer_dir = os.path.join( script_directory, "custom_nodes", "ComfyUI_WanVideoWrapper", "configs", "T5_tokenizer", ) os.makedirs(tokenizer_dir, exist_ok=True) required_files = [ "tokenizer.json", "tokenizer_config.json", "spiece.model", "special_tokens_map.json", ] def is_valid(path: str) -> bool: return os.path.exists(path) and os.path.getsize(path) > 0 all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files) if all_ok: print(f"T5 tokenizer ready at: {tokenizer_dir}") return print(f"Preparing T5 tokenizer at: {tokenizer_dir} ...") from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( "google/umt5-xxl", use_fast=True, trust_remote_code=False, ) tok.save_pretrained(tokenizer_dir) # Re-check all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files) if not all_ok: raise RuntimeError("Tokenizer files not fully prepared after save_pretrained") print("T5 tokenizer prepared successfully.") except Exception as e: print(f"Failed to prepare T5 tokenizer: {e}\n{traceback.format_exc()}") raise ensure_t5_tokenizer() def setup_global_logging_filter(): class MemoryLogFilter(logging.Filter): def filter(self, record): msg = record.getMessage() keywords = [ "Allocated memory:", "Max allocated memory:", "Max reserved memory:", "memory=", "max_memory=", "max_reserved=", "Block swap memory summary", "Transformer blocks on", "Total memory used by", "Non-blocking memory transfer" ] return not any(kw in msg for kw in keywords) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True ) logging.getLogger().handlers[0].addFilter(MemoryLogFilter()) setup_global_logging_filter() def tensor_to_video(video_tensor, output_path, fps=20, crf=20): frames = video_tensor.detach().cpu().numpy() if frames.dtype != np.uint8: if frames.max() <= 1.0: frames = (frames * 255).astype(np.uint8) else: frames = frames.astype(np.uint8) num_frames, height, width, _ = frames.shape command = [ 'ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-pix_fmt', 'rgb24', '-s', f'{width}x{height}', '-r', str(fps), '-i', '-', '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', str(crf), '-preset', 'medium', '-r', str(fps), '-an', output_path ] with subprocess.Popen(command, stdin=subprocess.PIPE, stderr=subprocess.PIPE) as proc: for frame in frames: proc.stdin.write(frame.tobytes()) proc.stdin.close() if proc.stderr is not None: proc.stderr.read() def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: comfyui_path = find_path("ComfyUI") if comfyui_path is not None and os.path.isdir(comfyui_path): if comfyui_path not in sys.path: sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: try: from main import load_extra_path_config except ImportError: print( "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead." ) from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths is not None: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") add_comfyui_directory_to_sys_path() add_extra_model_paths() def import_custom_nodes() -> None: import asyncio import execution from nodes import init_extra_nodes import server if getattr(import_custom_nodes, "_initialized", False): return loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() import_custom_nodes._initialized = True from nodes import NODE_CLASS_MAPPINGS print(f"Loading custom nodes and models...") import_custom_nodes() @spaces.GPU() def run_pipeline(vpath, prompt, width, height, fps, frame_count, outdir): try: import gc # Clean memory before starting gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() os.makedirs(outdir, exist_ok=True) with torch.inference_mode(): from custom_nodes.ComfyUI_WanVideoWrapper import nodes as wan_nodes vhs_loadvideo = NODE_CLASS_MAPPINGS["VHS_LoadVideo"]() # Set model and settings. wanvideovacemodelselect = wan_nodes.WanVideoVACEModelSelect() wanvideovacemodelselect_89 = wanvideovacemodelselect.getvacepath( vace_model="ditto_global_comfy.safetensors" ) wanvideoslg = wan_nodes.WanVideoSLG() wanvideoslg_113 = wanvideoslg.process( blocks="2", start_percent=0.20000000000000004, end_percent=0.7000000000000002, ) wanvideovaeloader = wan_nodes.WanVideoVAELoader() wanvideovaeloader_133 = wanvideovaeloader.loadmodel( model_name="wan/Wan2_1_VAE_bf16.safetensors", precision="bf16" ) loadwanvideot5textencoder = wan_nodes.LoadWanVideoT5TextEncoder() loadwanvideot5textencoder_134 = loadwanvideot5textencoder.loadmodel( model_name="umt5-xxl-enc-bf16.safetensors", precision="bf16", load_device="offload_device", quantization="disabled", ) wanvideoblockswap = wan_nodes.WanVideoBlockSwap() wanvideoblockswap_137 = wanvideoblockswap.setargs( blocks_to_swap=20, offload_img_emb=False, offload_txt_emb=False, use_non_blocking=True, vace_blocks_to_swap=0, ) wanvideoloraselect = wan_nodes.WanVideoLoraSelect() wanvideoloraselect_380 = wanvideoloraselect.getlorapath( lora="Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", strength=1.0, low_mem_load=False, ) wanvideomodelloader = wan_nodes.WanVideoModelLoader() imageresizekjv2 = NODE_CLASS_MAPPINGS["ImageResizeKJv2"]() wanvideovaceencode = wan_nodes.WanVideoVACEEncode() wanvideotextencode = wan_nodes.WanVideoTextEncode() wanvideosampler = wan_nodes.WanVideoSampler() wanvideodecode = wan_nodes.WanVideoDecode() wanvideomodelloader_142 = wanvideomodelloader.loadmodel( model="Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", base_precision="fp16", quantization="disabled", load_device="offload_device", attention_mode="sdpa", block_swap_args=get_value_at_index(wanvideoblockswap_137, 0), lora=get_value_at_index(wanvideoloraselect_380, 0), vace_model=get_value_at_index(wanvideovacemodelselect_89, 0), ) fname = os.path.basename(vpath) fname_clean = os.path.splitext(fname)[0] vhs_loadvideo_70 = vhs_loadvideo.load_video( video=vpath, force_rate=20, custom_width=width, custom_height=height, frame_load_cap=frame_count, skip_first_frames=1, select_every_nth=1, format="AnimateDiff", unique_id=16696422174153060213, ) imageresizekjv2_205 = imageresizekjv2.resize( width=width, height=height, upscale_method="area", keep_proportion="resize", pad_color="0, 0, 0", crop_position="center", divisible_by=8, device="cpu", image=get_value_at_index(vhs_loadvideo_70, 0), ) wanvideovaceencode_29 = wanvideovaceencode.process( width=width, height=height, num_frames=frame_count, strength=0.9750000000000002, vace_start_percent=0, vace_end_percent=1, tiled_vae=False, vae=get_value_at_index(wanvideovaeloader_133, 0), input_frames=get_value_at_index(imageresizekjv2_205, 0), ) wanvideotextencode_148 = wanvideotextencode.process( positive_prompt=prompt, negative_prompt="flickering artifact, jpg artifacts, compression, distortion, morphing, low-res, fake, oversaturated, overexposed, over bright, strange behavior, distorted limbs, unnatural motion, unrealistic anatomy, glitch, extra limbs,", force_offload=True, t5=get_value_at_index(loadwanvideot5textencoder_134, 0), model_to_offload=get_value_at_index(wanvideomodelloader_142, 0), ) # Clean memory before sampling (most memory-intensive step) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() wanvideosampler_2 = wanvideosampler.process( steps=4, cfg=1.2000000000000002, shift=2.0000000000000004, seed=random.randint(1, 2 ** 64), force_offload=True, scheduler="unipc", riflex_freq_index=0, denoise_strength=1, batched_cfg=False, rope_function="comfy", model=get_value_at_index(wanvideomodelloader_142, 0), image_embeds=get_value_at_index(wanvideovaceencode_29, 0), text_embeds=get_value_at_index(wanvideotextencode_148, 0), slg_args=get_value_at_index(wanvideoslg_113, 0), ) res = wanvideodecode.decode( enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128, vae=get_value_at_index(wanvideovaeloader_133, 0), samples=get_value_at_index(wanvideosampler_2, 0), ) save_path = os.path.join(outdir, f'{fname_clean}_edit.mp4') tensor_to_video(res[0], save_path, fps=fps) # Clean up memory after generation del res gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Done. Saved to: {save_path}") return save_path except Exception as e: err = f"Error: {e}\n{traceback.format_exc()}" print(err) # Clean memory on error too gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() raise @spaces.GPU() def gradio_infer(vfile, prompt, width, height, fps, frame_count, progress=gr.Progress(track_tqdm=True)): if vfile is None: return None, "Please upload the video!", "\n".join(logs) vpath = vfile if isinstance(vfile, str) else vfile.name if not os.path.exists(vpath) and hasattr(vfile, "save"): os.makedirs("uploads", exist_ok=True) vpath = os.path.join("uploads", os.path.basename(vfile.name)) vfile.save(vpath) outdir = "results" os.makedirs(outdir, exist_ok=True) save_path = run_pipeline( vpath=vpath, prompt=prompt, width=int(width), height=int(height), fps=int(fps), frame_count=int(frame_count), outdir=outdir, ) return save_path def build_interface(): with gr.Blocks(title="Ditto") as demo: gr.Markdown( """# Ditto: Scaling Instruction-Based Video Editing with a High-Quality Synthetic Dataset
📄 Paper   |   🌐 Project Page   |   💻 Github Code   |   📦 Model Weights   |   📊 Dataset
Note1: The backend of this demo is comfy. Though it runs fast, please note that due to the use of quantized and distilled models, there may be some quality degradation. Note2: Considering the limited memory, please try test cases with lower resolution and frame count, otherwise it may cause out of memory error (you can also try re-running it). If you like this project, please consider starring the repo to motivate us. Thank you! """ ) with gr.Column(): with gr.Row(): vfile = gr.Video(label="Input Video", value=os.path.join("input", "dasha.mp4"), sources="upload", interactive=True) out_video = gr.Video(label="Result") prompt = gr.Textbox(label="Editing Instruction", value="Make it in the style of Japanese anime") with gr.Row(): width = gr.Number(label="Width", value=576, precision=0) height = gr.Number(label="Height", value=324, precision=0) fps = gr.Number(label="FPS", value=20, precision=0) frame_count = gr.Number(label="Frame Count", value=49, precision=0) run_btn = gr.Button("Run", variant="primary") run_btn.click( fn=gradio_infer, inputs=[vfile, prompt, width, height, fps, frame_count], outputs=[out_video] ) examples = [ [ os.path.join("input", "dasha.mp4"), "Add some fire and flame to the background", 576, 324, 20, 49 ], [ os.path.join("input", "dasha.mp4"), "Add some snow and flakes to the background", 576, 324, 20, 49 ], [ os.path.join("input", "dasha.mp4"), "Make it in the style of pencil sketch", 576, 324, 20, 49 ], ] gr.Examples( examples=examples, inputs=[vfile, prompt, width, height, fps, frame_count], label="Examples" ) return demo if __name__ == "__main__": demo = build_interface() demo.launch()