|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from PIL import Image |
|
|
|
|
|
import wan |
|
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS |
|
|
from wan.distributed.util import init_distributed_group |
|
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander |
|
|
from wan.utils.utils import cache_video, str2bool |
|
|
|
|
|
EXAMPLE_PROMPT = { |
|
|
"t2v-A14B": { |
|
|
"prompt": |
|
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", |
|
|
}, |
|
|
"i2v-A14B": { |
|
|
"prompt": |
|
|
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", |
|
|
"image": |
|
|
"examples/i2v_input.JPG", |
|
|
}, |
|
|
"ti2v-5B": { |
|
|
"prompt": |
|
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def _validate_args(args): |
|
|
|
|
|
assert args.ckpt_dir is not None, "Please specify the checkpoint directory." |
|
|
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" |
|
|
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" |
|
|
|
|
|
if args.prompt is None: |
|
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] |
|
|
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: |
|
|
args.image = EXAMPLE_PROMPT[args.task]["image"] |
|
|
|
|
|
if args.task == "i2v-A14B": |
|
|
assert args.image is not None, "Please specify the image path for i2v." |
|
|
|
|
|
cfg = WAN_CONFIGS[args.task] |
|
|
|
|
|
if args.sample_steps is None: |
|
|
args.sample_steps = cfg.sample_steps |
|
|
|
|
|
if args.sample_shift is None: |
|
|
args.sample_shift = cfg.sample_shift |
|
|
|
|
|
if args.sample_guide_scale is None: |
|
|
args.sample_guide_scale = cfg.sample_guide_scale |
|
|
|
|
|
if args.frame_num is None: |
|
|
args.frame_num = cfg.frame_num |
|
|
|
|
|
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( |
|
|
0, sys.maxsize) |
|
|
|
|
|
assert args.size in SUPPORTED_SIZES[ |
|
|
args. |
|
|
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" |
|
|
|
|
|
|
|
|
def _parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate a image or video from a text prompt or image using Wan" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--task", |
|
|
type=str, |
|
|
default="t2v-A14B", |
|
|
choices=list(WAN_CONFIGS.keys()), |
|
|
help="The task to run.") |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
type=str, |
|
|
default="1280*720", |
|
|
choices=list(SIZE_CONFIGS.keys()), |
|
|
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frame_num", |
|
|
type=int, |
|
|
default=None, |
|
|
help="How many frames of video are generated. The number should be 4n+1" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ckpt_dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The path to the checkpoint directory.") |
|
|
parser.add_argument( |
|
|
"--offload_model", |
|
|
type=str2bool, |
|
|
default=None, |
|
|
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ulysses_size", |
|
|
type=int, |
|
|
default=1, |
|
|
help="The size of the ulysses parallelism in DiT.") |
|
|
parser.add_argument( |
|
|
"--t5_fsdp", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to use FSDP for T5.") |
|
|
parser.add_argument( |
|
|
"--t5_cpu", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to place T5 model on CPU.") |
|
|
parser.add_argument( |
|
|
"--dit_fsdp", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to use FSDP for DiT.") |
|
|
parser.add_argument( |
|
|
"--save_file", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The file to save the generated video to.") |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The prompt to generate the video from.") |
|
|
parser.add_argument( |
|
|
"--use_prompt_extend", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to use prompt extend.") |
|
|
parser.add_argument( |
|
|
"--prompt_extend_method", |
|
|
type=str, |
|
|
default="local_qwen", |
|
|
choices=["dashscope", "local_qwen"], |
|
|
help="The prompt extend method to use.") |
|
|
parser.add_argument( |
|
|
"--prompt_extend_model", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The prompt extend model to use.") |
|
|
parser.add_argument( |
|
|
"--prompt_extend_target_lang", |
|
|
type=str, |
|
|
default="zh", |
|
|
choices=["zh", "en"], |
|
|
help="The target language of prompt extend.") |
|
|
parser.add_argument( |
|
|
"--base_seed", |
|
|
type=int, |
|
|
default=-1, |
|
|
help="The seed to use for generating the video.") |
|
|
parser.add_argument( |
|
|
"--image", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The image to generate the video from.") |
|
|
parser.add_argument( |
|
|
"--sample_solver", |
|
|
type=str, |
|
|
default='unipc', |
|
|
choices=['unipc', 'dpm++'], |
|
|
help="The solver used to sample.") |
|
|
parser.add_argument( |
|
|
"--sample_steps", type=int, default=None, help="The sampling steps.") |
|
|
parser.add_argument( |
|
|
"--sample_shift", |
|
|
type=float, |
|
|
default=None, |
|
|
help="Sampling shift factor for flow matching schedulers.") |
|
|
parser.add_argument( |
|
|
"--sample_guide_scale", |
|
|
type=float, |
|
|
default=None, |
|
|
help="Classifier free guidance scale.") |
|
|
parser.add_argument( |
|
|
"--convert_model_dtype", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to convert model paramerters dtype.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
_validate_args(args) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def _init_logging(rank): |
|
|
|
|
|
if rank == 0: |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="[%(asctime)s] %(levelname)s: %(message)s", |
|
|
handlers=[logging.StreamHandler(stream=sys.stdout)]) |
|
|
else: |
|
|
logging.basicConfig(level=logging.ERROR) |
|
|
|
|
|
|
|
|
def generate(args): |
|
|
rank = int(os.getenv("RANK", 0)) |
|
|
world_size = int(os.getenv("WORLD_SIZE", 1)) |
|
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
|
device = local_rank |
|
|
_init_logging(rank) |
|
|
|
|
|
if args.offload_model is None: |
|
|
args.offload_model = False if world_size > 1 else True |
|
|
logging.info( |
|
|
f"offload_model is not specified, set to {args.offload_model}.") |
|
|
if world_size > 1: |
|
|
torch.cuda.set_device(local_rank) |
|
|
dist.init_process_group( |
|
|
backend="nccl", |
|
|
init_method="env://", |
|
|
rank=rank, |
|
|
world_size=world_size) |
|
|
else: |
|
|
assert not ( |
|
|
args.t5_fsdp or args.dit_fsdp |
|
|
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
|
|
assert not ( |
|
|
args.ulysses_size > 1 |
|
|
), f"sequence parallel are not supported in non-distributed environments." |
|
|
|
|
|
if args.ulysses_size > 1: |
|
|
assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size." |
|
|
init_distributed_group() |
|
|
|
|
|
if args.use_prompt_extend: |
|
|
if args.prompt_extend_method == "dashscope": |
|
|
prompt_expander = DashScopePromptExpander( |
|
|
model_name=args.prompt_extend_model, |
|
|
task=args.task, |
|
|
is_vl=args.image is not None) |
|
|
elif args.prompt_extend_method == "local_qwen": |
|
|
prompt_expander = QwenPromptExpander( |
|
|
model_name=args.prompt_extend_model, |
|
|
task=args.task, |
|
|
is_vl=args.image is not None, |
|
|
device=rank) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Unsupport prompt_extend_method: {args.prompt_extend_method}") |
|
|
|
|
|
cfg = WAN_CONFIGS[args.task] |
|
|
if args.ulysses_size > 1: |
|
|
assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." |
|
|
|
|
|
logging.info(f"Generation job args: {args}") |
|
|
logging.info(f"Generation model config: {cfg}") |
|
|
|
|
|
if dist.is_initialized(): |
|
|
base_seed = [args.base_seed] if rank == 0 else [None] |
|
|
dist.broadcast_object_list(base_seed, src=0) |
|
|
args.base_seed = base_seed[0] |
|
|
|
|
|
logging.info(f"Input prompt: {args.prompt}") |
|
|
img = None |
|
|
if args.image is not None: |
|
|
img = Image.open(args.image).convert("RGB") |
|
|
logging.info(f"Input image: {args.image}") |
|
|
|
|
|
|
|
|
if args.use_prompt_extend: |
|
|
logging.info("Extending prompt ...") |
|
|
if rank == 0: |
|
|
prompt_output = prompt_expander( |
|
|
args.prompt, |
|
|
image=img, |
|
|
tar_lang=args.prompt_extend_target_lang, |
|
|
seed=args.base_seed) |
|
|
if prompt_output.status == False: |
|
|
logging.info( |
|
|
f"Extending prompt failed: {prompt_output.message}") |
|
|
logging.info("Falling back to original prompt.") |
|
|
input_prompt = args.prompt |
|
|
else: |
|
|
input_prompt = prompt_output.prompt |
|
|
input_prompt = [input_prompt] |
|
|
else: |
|
|
input_prompt = [None] |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast_object_list(input_prompt, src=0) |
|
|
args.prompt = input_prompt[0] |
|
|
logging.info(f"Extended prompt: {args.prompt}") |
|
|
|
|
|
if "t2v" in args.task: |
|
|
logging.info("Creating WanT2V pipeline.") |
|
|
wan_t2v = wan.WanT2V( |
|
|
config=cfg, |
|
|
checkpoint_dir=args.ckpt_dir, |
|
|
device_id=device, |
|
|
rank=rank, |
|
|
t5_fsdp=args.t5_fsdp, |
|
|
dit_fsdp=args.dit_fsdp, |
|
|
use_sp=(args.ulysses_size > 1), |
|
|
t5_cpu=args.t5_cpu, |
|
|
convert_model_dtype=args.convert_model_dtype, |
|
|
) |
|
|
|
|
|
logging.info(f"Generating video ...") |
|
|
video = wan_t2v.generate( |
|
|
args.prompt, |
|
|
size=SIZE_CONFIGS[args.size], |
|
|
frame_num=args.frame_num, |
|
|
shift=args.sample_shift, |
|
|
sample_solver=args.sample_solver, |
|
|
sampling_steps=args.sample_steps, |
|
|
guide_scale=args.sample_guide_scale, |
|
|
seed=args.base_seed, |
|
|
offload_model=args.offload_model) |
|
|
elif "ti2v" in args.task: |
|
|
logging.info("Creating WanTI2V pipeline.") |
|
|
wan_ti2v = wan.WanTI2V( |
|
|
config=cfg, |
|
|
checkpoint_dir=args.ckpt_dir, |
|
|
device_id=device, |
|
|
rank=rank, |
|
|
t5_fsdp=args.t5_fsdp, |
|
|
dit_fsdp=args.dit_fsdp, |
|
|
use_sp=(args.ulysses_size > 1), |
|
|
t5_cpu=args.t5_cpu, |
|
|
convert_model_dtype=args.convert_model_dtype, |
|
|
) |
|
|
|
|
|
logging.info(f"Generating video ...") |
|
|
video = wan_ti2v.generate( |
|
|
args.prompt, |
|
|
img=img, |
|
|
size=SIZE_CONFIGS[args.size], |
|
|
max_area=MAX_AREA_CONFIGS[args.size], |
|
|
frame_num=args.frame_num, |
|
|
shift=args.sample_shift, |
|
|
sample_solver=args.sample_solver, |
|
|
sampling_steps=args.sample_steps, |
|
|
guide_scale=args.sample_guide_scale, |
|
|
seed=args.base_seed, |
|
|
offload_model=args.offload_model) |
|
|
else: |
|
|
logging.info("Creating WanI2V pipeline.") |
|
|
wan_i2v = wan.WanI2V( |
|
|
config=cfg, |
|
|
checkpoint_dir=args.ckpt_dir, |
|
|
device_id=device, |
|
|
rank=rank, |
|
|
t5_fsdp=args.t5_fsdp, |
|
|
dit_fsdp=args.dit_fsdp, |
|
|
use_sp=(args.ulysses_size > 1), |
|
|
t5_cpu=args.t5_cpu, |
|
|
convert_model_dtype=args.convert_model_dtype, |
|
|
) |
|
|
|
|
|
logging.info("Generating video ...") |
|
|
video = wan_i2v.generate( |
|
|
args.prompt, |
|
|
img, |
|
|
max_area=MAX_AREA_CONFIGS[args.size], |
|
|
frame_num=args.frame_num, |
|
|
shift=args.sample_shift, |
|
|
sample_solver=args.sample_solver, |
|
|
sampling_steps=args.sample_steps, |
|
|
guide_scale=args.sample_guide_scale, |
|
|
seed=args.base_seed, |
|
|
offload_model=args.offload_model) |
|
|
|
|
|
if rank == 0: |
|
|
if args.save_file is None: |
|
|
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
formatted_prompt = args.prompt.replace(" ", "_").replace("/", |
|
|
"_")[:50] |
|
|
suffix = '.mp4' |
|
|
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix |
|
|
|
|
|
logging.info(f"Saving generated video to {args.save_file}") |
|
|
cache_video( |
|
|
tensor=video[None], |
|
|
save_file=args.save_file, |
|
|
fps=cfg.sample_fps, |
|
|
nrow=1, |
|
|
normalize=True, |
|
|
value_range=(-1, 1)) |
|
|
del video |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
logging.info("Finished.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = _parse_args() |
|
|
generate(args) |
|
|
|