import argparse import datetime import logging import inspect import math import os import random import gc import copy import imageio import numpy as np from PIL import Image from scipy.stats import anderson from typing import Dict, Optional, Tuple, List from omegaconf import OmegaConf import torch import torch.nn.functional as F import torch.utils.checkpoint from torchvision import transforms from torchvision.transforms import ToTensor from tqdm.auto import tqdm from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed import transformers from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers.models.clip.modeling_clip import CLIPEncoder import diffusers from diffusers.models import AutoencoderKL from diffusers import DDIMScheduler, TextToVideoSDPipeline from diffusers.optimization import get_scheduler from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention_processor import AttnProcessor2_0, Attention from diffusers.models.attention import BasicTransformerBlock from diffusers import StableVideoDiffusionPipeline from diffusers.models.lora import LoRALinearLayer from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image from diffusers.utils.import_utils import is_xformers_available from diffusers.models.unet_3d_blocks import \ (CrossAttnDownBlockSpatioTemporal, DownBlockSpatioTemporal, CrossAttnUpBlockSpatioTemporal, UpBlockSpatioTemporal) from i2vedit.utils.dataset import VideoJsonDataset, SingleVideoDataset, \ ImageDataset, VideoFolderDataset, CachedDataset, \ pad_with_ratio, return_to_original_res from einops import rearrange, repeat from i2vedit.utils.lora_handler import LoraHandler from i2vedit.utils.lora import extract_lora_child_module from i2vedit.utils.euler_utils import euler_inversion from i2vedit.utils.svd_util import SmoothAreaRandomDetection from i2vedit.data import VideoIO, SingleClipDataset, ResolutionControl #from utils.model_utils import load_primary_models from i2vedit.utils.euler_utils import inverse_video from i2vedit.train import train_motion_lora, load_images_from_list from i2vedit.inference import initialize_pipeline from i2vedit.utils.model_utils import P2PEulerDiscreteScheduler, P2PStableVideoDiffusionPipeline from i2vedit.prompt_attention import attention_util def create_output_folders(output_dir, config): os.makedirs(output_dir, exist_ok=True) OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) return output_dir def main( pretrained_model_path: str, data_params: Dict, train_motion_lora_params: Dict, sarp_params: Dict, attention_matching_params: Dict, long_video_params: Dict = {"mode": "skip-interval"}, use_sarp: bool = True, use_motion_lora: bool = True, train_motion_lora_only: bool = False, retrain_motion_lora: bool = True, use_inversed_latents: bool = True, use_attention_matching: bool = True, use_consistency_attention_control: bool = False, output_dir: str = "./outputs", num_steps: int = 25, device: str = "cuda", seed: int = 23, enable_xformers_memory_efficient_attention: bool = True, enable_torch_2_attn: bool = False, dtype: str = 'fp16', load_from_last_frames_latents: List[str] = None, save_last_frames: bool = True, visualize_attention_store: bool = False, visualize_attention_store_steps: List[int] = None, use_latent_blend: bool = False, use_previous_latent_for_train: bool = False, use_latent_noise: bool = True, load_from_previous_consistency_store_controller: str = None, load_from_previous_consistency_edit_controller: List[str] = None ): *_, config = inspect.getargvalues(inspect.currentframe()) if dtype == "fp16": dtype = torch.float16 elif dtype == "fp32": dtype = torch.float32 # create folder output_dir = create_output_folders(output_dir, config) # prepare video data data_params["output_dir"] = output_dir data_params["device"] = device videoio = VideoIO(**data_params, dtype=dtype) # smooth area random perturbation if use_sarp: sard = SmoothAreaRandomDetection(device, dtype=torch.float32) else: sard = None keyframe = None previous_last_frames = load_images_from_list(data_params.keyframe_paths) consistency_train_controller = None if load_from_last_frames_latents is not None: previous_last_frames_latents = [torch.load(thpath).to(device) for thpath in load_from_last_frames_latents] else: previous_last_frames_latents = [None,] * len(previous_last_frames) if use_consistency_attention_control and load_from_previous_consistency_store_controller is not None: previous_consistency_store_controller = attention_util.ConsistencyAttentionControl( additional_attention_store=None, use_inversion_attention=False, save_self_attention=True, save_latents=False, disk_store=True, load_attention_store=os.path.join(load_from_previous_consistency_store_controller, "clip_0") ) else: previous_consistency_store_controller = None previous_consistency_edit_controller_list = [None,] * len(previous_last_frames) if use_consistency_attention_control and load_from_previous_consistency_edit_controller is not None: for i in range(len(load_from_previous_consistency_edit_controller)): previous_consistency_edit_controller_list[i] = attention_util.ConsistencyAttentionControl( additional_attention_store=None, use_inversion_attention=False, save_self_attention=True, save_latents=False, disk_store=True, load_attention_store=os.path.join(load_from_previous_consistency_edit_controller[i], "clip_0") ) # read data and process for clip_id, video in enumerate(videoio.read_video_iter()): if clip_id >= data_params.get("end_clip_id", 9): break if clip_id < data_params.get("begin_clip_id", 0): continue video = video.unsqueeze(0) resctrl = ResolutionControl(video.shape[-2:], data_params.output_res, data_params.pad_to_fit, fill=-1) # update keyframe and edited keyframe if long_video_params.mode == "skip-interval": assert data_params.overlay_size > 0 # save the first frame as the keyframe for cross-attention #if clip_id == 0: firstframe = video[:,0:1,:,:,:] keyframe = video[:,0:1,:,:,:] edited_keyframes = copy.deepcopy(previous_last_frames) edited_firstframes = edited_keyframes #edited_firstframes = load_images_from_list(data_params.keyframe_paths) elif long_video_params.mode == "auto-regressive": assert data_params.overlay_size == 1 firstframe = video[:,0:1,:,:,:] keyframe = video[:,0:1,:,:,:] edited_keyframes = copy.deepcopy(previous_last_frames) edited_firstframes = edited_keyframes # register for unet, perform inversion load_attention_store = None if use_attention_matching: assert use_inversed_latents, "inversion is disabled." if attention_matching_params.get("load_attention_store") is not None: load_attention_store = os.path.join(attention_matching_params.get("load_attention_store"), f"clip_{clip_id}") if not os.path.exists(load_attention_store): print(f"Load {load_attention_store} failed, folder doesn't exists.") load_attention_store = None store_controller = attention_util.AttentionStore( disk_store=attention_matching_params.disk_store, save_latents = use_latent_blend, save_self_attention=True, load_attention_store=load_attention_store, store_path=os.path.join(output_dir, "attention_store", f"clip_{clip_id}") ) print("store_controller.store_dir:", store_controller.store_dir) else: store_controller = None load_consistency_attention_store = None if use_consistency_attention_control: if clip_id==0 and attention_matching_params.get("load_consistency_attention_store") is not None: load_consistency_attention_store = os.path.join(attention_matching_params.get("load_consistency_attention_store"), f"clip_{clip_id}") if not os.path.exists(load_consistency_attention_store): print(f"Load {load_consistency_attention_store} failed, folder doesn't exists.") load_consistency_attention_store = None consistency_store_controller = attention_util.ConsistencyAttentionControl( additional_attention_store=previous_consistency_store_controller, use_inversion_attention=False, save_self_attention=(clip_id==0), load_attention_store=load_consistency_attention_store, save_latents=False, disk_store=True, store_path=os.path.join(output_dir, "consistency_attention_store", f"clip_{clip_id}") ) print("consistency_store_controller.store_dir:", consistency_store_controller.store_dir) else: consistency_store_controller = None if train_motion_lora_only: assert use_motion_lora and retrain_motion_lora, "use_motion_lora/retrain_motion_lora should be enbled to train motion lora only." # perform smooth area random perturbation if use_inversed_latents: print("begin inversion sampling for inference...") inversion_noise = inverse_video( pretrained_model_path, video, keyframe, firstframe, num_steps, resctrl, sard, enable_xformers_memory_efficient_attention, enable_torch_2_attn, store_controller = store_controller, consistency_store_controller = consistency_store_controller, find_modules=attention_matching_params.registered_modules if load_attention_store is None else {}, consistency_find_modules=long_video_params.registered_modules if load_consistency_attention_store is None else {}, # dtype=dtype, **sarp_params, ) else: if use_motion_lora and retrain_motion_lora: assert not any([np > 0 for np in train_motion_lora_params.validation_data.noise_prior]), "inversion noise is not calculated but validation during motion lora training aims to use inversion noise as input latents." inversion_noise = None if use_motion_lora: if retrain_motion_lora: if use_consistency_attention_control: if data_params.output_res[0] != train_motion_lora_params.train_data.height or \ data_params.output_res[1] != train_motion_lora_params.train_data.width: if consistency_train_controller is None: load_consistency_train_attention_store = None if attention_matching_params.get("load_consistency_train_attention_store") is not None: load_consistency_train_attention_store = os.path.join(attention_matching_params.get("load_consistency_train_attention_store"), f"clip_0") if not os.path.exists(load_consistency_train_attention_store): print(f"Load {load_consistency_train_attention_store} failed, folder doesn't exists.") load_consistency_train_attention_store = None if load_consistency_train_attention_store is None and clip_id > 0: raise IOError(f"load_consistency_train_attention_store can't be None for clip {clip_id}.") consistency_train_controller = attention_util.ConsistencyAttentionControl( additional_attention_store=None, use_inversion_attention=False, save_self_attention=True, load_attention_store=load_consistency_train_attention_store, save_latents=False, disk_store=True, store_path=os.path.join(output_dir, "consistency_train_attention_store", "clip_0") ) print("consistency_train_controller.store_dir:", consistency_train_controller.store_dir) resctrl_train = ResolutionControl( video.shape[-2:], (train_motion_lora_params.train_data.height,train_motion_lora_params.train_data.width), data_params.pad_to_fit, fill=-1 ) print("begin inversion sampling for training...") inversion_noise_train = inverse_video( pretrained_model_path, video, keyframe, firstframe, num_steps, resctrl_train, sard, enable_xformers_memory_efficient_attention, enable_torch_2_attn, store_controller = None, consistency_store_controller = consistency_train_controller, find_modules={}, consistency_find_modules=long_video_params.registered_modules if long_video_params.get("load_attention_store") is None else {}, # dtype=dtype, **sarp_params, ) else: if consistency_train_controller is None: consistency_train_controller = consistency_store_controller else: consistency_train_controller = None if retrain_motion_lora: train_dataset = SingleClipDataset( inversion_noise=inversion_noise, video_clip=video, keyframe=((ToTensor()(previous_last_frames[0])-0.5)/0.5).unsqueeze(0).unsqueeze(0) if use_previous_latent_for_train else keyframe, keyframe_latent=previous_last_frames_latents[0] if use_previous_latent_for_train else None, firstframe=firstframe, height=train_motion_lora_params.train_data.height, width=train_motion_lora_params.train_data.width, use_data_aug=train_motion_lora_params.train_data.get("use_data_aug"), pad_to_fit=train_motion_lora_params.train_data.get("pad_to_fit", False) ) train_motion_lora_params.validation_data.num_inference_steps = num_steps train_motion_lora( pretrained_model_path, output_dir, train_dataset, edited_firstframes=edited_firstframes, validation_images=edited_keyframes, validation_images_latents=previous_last_frames_latents, seed=seed, clip_id=clip_id, consistency_edit_controller_list=previous_consistency_edit_controller_list, consistency_controller=consistency_train_controller if clip_id!=0 else None, consistency_find_modules=long_video_params.registered_modules, enable_xformers_memory_efficient_attention=enable_xformers_memory_efficient_attention, enable_torch_2_attn=enable_torch_2_attn, **train_motion_lora_params ) if train_motion_lora_only: if not use_consistency_attention_control: continue # choose and load motion lora best_checkpoint_index = attention_matching_params.get("best_checkpoint_index", 250) if retrain_motion_lora: lora_dir = f"{os.path.join(output_dir,'train_motion_lora')}/clip_{clip_id}" lora_path = f"{lora_dir}/checkpoint-{best_checkpoint_index}/temporal/lora" else: lora_path = f"/homw/user/app/upload/lora" assert os.path.exists(lora_path), f"lora path: {lora_path} doesn't exist!" lora_rank = train_motion_lora_params.lora_rank lora_scale = attention_matching_params.get("lora_scale", 1.0) # prepare models pipe = initialize_pipeline( pretrained_model_path, device, enable_xformers_memory_efficient_attention, enable_torch_2_attn, lora_path, lora_rank, lora_scale, load_spatial_lora = False #(clip_id != 0) ).to(device, dtype=dtype) else: pipe = P2PStableVideoDiffusionPipeline.from_pretrained( pretrained_model_path ).to(device, dtype=dtype) if use_attention_matching or use_consistency_attention_control: pipe.scheduler = P2PEulerDiscreteScheduler.from_config(pipe.scheduler.config) generator = torch.Generator(device="cpu") generator.manual_seed(seed) previous_last_frames = [] editing_params = [item for name, item in attention_matching_params.params.items()] with torch.no_grad(): with torch.autocast(device, dtype=dtype): for kf_id, (edited_keyframe, editing_param) in enumerate(zip(edited_keyframes, editing_params)): print(kf_id, editing_param) # control resolution iw, ih = edited_keyframe.size resctrl = ResolutionControl( (ih, iw), data_params.output_res, data_params.pad_to_fit, fill=0 ) edited_keyframe = resctrl(edited_keyframe) edited_firstframe = resctrl(edited_firstframes[kf_id]) # control attention pipe.scheduler.controller = [] if use_attention_matching: edit_controller = attention_util.AttentionControlEdit( num_steps = num_steps, cross_replace_steps = attention_matching_params.cross_replace_steps, temporal_self_replace_steps = attention_matching_params.temporal_self_replace_steps, spatial_self_replace_steps = attention_matching_params.spatial_self_replace_steps, mask_thr = editing_param.get("mask_thr", 0.35), temporal_step_thr = editing_param.get("temporal_step_thr", [0.5,0.8]), control_mode = attention_matching_params.control_mode, spatial_attention_chunk_size = attention_matching_params.get("spatial_attention_chunk_size", 1), additional_attention_store = store_controller, use_inversion_attention = True, save_self_attention = False, save_latents = False, latent_blend = use_latent_blend, disk_store = attention_matching_params.disk_store ) pipe.scheduler.controller.append(edit_controller) else: edit_controller = None if use_consistency_attention_control: consistency_edit_controller = attention_util.ConsistencyAttentionControl( additional_attention_store=previous_consistency_edit_controller_list[kf_id], use_inversion_attention=False, save_self_attention=(clip_id==0), save_latents=False, disk_store=True, store_path=os.path.join(output_dir, f"consistency_edit{kf_id}_attention_store", f"clip_{clip_id}") ) pipe.scheduler.controller.append(consistency_edit_controller) else: consistency_edit_controller = None if use_attention_matching or use_consistency_attention_control: attention_util.register_attention_control( pipe.unet, edit_controller, consistency_edit_controller, find_modules=attention_matching_params.registered_modules, consistency_find_modules=long_video_params.registered_modules ) # should be reorganized to perform attention control edited_output = pipe( edited_keyframe, edited_firstframe=edited_firstframe, image_latents=previous_last_frames_latents[kf_id], width=data_params.output_res[1], height=data_params.output_res[0], num_frames=video.shape[1], num_inference_steps=num_steps, decode_chunk_size=8, motion_bucket_id=127, fps=data_params.output_fps, noise_aug_strength=0.02, max_guidance_scale=attention_matching_params.get("max_guidance_scale", 2.5), generator=generator, latents=inversion_noise ) edited_video = [img for sublist in edited_output.frames for img in sublist] edited_video_latents = edited_output.latents # callback to replace frames videoio.write_video(edited_video, kf_id, resctrl) # save previous frames if long_video_params.mode == "skip-interval": #previous_latents[kf_id] = edit_controller.get_all_last_latents(data_params.overlay_size) previous_last_frames.append( resctrl.callback(edited_video[-1]) ) if use_latent_noise: previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:] else: previous_last_frames_latents[kf_id] = None elif long_video_params.mode == "auto-regressive": previous_last_frames.append( resctrl.callback(edited_video[-1]) ) if use_latent_noise: previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:] else: previous_last_frames_latents[kf_id] = None # save last frames for convenient if save_last_frames: try: fname = os.path.join(output_dir, f"clip_{clip_id}_lastframe_{kf_id}") previous_last_frames[kf_id].save(fname+".png") if use_latent_noise: torch.save(previous_last_frames_latents[kf_id], fname+".pt") except: print("save fail") if use_attention_matching or use_consistency_attention_control: attention_util.register_attention_control( pipe.unet, edit_controller, consistency_edit_controller, find_modules=attention_matching_params.registered_modules, consistency_find_modules=long_video_params.registered_modules, undo=True ) if edit_controller is not None: if visualize_attention_store: vis_save_path = os.path.join(output_dir, "visualization", f"{kf_id}", f"clip_{clip_id}") os.makedirs(vis_save_path, exist_ok=True) attention_util.show_avg_difference_maps( edit_controller, save_path = vis_save_path ) assert visualize_attention_store_steps is not None attention_util.show_self_attention( edit_controller, steps = visualize_attention_store_steps, save_path = vis_save_path, inversed = False ) edit_controller.delete() del edit_controller if use_consistency_attention_control: if clip_id == 0: previous_consistency_edit_controller_list[kf_id] = consistency_edit_controller else: consistency_edit_controller.delete() del consistency_edit_controller print(f"previous_consistency_edit_controller_list[{kf_id}]", previous_consistency_edit_controller_list[kf_id].store_dir) if use_attention_matching: del store_controller if use_consistency_attention_control and clip_id == 0: previous_consistency_store_controller = consistency_store_controller videoio.close() if use_consistency_attention_control: print("consistency_store_controller for clip 0:", previous_consistency_store_controller.store_dir) if retrain_motion_lora: print("consistency_train_controller for clip 0:", consistency_train_controller.store_dir) for kf_id in range(len(previous_consistency_edit_controller_list)): print(f"previous_consistency_edit_controller_list[{kf_id}]:", previous_consistency_edit_controller_list[kf_id].store_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default='./configs/svdedit/item2_2.yaml') args = parser.parse_args() main(**OmegaConf.load(args.config))