Spaces:
Paused
Paused
init
Browse files- .gitattributes +2 -0
- .gitignore +12 -0
- README.md +4 -4
- app.py +173 -0
- config/customize_subsequent_edit.yaml +152 -0
- config/customize_train.yaml +149 -0
- config/customize_train_multi.yaml +149 -0
- i2vedit/__init__.py +0 -0
- i2vedit/data.py +317 -0
- i2vedit/inference.py +89 -0
- i2vedit/prompt_attention/__init__.py +0 -0
- i2vedit/prompt_attention/attention_register.py +250 -0
- i2vedit/prompt_attention/attention_store.py +305 -0
- i2vedit/prompt_attention/attention_util.py +621 -0
- i2vedit/prompt_attention/common/__init__.py +0 -0
- i2vedit/prompt_attention/common/image_util.py +192 -0
- i2vedit/prompt_attention/common/instantiate_from_config.py +33 -0
- i2vedit/prompt_attention/common/logger.py +17 -0
- i2vedit/prompt_attention/common/set_seed.py +28 -0
- i2vedit/prompt_attention/common/util.py +73 -0
- i2vedit/prompt_attention/ptp_utils.py +199 -0
- i2vedit/prompt_attention/visualization.py +391 -0
- i2vedit/train.py +1488 -0
- i2vedit/utils/__init__.py +0 -0
- i2vedit/utils/bucketing.py +32 -0
- i2vedit/utils/dataset.py +705 -0
- i2vedit/utils/euler_utils.py +226 -0
- i2vedit/utils/lora.py +1493 -0
- i2vedit/utils/lora_handler.py +270 -0
- i2vedit/utils/model_utils.py +588 -0
- i2vedit/utils/svd_util.py +397 -0
- i2vedit/version.py +5 -0
- main.py +595 -0
- mydata/source_and_edits/source.mp4 +3 -0
- mydata/source_and_edits/white.jpg +3 -0
- req.txt +165 -0
- requirements.txt +176 -0
- setup.py +73 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
mydata/source_and_edits/source.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
mydata/source_and_edits/white.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cache
|
| 2 |
+
ckpts
|
| 3 |
+
trash
|
| 4 |
+
outputs
|
| 5 |
+
run.sh
|
| 6 |
+
__pycache__
|
| 7 |
+
*_tmp
|
| 8 |
+
*.mp4
|
| 9 |
+
*.png
|
| 10 |
+
i2vedit.egg-info/
|
| 11 |
+
!mydata/**/*.mp4
|
| 12 |
+
customize_train_local.yaml
|
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: I2vedit
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 1 |
---
|
| 2 |
title: I2vedit
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.32.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
app.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
target_paths = {
|
| 8 |
+
"video": "/home/user/app/upload/source_and_edits/source.mp4",
|
| 9 |
+
"image": "/home/user/app/upload/source_and_edits/ref.jpg",
|
| 10 |
+
"config": "/home/user/app/upload/config/customize_train.yaml",
|
| 11 |
+
"lora": "/homw/user/app/upload/lora/lora.pt",
|
| 12 |
+
"output_l": "/home/user/app/outputs/train_motion_lora",
|
| 13 |
+
"output_r": "/home/user/app/outputs/ref.mp4",
|
| 14 |
+
"zip": "/home/user/app/outputs/train_motion_lora.zip",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def zip_outputs():
|
| 19 |
+
if os.path.exists(target_paths["zip"]):
|
| 20 |
+
os.remove(target_paths["zip"])
|
| 21 |
+
shutil.make_archive(target_paths["zip"].replace(".zip", ""), 'zip', root_dir=target_paths["output_l"])
|
| 22 |
+
return target_paths["zip"]
|
| 23 |
+
|
| 24 |
+
def output_video():
|
| 25 |
+
if os.path.exists(target_paths["output_r"]):
|
| 26 |
+
return target_paths["output_r"]
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def start_training_stream():
|
| 31 |
+
process = subprocess.Popen(
|
| 32 |
+
["python", "main.py", "--config=" + target_paths["config"]],
|
| 33 |
+
stdout=subprocess.PIPE,
|
| 34 |
+
stderr=subprocess.STDOUT,
|
| 35 |
+
text=True,
|
| 36 |
+
bufsize=1,
|
| 37 |
+
universal_newlines=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
output = []
|
| 41 |
+
for line in process.stdout:
|
| 42 |
+
output.append(line)
|
| 43 |
+
yield "".join(output)
|
| 44 |
+
|
| 45 |
+
def install_i2vedit():
|
| 46 |
+
try:
|
| 47 |
+
import i2vedit
|
| 48 |
+
print("i2vedit already installed")
|
| 49 |
+
except ImportError:
|
| 50 |
+
print("Installing i2vedit...")
|
| 51 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./i2vedit"])
|
| 52 |
+
print("i2vedit installed")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def install_package(package_name):
|
| 56 |
+
try:
|
| 57 |
+
result = subprocess.run(
|
| 58 |
+
[sys.executable, "-m", "pip", "install", package_name],
|
| 59 |
+
stdout=subprocess.PIPE,
|
| 60 |
+
stderr=subprocess.PIPE,
|
| 61 |
+
text=True,
|
| 62 |
+
)
|
| 63 |
+
output = result.stdout + "\n" + result.stderr
|
| 64 |
+
return output
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return f"Error: {str(e)}"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def show_package(pkg_name):
|
| 70 |
+
try:
|
| 71 |
+
result = subprocess.run(
|
| 72 |
+
[sys.executable, "-m", "pip", "show", pkg_name],
|
| 73 |
+
stdout=subprocess.PIPE,
|
| 74 |
+
stderr=subprocess.PIPE,
|
| 75 |
+
text=True,
|
| 76 |
+
)
|
| 77 |
+
return result.stdout if result.stdout else result.stderr
|
| 78 |
+
except Exception as e:
|
| 79 |
+
return str(e)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def uninstall_package(package_name):
|
| 83 |
+
try:
|
| 84 |
+
result = subprocess.run(
|
| 85 |
+
[sys.executable, "-m", "pip", "uninstall", package_name, "-y"],
|
| 86 |
+
stdout=subprocess.PIPE,
|
| 87 |
+
stderr=subprocess.PIPE,
|
| 88 |
+
text=True,
|
| 89 |
+
)
|
| 90 |
+
output = result.stdout + "\n" + result.stderr
|
| 91 |
+
return output
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return f"Error: {str(e)}"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def save_files(video_file, image_file, config_file, lora_file=None):
|
| 98 |
+
os.makedirs(os.path.dirname(target_paths["video"]), exist_ok=True)
|
| 99 |
+
os.makedirs(os.path.dirname(target_paths["config"]), exist_ok=True)
|
| 100 |
+
|
| 101 |
+
shutil.copy(video_file.name, target_paths["video"])
|
| 102 |
+
shutil.copy(image_file.name, target_paths["image"])
|
| 103 |
+
shutil.copy(config_file.name, target_paths["config"])
|
| 104 |
+
if lora_file:
|
| 105 |
+
os.makedirs(os.path.dirname(target_paths["lora"]), exist_ok=True)
|
| 106 |
+
shutil.copy(lora_file.name, target_paths["lora"])
|
| 107 |
+
return "檔案已成功上傳並儲存!"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
install_i2vedit()
|
| 111 |
+
install_package("huggingface_hub==0.25.1")
|
| 112 |
+
install_package("diffusers==0.25.1")
|
| 113 |
+
install_package("gradio==5.0.0")
|
| 114 |
+
uninstall_package("datasets")
|
| 115 |
+
print("package version set complete")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
with gr.Blocks(theme=gr.themes.Origin()) as demo:
|
| 119 |
+
gr.Markdown("## 請先上傳檔案")
|
| 120 |
+
with gr.Row():
|
| 121 |
+
video_input = gr.File(label="原始影片", file_types=[".mp4"])
|
| 122 |
+
image_input = gr.File(label="編輯圖像", file_types=[".jpg", ".jpeg", ".png"])
|
| 123 |
+
config_input = gr.File(label="Config 檔", file_types=[".yaml", ".yml"])
|
| 124 |
+
lora_input = gr.File(label="LoRA 檔案", file_types=[".pt"])
|
| 125 |
+
|
| 126 |
+
upload_button = gr.Button("上傳並儲存")
|
| 127 |
+
output = gr.Textbox(label="狀態")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
gr.Markdown("## Training")
|
| 131 |
+
with gr.Column():
|
| 132 |
+
log_output = gr.Textbox(label="Training Log", lines=20)
|
| 133 |
+
train_btn = gr.Button("Start Training")
|
| 134 |
+
|
| 135 |
+
gr.Markdown("## Pip Installer")
|
| 136 |
+
with gr.Column():
|
| 137 |
+
with gr.Row():
|
| 138 |
+
pkg_input = gr.Textbox(lines=1, placeholder="輸入想安裝的套件名稱,例如 diffusers 或 numpy==1.2.0")
|
| 139 |
+
install_output = gr.Textbox(label="Install Output", lines=10)
|
| 140 |
+
install_btn = gr.Button("Install Package")
|
| 141 |
+
|
| 142 |
+
gr.Markdown("## Pip Uninstaller")
|
| 143 |
+
with gr.Column():
|
| 144 |
+
with gr.Row():
|
| 145 |
+
pkg_input2 = gr.Textbox(lines=1, placeholder="輸入想解除安裝的套件名稱,例如 diffusers 或 numpy")
|
| 146 |
+
uninstall_output = gr.Textbox(label="Uninstall Output", lines=10)
|
| 147 |
+
uninstall_btn = gr.Button("Uninstall Package")
|
| 148 |
+
|
| 149 |
+
gr.Markdown("## Pip show")
|
| 150 |
+
with gr.Column():
|
| 151 |
+
with gr.Row():
|
| 152 |
+
show_input = gr.Textbox(label="輸入套件名稱(如 diffusers)")
|
| 153 |
+
show_output = gr.Textbox(label="套件資訊", lines=10)
|
| 154 |
+
show_btn = gr.Button("pip show")
|
| 155 |
+
|
| 156 |
+
gr.Markdown("## Download lora")
|
| 157 |
+
with gr.Column():
|
| 158 |
+
file_output = gr.File(label="點擊下載", interactive=True)
|
| 159 |
+
download_btn = gr.Button("下載lora")
|
| 160 |
+
|
| 161 |
+
gr.Markdown("## Download results")
|
| 162 |
+
with gr.Column():
|
| 163 |
+
file_output2 = gr.File(label="點擊下載", interactive=True)
|
| 164 |
+
download_btn2 = gr.Button("下載結果")
|
| 165 |
+
|
| 166 |
+
show_btn.click(fn=show_package, inputs=show_input, outputs=show_output)
|
| 167 |
+
download_btn.click(fn=zip_outputs, outputs=file_output)
|
| 168 |
+
download_btn2.click(fn=output_video, outputs=file_output2)
|
| 169 |
+
install_btn.click(fn=install_package, inputs=pkg_input, outputs=install_output)
|
| 170 |
+
train_btn.click(fn=start_training_stream, outputs=log_output)
|
| 171 |
+
uninstall_btn.click(fn=uninstall_package, inputs=pkg_input2, outputs=uninstall_output)
|
| 172 |
+
upload_button.click(fn=save_files,inputs=[video_input, image_input, config_input, lora_input],outputs=output)
|
| 173 |
+
demo.launch()
|
config/customize_subsequent_edit.yaml
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretrained diffusers model path.
|
| 2 |
+
pretrained_model_path: "ckpts/stable-video-diffusion-img2vid"
|
| 3 |
+
# The folder where your training outputs will be placed.
|
| 4 |
+
output_dir: "./acc"
|
| 5 |
+
seed: 23
|
| 6 |
+
num_steps: 25
|
| 7 |
+
# Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
|
| 8 |
+
enable_xformers_memory_efficient_attention: True
|
| 9 |
+
# Use scaled dot product attention (Only available with >= Torch 2.0)
|
| 10 |
+
enable_torch_2_attn: True
|
| 11 |
+
|
| 12 |
+
use_sarp: true
|
| 13 |
+
|
| 14 |
+
use_motion_lora: true
|
| 15 |
+
train_motion_lora_only: false
|
| 16 |
+
retrain_motion_lora: true
|
| 17 |
+
|
| 18 |
+
use_inversed_latents: true
|
| 19 |
+
use_attention_matching: true
|
| 20 |
+
use_consistency_attention_control: true
|
| 21 |
+
dtype: fp16
|
| 22 |
+
|
| 23 |
+
visualize_attention_store: false
|
| 24 |
+
visualize_attention_store_steps: [0, 5, 10, 15, 20, 24]
|
| 25 |
+
|
| 26 |
+
save_last_frames: True
|
| 27 |
+
load_from_last_frames_latents:
|
| 28 |
+
- "./cache/item1/i2vedit_2024-05-11T15-53-54/clip_0_lastframe_0.pt"
|
| 29 |
+
- "./cache/item1/i2vedit_2024-05-11T15-53-54/clip_0_lastframe_1.pt"
|
| 30 |
+
load_from_previous_consistency_edit_controller:
|
| 31 |
+
- "./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_edit0_attention_store"
|
| 32 |
+
- "./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_edit1_attention_store"
|
| 33 |
+
load_from_previous_consistency_store_controller:
|
| 34 |
+
"./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_attention_store"
|
| 35 |
+
|
| 36 |
+
# data_params
|
| 37 |
+
data_params:
|
| 38 |
+
video_path: "../datasets/svdedit/item1/source.mp4"
|
| 39 |
+
keyframe_paths:
|
| 40 |
+
- "../datasets/svdedit/tmp/edit0.png"
|
| 41 |
+
- "../datasets/svdedit/tmp/edit1.png"
|
| 42 |
+
start_t: 0
|
| 43 |
+
end_t: -1
|
| 44 |
+
sample_fps: 7
|
| 45 |
+
chunk_size: 16
|
| 46 |
+
overlay_size: 1
|
| 47 |
+
normalize: true
|
| 48 |
+
output_fps: 7
|
| 49 |
+
save_sampled_frame: true
|
| 50 |
+
output_res: [576, 1024]
|
| 51 |
+
pad_to_fit: false
|
| 52 |
+
begin_clip_id: 1
|
| 53 |
+
end_clip_id: 2
|
| 54 |
+
|
| 55 |
+
train_motion_lora_params:
|
| 56 |
+
cache_latents: true
|
| 57 |
+
cached_latent_dir: null #/path/to/cached_latents
|
| 58 |
+
lora_rank: 32
|
| 59 |
+
# Use LoRA for the UNET model.
|
| 60 |
+
use_unet_lora: True
|
| 61 |
+
# LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
|
| 62 |
+
# See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
|
| 63 |
+
lora_unet_dropout: 0.1
|
| 64 |
+
# The only time you want this off is if you're doing full LoRA training.
|
| 65 |
+
save_pretrained_model: False
|
| 66 |
+
# Learning rate for AdamW
|
| 67 |
+
learning_rate: 5e-4
|
| 68 |
+
# Weight decay. Higher = more regularization. Lower = closer to dataset.
|
| 69 |
+
adam_weight_decay: 1e-2
|
| 70 |
+
# Maximum number of train steps. Model is saved after training.
|
| 71 |
+
max_train_steps: 250
|
| 72 |
+
# Saves a model every nth step.
|
| 73 |
+
checkpointing_steps: 250
|
| 74 |
+
# How many steps to do for validation if sample_preview is enabled.
|
| 75 |
+
validation_steps: 300
|
| 76 |
+
# Whether or not we want to use mixed precision with accelerate
|
| 77 |
+
mixed_precision: "fp16"
|
| 78 |
+
# Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
|
| 79 |
+
# If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
|
| 80 |
+
gradient_checkpointing: True
|
| 81 |
+
image_encoder_gradient_checkpointing: True
|
| 82 |
+
|
| 83 |
+
train_data:
|
| 84 |
+
# The width and height in which you want your training data to be resized to.
|
| 85 |
+
width: 896
|
| 86 |
+
height: 512
|
| 87 |
+
# This will find the closest aspect ratio to your input width and height.
|
| 88 |
+
# For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
|
| 89 |
+
use_data_aug: ~ #"controlnet"
|
| 90 |
+
pad_to_fit: false
|
| 91 |
+
|
| 92 |
+
validation_data:
|
| 93 |
+
# Whether or not to sample preview during training (Requires more VRAM).
|
| 94 |
+
sample_preview: True
|
| 95 |
+
# The number of frames to sample during validation.
|
| 96 |
+
num_frames: 14
|
| 97 |
+
# Height and width of validation sample.
|
| 98 |
+
width: 1024
|
| 99 |
+
height: 576
|
| 100 |
+
pad_to_fit: false
|
| 101 |
+
# scale of spatial LoRAs, default is 0
|
| 102 |
+
spatial_scale: 0
|
| 103 |
+
# scale of noise prior, i.e. the scale of inversion noises
|
| 104 |
+
noise_prior:
|
| 105 |
+
- 0.0
|
| 106 |
+
#- 1.0
|
| 107 |
+
|
| 108 |
+
sarp_params:
|
| 109 |
+
sarp_noise_scale: 0.005
|
| 110 |
+
|
| 111 |
+
attention_matching_params:
|
| 112 |
+
best_checkpoint_index: 250
|
| 113 |
+
lora_scale: 1.0
|
| 114 |
+
# lora path
|
| 115 |
+
lora_dir: ~
|
| 116 |
+
max_guidance_scale: 2.0
|
| 117 |
+
disk_store: True
|
| 118 |
+
load_attention_store: "./cache/item1/attention_store"
|
| 119 |
+
load_consistency_attention_store: ~
|
| 120 |
+
registered_modules:
|
| 121 |
+
BasicTransformerBlock:
|
| 122 |
+
- "attn1"
|
| 123 |
+
#- "attn2"
|
| 124 |
+
TemporalBasicTransformerBlock:
|
| 125 |
+
- "attn1"
|
| 126 |
+
#- "attn2"
|
| 127 |
+
control_mode:
|
| 128 |
+
spatial_self: "masked_copy"
|
| 129 |
+
temporal_self: "copy_v2"
|
| 130 |
+
cross_replace_steps: 0.0
|
| 131 |
+
temporal_self_replace_steps: 1.0
|
| 132 |
+
spatial_self_replace_steps: 1.0
|
| 133 |
+
spatial_attention_chunk_size: 1
|
| 134 |
+
|
| 135 |
+
params:
|
| 136 |
+
edit0:
|
| 137 |
+
temporal_step_thr: [0.5, 0.8]
|
| 138 |
+
mask_thr: [0.35, 0.35]
|
| 139 |
+
edit1:
|
| 140 |
+
temporal_step_thr: [0.5, 0.8]
|
| 141 |
+
mask_thr: [0.35, 0.35]
|
| 142 |
+
|
| 143 |
+
long_video_params:
|
| 144 |
+
mode: "skip-interval"
|
| 145 |
+
registered_modules:
|
| 146 |
+
BasicTransformerBlock:
|
| 147 |
+
#- "attn1"
|
| 148 |
+
#- "attn2"
|
| 149 |
+
TemporalBasicTransformerBlock:
|
| 150 |
+
- "attn1"
|
| 151 |
+
#- "attn2"
|
| 152 |
+
|
config/customize_train.yaml
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretrained diffusers model path.
|
| 2 |
+
# Don't change
|
| 3 |
+
pretrained_model_path: "stabilityai/stable-video-diffusion-img2vid"
|
| 4 |
+
# The folder where your training outputs will be placed.
|
| 5 |
+
# Don't change
|
| 6 |
+
output_dir: "/home/user/app/outputs"
|
| 7 |
+
seed: 23
|
| 8 |
+
num_steps: 25
|
| 9 |
+
# Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
|
| 10 |
+
enable_xformers_memory_efficient_attention: True
|
| 11 |
+
# Use scaled dot product attention (Only available with >= Torch 2.0)
|
| 12 |
+
enable_torch_2_attn: True
|
| 13 |
+
|
| 14 |
+
use_sarp: true
|
| 15 |
+
|
| 16 |
+
use_motion_lora: true
|
| 17 |
+
train_motion_lora_only: false
|
| 18 |
+
retrain_motion_lora: true
|
| 19 |
+
|
| 20 |
+
use_inversed_latents: true
|
| 21 |
+
use_attention_matching: true
|
| 22 |
+
use_consistency_attention_control: false
|
| 23 |
+
dtype: fp16
|
| 24 |
+
|
| 25 |
+
visualize_attention_store: false
|
| 26 |
+
visualize_attention_store_steps: #[0, 5, 10, 15, 20, 24]
|
| 27 |
+
|
| 28 |
+
save_last_frames: True
|
| 29 |
+
load_from_last_frames_latents:
|
| 30 |
+
|
| 31 |
+
# data_params
|
| 32 |
+
data_params:
|
| 33 |
+
# Don't change
|
| 34 |
+
video_path: "/home/user/app/upload/source_and_edits/source.mp4"
|
| 35 |
+
# Don't change
|
| 36 |
+
keyframe_paths:
|
| 37 |
+
- "/home/user/app/upload/source_and_edits/ref.jpg"
|
| 38 |
+
start_t: 0
|
| 39 |
+
end_t: 1.6
|
| 40 |
+
sample_fps: 10
|
| 41 |
+
chunk_size: 16
|
| 42 |
+
overlay_size: 1
|
| 43 |
+
normalize: true
|
| 44 |
+
output_fps: 3
|
| 45 |
+
save_sampled_frame: true
|
| 46 |
+
output_res: [576, 576]
|
| 47 |
+
pad_to_fit: true
|
| 48 |
+
begin_clip_id: 0
|
| 49 |
+
end_clip_id: 1
|
| 50 |
+
|
| 51 |
+
train_motion_lora_params:
|
| 52 |
+
cache_latents: true
|
| 53 |
+
cached_latent_dir: null #/path/to/cached_latents
|
| 54 |
+
lora_rank: 32
|
| 55 |
+
# Use LoRA for the UNET model.
|
| 56 |
+
use_unet_lora: True
|
| 57 |
+
# LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
|
| 58 |
+
# See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
|
| 59 |
+
lora_unet_dropout: 0.1
|
| 60 |
+
# The only time you want this off is if you're doing full LoRA training.
|
| 61 |
+
save_pretrained_model: False
|
| 62 |
+
# Learning rate for AdamW
|
| 63 |
+
learning_rate: 5e-4
|
| 64 |
+
# Weight decay. Higher = more regularization. Lower = closer to dataset.
|
| 65 |
+
adam_weight_decay: 1e-2
|
| 66 |
+
# Maximum number of train steps. Model is saved after training.
|
| 67 |
+
max_train_steps: 600
|
| 68 |
+
# Saves a model every nth step.
|
| 69 |
+
checkpointing_steps: 100
|
| 70 |
+
# How many steps to do for validation if sample_preview is enabled.
|
| 71 |
+
validation_steps: 100
|
| 72 |
+
# Whether or not we want to use mixed precision with accelerate
|
| 73 |
+
mixed_precision: "fp16"
|
| 74 |
+
# Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
|
| 75 |
+
# If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
|
| 76 |
+
gradient_checkpointing: True
|
| 77 |
+
image_encoder_gradient_checkpointing: True
|
| 78 |
+
|
| 79 |
+
train_data:
|
| 80 |
+
# The width and height in which you want your training data to be resized to.
|
| 81 |
+
width: 576
|
| 82 |
+
height: 576
|
| 83 |
+
# This will find the closest aspect ratio to your input width and height.
|
| 84 |
+
# For example, 576x576 width and height with a video of resolution 1280x720 will be resized to 576x256
|
| 85 |
+
use_data_aug: ~ #"controlnet"
|
| 86 |
+
pad_to_fit: true
|
| 87 |
+
|
| 88 |
+
validation_data:
|
| 89 |
+
# Whether or not to sample preview during training (Requires more VRAM).
|
| 90 |
+
sample_preview: True
|
| 91 |
+
# The number of frames to sample during validation.
|
| 92 |
+
num_frames: 16
|
| 93 |
+
# Height and width of validation sample.
|
| 94 |
+
width: 576
|
| 95 |
+
height: 576
|
| 96 |
+
pad_to_fit: true
|
| 97 |
+
# scale of spatial LoRAs, default is 0
|
| 98 |
+
spatial_scale: 0
|
| 99 |
+
# scale of noise prior, i.e. the scale of inversion noises
|
| 100 |
+
noise_prior:
|
| 101 |
+
#- 0.0
|
| 102 |
+
- 1.0
|
| 103 |
+
|
| 104 |
+
sarp_params:
|
| 105 |
+
sarp_noise_scale: 0.005
|
| 106 |
+
|
| 107 |
+
attention_matching_params:
|
| 108 |
+
best_checkpoint_index: 500
|
| 109 |
+
lora_scale: 1.0
|
| 110 |
+
# lora path
|
| 111 |
+
lora_dir: ~
|
| 112 |
+
max_guidance_scale: 2.0
|
| 113 |
+
|
| 114 |
+
disk_store: True
|
| 115 |
+
load_attention_store: ~
|
| 116 |
+
load_consistency_attention_store: ~
|
| 117 |
+
load_consistency_train_attention_store: ~
|
| 118 |
+
registered_modules:
|
| 119 |
+
BasicTransformerBlock:
|
| 120 |
+
- "attn1"
|
| 121 |
+
#- "attn2"
|
| 122 |
+
TemporalBasicTransformerBlock:
|
| 123 |
+
- "attn1"
|
| 124 |
+
#- "attn2"
|
| 125 |
+
control_mode:
|
| 126 |
+
spatial_self: "masked_copy"
|
| 127 |
+
temporal_self: "copy_v2"
|
| 128 |
+
cross_replace_steps: 0.0
|
| 129 |
+
temporal_self_replace_steps: 1.0
|
| 130 |
+
spatial_self_replace_steps: 1.0
|
| 131 |
+
spatial_attention_chunk_size: 1
|
| 132 |
+
|
| 133 |
+
params:
|
| 134 |
+
edit0:
|
| 135 |
+
temporal_step_thr: [0.5, 0.8]
|
| 136 |
+
mask_thr: [0.35, 0.35]
|
| 137 |
+
edit1:
|
| 138 |
+
temporal_step_thr: [0.5, 0.8]
|
| 139 |
+
mask_thr: [0.35, 0.35]
|
| 140 |
+
|
| 141 |
+
long_video_params:
|
| 142 |
+
mode: "skip-interval"
|
| 143 |
+
registered_modules:
|
| 144 |
+
BasicTransformerBlock:
|
| 145 |
+
#- "attn1"
|
| 146 |
+
#- "attn2"
|
| 147 |
+
TemporalBasicTransformerBlock:
|
| 148 |
+
- "attn1"
|
| 149 |
+
#- "attn2"
|
config/customize_train_multi.yaml
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretrained diffusers model path.
|
| 2 |
+
# Don't change
|
| 3 |
+
pretrained_model_path: "stabilityai/stable-video-diffusion-img2vid"
|
| 4 |
+
# The folder where your training outputs will be placed.
|
| 5 |
+
# Don't change
|
| 6 |
+
output_dir: "/home/user/app/outputs"
|
| 7 |
+
seed: 23
|
| 8 |
+
num_steps: 25
|
| 9 |
+
# Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
|
| 10 |
+
enable_xformers_memory_efficient_attention: True
|
| 11 |
+
# Use scaled dot product attention (Only available with >= Torch 2.0)
|
| 12 |
+
enable_torch_2_attn: True
|
| 13 |
+
|
| 14 |
+
use_sarp: true
|
| 15 |
+
|
| 16 |
+
use_motion_lora: true
|
| 17 |
+
train_motion_lora_only: false
|
| 18 |
+
retrain_motion_lora: true
|
| 19 |
+
|
| 20 |
+
use_inversed_latents: true
|
| 21 |
+
use_attention_matching: true
|
| 22 |
+
use_consistency_attention_control: true
|
| 23 |
+
dtype: fp16
|
| 24 |
+
|
| 25 |
+
visualize_attention_store: false
|
| 26 |
+
visualize_attention_store_steps: #[0, 5, 10, 15, 20, 24]
|
| 27 |
+
|
| 28 |
+
save_last_frames: True
|
| 29 |
+
load_from_last_frames_latents:
|
| 30 |
+
|
| 31 |
+
# data_params
|
| 32 |
+
data_params:
|
| 33 |
+
# Don't change
|
| 34 |
+
video_path: "/home/user/app/upload/source_and_edits/source.mp4"
|
| 35 |
+
# Don't change
|
| 36 |
+
keyframe_paths:
|
| 37 |
+
- "/home/user/app/upload/source_and_edits/ref.jpg"
|
| 38 |
+
start_t: 0
|
| 39 |
+
end_t: 4.0
|
| 40 |
+
sample_fps: 10
|
| 41 |
+
chunk_size: 12
|
| 42 |
+
overlay_size: 3
|
| 43 |
+
normalize: true
|
| 44 |
+
output_fps: 10
|
| 45 |
+
save_sampled_frame: true
|
| 46 |
+
output_res: [768, 768]
|
| 47 |
+
pad_to_fit: true
|
| 48 |
+
begin_clip_id: 0
|
| 49 |
+
end_clip_id: 4
|
| 50 |
+
|
| 51 |
+
train_motion_lora_params:
|
| 52 |
+
cache_latents: true
|
| 53 |
+
cached_latent_dir: null #/path/to/cached_latents
|
| 54 |
+
lora_rank: 32
|
| 55 |
+
# Use LoRA for the UNET model.
|
| 56 |
+
use_unet_lora: True
|
| 57 |
+
# LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
|
| 58 |
+
# See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
|
| 59 |
+
lora_unet_dropout: 0.1
|
| 60 |
+
# The only time you want this off is if you're doing full LoRA training.
|
| 61 |
+
save_pretrained_model: False
|
| 62 |
+
# Learning rate for AdamW
|
| 63 |
+
learning_rate: 5e-4
|
| 64 |
+
# Weight decay. Higher = more regularization. Lower = closer to dataset.
|
| 65 |
+
adam_weight_decay: 1e-2
|
| 66 |
+
# Maximum number of train steps. Model is saved after training.
|
| 67 |
+
max_train_steps: 600
|
| 68 |
+
# Saves a model every nth step.
|
| 69 |
+
checkpointing_steps: 200
|
| 70 |
+
# How many steps to do for validation if sample_preview is enabled.
|
| 71 |
+
validation_steps: 200
|
| 72 |
+
# Whether or not we want to use mixed precision with accelerate
|
| 73 |
+
mixed_precision: "fp16"
|
| 74 |
+
# Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
|
| 75 |
+
# If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
|
| 76 |
+
gradient_checkpointing: True
|
| 77 |
+
image_encoder_gradient_checkpointing: True
|
| 78 |
+
|
| 79 |
+
train_data:
|
| 80 |
+
# The width and height in which you want your training data to be resized to.
|
| 81 |
+
width: 768
|
| 82 |
+
height: 768
|
| 83 |
+
# This will find the closest aspect ratio to your input width and height.
|
| 84 |
+
# For example, 768x768 width and height with a video of resolution 1280x720 will be resized to 768x256
|
| 85 |
+
use_data_aug: ~ #"controlnet"
|
| 86 |
+
pad_to_fit: true
|
| 87 |
+
|
| 88 |
+
validation_data:
|
| 89 |
+
# Whether or not to sample preview during training (Requires more VRAM).
|
| 90 |
+
sample_preview: True
|
| 91 |
+
# The number of frames to sample during validation.
|
| 92 |
+
num_frames: 8
|
| 93 |
+
# Height and width of validation sample.
|
| 94 |
+
width: 768
|
| 95 |
+
height: 768
|
| 96 |
+
pad_to_fit: true
|
| 97 |
+
# scale of spatial LoRAs, default is 0
|
| 98 |
+
spatial_scale: 0
|
| 99 |
+
# scale of noise prior, i.e. the scale of inversion noises
|
| 100 |
+
noise_prior:
|
| 101 |
+
#- 0.0
|
| 102 |
+
- 1.0
|
| 103 |
+
|
| 104 |
+
sarp_params:
|
| 105 |
+
sarp_noise_scale: 0.005
|
| 106 |
+
|
| 107 |
+
attention_matching_params:
|
| 108 |
+
best_checkpoint_index: 600
|
| 109 |
+
lora_scale: 1.0
|
| 110 |
+
# lora path
|
| 111 |
+
lora_dir: ~
|
| 112 |
+
max_guidance_scale: 2.0
|
| 113 |
+
|
| 114 |
+
disk_store: True
|
| 115 |
+
load_attention_store: ~
|
| 116 |
+
load_consistency_attention_store: ~
|
| 117 |
+
load_consistency_train_attention_store: ~
|
| 118 |
+
registered_modules:
|
| 119 |
+
BasicTransformerBlock:
|
| 120 |
+
- "attn1"
|
| 121 |
+
#- "attn2"
|
| 122 |
+
TemporalBasicTransformerBlock:
|
| 123 |
+
- "attn1"
|
| 124 |
+
#- "attn2"
|
| 125 |
+
control_mode:
|
| 126 |
+
spatial_self: "masked_copy"
|
| 127 |
+
temporal_self: "copy_v2"
|
| 128 |
+
cross_replace_steps: 0.0
|
| 129 |
+
temporal_self_replace_steps: 1.0
|
| 130 |
+
spatial_self_replace_steps: 1.0
|
| 131 |
+
spatial_attention_chunk_size: 1
|
| 132 |
+
|
| 133 |
+
params:
|
| 134 |
+
edit0:
|
| 135 |
+
temporal_step_thr: [0.5, 0.8]
|
| 136 |
+
mask_thr: [0.35, 0.35]
|
| 137 |
+
edit1:
|
| 138 |
+
temporal_step_thr: [0.5, 0.8]
|
| 139 |
+
mask_thr: [0.35, 0.35]
|
| 140 |
+
|
| 141 |
+
long_video_params:
|
| 142 |
+
mode: "skip-interval"
|
| 143 |
+
registered_modules:
|
| 144 |
+
BasicTransformerBlock:
|
| 145 |
+
#- "attn1"
|
| 146 |
+
#- "attn2"
|
| 147 |
+
TemporalBasicTransformerBlock:
|
| 148 |
+
- "attn1"
|
| 149 |
+
#- "attn2"
|
i2vedit/__init__.py
ADDED
|
File without changes
|
i2vedit/data.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import decord
|
| 3 |
+
import imageio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
from torchvision.transforms import Resize, Pad, InterpolationMode, ToTensor, InterpolationMode
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
#from i2vedit.utils.augment import ControlNetDataAugmentation, ColorDataAugmentation
|
| 16 |
+
# from utils.euler_utils import tensor_to_vae_latent
|
| 17 |
+
|
| 18 |
+
class ResolutionControl(object):
|
| 19 |
+
|
| 20 |
+
def __init__(self, input_res, output_res, pad_to_fit=False, fill=0, **kwargs):
|
| 21 |
+
|
| 22 |
+
self.ih, self.iw = input_res
|
| 23 |
+
self.output_res = output_res
|
| 24 |
+
self.pad_to_fit = pad_to_fit
|
| 25 |
+
self.fill=fill
|
| 26 |
+
|
| 27 |
+
def pad_with_ratio(self, frames, res, fill=0):
|
| 28 |
+
if isinstance(frames, torch.Tensor):
|
| 29 |
+
original_dim = frames.ndim
|
| 30 |
+
if frames.ndim > 4:
|
| 31 |
+
batch_size = frames.shape[0]
|
| 32 |
+
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
| 33 |
+
_, _, ih, iw = frames.shape
|
| 34 |
+
elif isinstance(frames, PIL.Image.Image):
|
| 35 |
+
iw, ih = frames.size
|
| 36 |
+
assert ih == self.ih and iw == self.iw, "resolution doesn't match."
|
| 37 |
+
#print("ih, iw", ih, iw)
|
| 38 |
+
i_ratio = ih / iw
|
| 39 |
+
h, w = res
|
| 40 |
+
#print("h,w", h ,w)
|
| 41 |
+
n_ratio = h / w
|
| 42 |
+
if i_ratio > n_ratio:
|
| 43 |
+
nw = int(ih / h * w)
|
| 44 |
+
#print("nw", nw)
|
| 45 |
+
frames = Pad(((nw - iw)//2,0), fill=fill)(frames)
|
| 46 |
+
else:
|
| 47 |
+
nh = int(iw / w * h)
|
| 48 |
+
frames = Pad((0,(nh - ih)//2), fill=fill)(frames)
|
| 49 |
+
#print("after pad", frames.shape)
|
| 50 |
+
if isinstance(frames, torch.Tensor):
|
| 51 |
+
if original_dim > 4:
|
| 52 |
+
frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
|
| 53 |
+
|
| 54 |
+
return frames
|
| 55 |
+
|
| 56 |
+
def return_to_original_res(self, frames):
|
| 57 |
+
if isinstance(frames, torch.Tensor):
|
| 58 |
+
original_dim = frames.ndim
|
| 59 |
+
if frames.ndim > 4:
|
| 60 |
+
batch_size = frames.shape[0]
|
| 61 |
+
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
| 62 |
+
_, _, h, w = frames.shape
|
| 63 |
+
elif isinstance(frames, PIL.Image.Image):
|
| 64 |
+
w, h = frames.size
|
| 65 |
+
#print("original res", (self.ih, self.iw))
|
| 66 |
+
#print("current res", (h, w))
|
| 67 |
+
assert h == self.output_res[0] and w == self.output_res[1], "resolution doesn't match."
|
| 68 |
+
n_ratio = h / w
|
| 69 |
+
ih, iw = self.ih, self.iw
|
| 70 |
+
i_ratio = ih / iw
|
| 71 |
+
if self.pad_to_fit:
|
| 72 |
+
if i_ratio > n_ratio:
|
| 73 |
+
nw = int(ih / h * w)
|
| 74 |
+
frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
|
| 75 |
+
if isinstance(frames, torch.Tensor):
|
| 76 |
+
frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
|
| 77 |
+
elif isinstance(frames, PIL.Image.Image):
|
| 78 |
+
frames = frames.crop(((nw - iw)//2,0,iw+(nw - iw)//2,ih))
|
| 79 |
+
else:
|
| 80 |
+
nh = int(iw / w * h)
|
| 81 |
+
frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
|
| 82 |
+
if isinstance(frames, torch.Tensor):
|
| 83 |
+
frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
|
| 84 |
+
elif isinstance(frames, PIL.Image.Image):
|
| 85 |
+
frames = frames.crop((0,(nh - ih)//2,iw,ih+(nh - ih)//2))
|
| 86 |
+
else:
|
| 87 |
+
frames = Resize((ih, iw), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
|
| 88 |
+
|
| 89 |
+
if isinstance(frames, torch.Tensor):
|
| 90 |
+
if original_dim > 4:
|
| 91 |
+
frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
|
| 92 |
+
|
| 93 |
+
return frames
|
| 94 |
+
|
| 95 |
+
def __call__(self, frames):
|
| 96 |
+
if self.pad_to_fit:
|
| 97 |
+
frames = self.pad_with_ratio(frames, self.output_res, fill=self.fill)
|
| 98 |
+
|
| 99 |
+
if isinstance(frames, torch.Tensor):
|
| 100 |
+
original_dim = frames.ndim
|
| 101 |
+
if frames.ndim > 4:
|
| 102 |
+
batch_size = frames.shape[0]
|
| 103 |
+
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
| 104 |
+
frames = (frames + 1) / 2.
|
| 105 |
+
|
| 106 |
+
frames = Resize(tuple(self.output_res), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
|
| 107 |
+
if isinstance(frames, torch.Tensor):
|
| 108 |
+
if original_dim > 4:
|
| 109 |
+
frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
|
| 110 |
+
frames = frames * 2 - 1
|
| 111 |
+
|
| 112 |
+
return frames
|
| 113 |
+
|
| 114 |
+
def callback(self, frames):
|
| 115 |
+
return self.return_to_original_res(frames)
|
| 116 |
+
|
| 117 |
+
class VideoIO(object):
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
video_path,
|
| 122 |
+
keyframe_paths,
|
| 123 |
+
output_dir,
|
| 124 |
+
device,
|
| 125 |
+
dtype,
|
| 126 |
+
start_t:int=0,
|
| 127 |
+
end_t:int=-1,
|
| 128 |
+
sample_fps:int=-1,
|
| 129 |
+
chunk_size: int=14,
|
| 130 |
+
overlay_size: int=-1,
|
| 131 |
+
normalize: bool=True,
|
| 132 |
+
output_fps: int=-1,
|
| 133 |
+
save_sampled_video: bool=True,
|
| 134 |
+
**kwargs
|
| 135 |
+
):
|
| 136 |
+
self.video_path = video_path
|
| 137 |
+
self.keyframe_paths = keyframe_paths
|
| 138 |
+
self.device = device
|
| 139 |
+
self.dtype = dtype
|
| 140 |
+
self.start_t = start_t
|
| 141 |
+
self.end_t = end_t
|
| 142 |
+
self.sample_fps = sample_fps
|
| 143 |
+
self.chunk_size = chunk_size
|
| 144 |
+
self.overlay_size = overlay_size
|
| 145 |
+
self.normalize = normalize
|
| 146 |
+
self.save_sampled_video = save_sampled_video
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
vr = decord.VideoReader(video_path)
|
| 151 |
+
initial_fps = vr.get_avg_fps()
|
| 152 |
+
self.initial_fps = initial_fps
|
| 153 |
+
|
| 154 |
+
if output_fps == -1: output_fps = initial_fps
|
| 155 |
+
|
| 156 |
+
self.video_writer_list = []
|
| 157 |
+
for keyframe_path in keyframe_paths:
|
| 158 |
+
fname, ext = os.path.splitext(os.path.basename(keyframe_path))
|
| 159 |
+
output_video_path = os.path.join(output_dir, fname+".mp4")
|
| 160 |
+
self.video_writer_list.append( imageio.get_writer(output_video_path, fps=output_fps) )
|
| 161 |
+
|
| 162 |
+
if save_sampled_video:
|
| 163 |
+
fname, ext = os.path.splitext(os.path.basename(video_path))
|
| 164 |
+
output_sampled_video_path = os.path.join(output_dir, fname+f"_from{start_t}s_to{end_t}s{ext}")
|
| 165 |
+
self.sampled_video_writer = imageio.get_writer(output_sampled_video_path, fps=output_fps)
|
| 166 |
+
|
| 167 |
+
def read_keyframe_iter(self):
|
| 168 |
+
for keyframe_path in self.keyframe_paths:
|
| 169 |
+
image = Image.open(keyframe_path).convert("RGB")
|
| 170 |
+
yield image
|
| 171 |
+
|
| 172 |
+
def read_video_iter(self):
|
| 173 |
+
vr = decord.VideoReader(self.video_path)
|
| 174 |
+
if self.sample_fps == -1: self.sample_fps = self.initial_fps
|
| 175 |
+
if self.end_t == -1:
|
| 176 |
+
self.end_t = len(vr) / self.initial_fps
|
| 177 |
+
else:
|
| 178 |
+
self.end_t = min(len(vr) / self.initial_fps, self.end_t)
|
| 179 |
+
if self.overlay_size == -1: self.overlay_size = 0
|
| 180 |
+
assert 0 <= self.start_t < self.end_t
|
| 181 |
+
assert self.sample_fps > 0
|
| 182 |
+
|
| 183 |
+
start_f_ind = int(self.start_t * self.initial_fps)
|
| 184 |
+
end_f_ind = int(self.end_t * self.initial_fps)
|
| 185 |
+
num_f = int((self.end_t - self.start_t) * self.sample_fps)
|
| 186 |
+
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
|
| 187 |
+
print("sample_idx", sample_idx)
|
| 188 |
+
|
| 189 |
+
assert len(sample_idx) > 0, f"sample_idx is empty!"
|
| 190 |
+
|
| 191 |
+
begin_frame_idx = 0
|
| 192 |
+
while begin_frame_idx < len(sample_idx):
|
| 193 |
+
self.begin_frame_idx = begin_frame_idx
|
| 194 |
+
begin_frame_idx = max(begin_frame_idx - self.overlay_size, 0)
|
| 195 |
+
next_frame_idx = min(begin_frame_idx + self.chunk_size, len(sample_idx))
|
| 196 |
+
|
| 197 |
+
video = vr.get_batch(sample_idx[begin_frame_idx:next_frame_idx])
|
| 198 |
+
begin_frame_idx = next_frame_idx
|
| 199 |
+
|
| 200 |
+
if self.save_sampled_video:
|
| 201 |
+
overlay_size = 0 if self.begin_frame_idx == 0 else self.overlay_size
|
| 202 |
+
print(type(video))
|
| 203 |
+
for frame in video[overlay_size:]:
|
| 204 |
+
self.sampled_video_writer.append_data(frame.detach().cpu().numpy())
|
| 205 |
+
|
| 206 |
+
video = torch.Tensor(video).to(self.device).to(self.dtype)
|
| 207 |
+
video = rearrange(video, "f h w c -> f c h w")
|
| 208 |
+
|
| 209 |
+
if self.normalize:
|
| 210 |
+
video = video / 127.5 - 1.0
|
| 211 |
+
|
| 212 |
+
yield video
|
| 213 |
+
|
| 214 |
+
def write_video(self, video, video_id, resctrl: ResolutionControl = None):
|
| 215 |
+
'''
|
| 216 |
+
video:
|
| 217 |
+
'''
|
| 218 |
+
overlay_size = 0 if self.begin_frame_idx == 0 else self.overlay_size
|
| 219 |
+
for img in video[overlay_size:]:
|
| 220 |
+
if resctrl is not None:
|
| 221 |
+
img = resctrl.callback(img)
|
| 222 |
+
self.video_writer_list[video_id].append_data(np.array(img))
|
| 223 |
+
|
| 224 |
+
def close(self):
|
| 225 |
+
for video_writer in self.video_writer_list:
|
| 226 |
+
video_writer.close()
|
| 227 |
+
if self.save_sampled_video:
|
| 228 |
+
self.sampled_video_writer.close()
|
| 229 |
+
self.begin_frame_idx = 0
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class SingleClipDataset(Dataset):
|
| 233 |
+
|
| 234 |
+
# data_aug_class = {
|
| 235 |
+
# "rsfnet": ColorDataAugmentation,
|
| 236 |
+
# "controlnet": ControlNetDataAugmentation
|
| 237 |
+
# }
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
inversion_noise,
|
| 242 |
+
video_clip,
|
| 243 |
+
keyframe,
|
| 244 |
+
firstframe,
|
| 245 |
+
height,
|
| 246 |
+
width,
|
| 247 |
+
use_data_aug=None,
|
| 248 |
+
pad_to_fit=False,
|
| 249 |
+
keyframe_latent=None
|
| 250 |
+
):
|
| 251 |
+
|
| 252 |
+
self.resctrl = ResolutionControl(video_clip.shape[-2:],(height,width),pad_to_fit,fill=-1)
|
| 253 |
+
|
| 254 |
+
video_clip = rearrange(video_clip, "1 f c h w -> f c h w")
|
| 255 |
+
keyframe = rearrange(keyframe, "1 f c h w -> f c h w")
|
| 256 |
+
firstframe = rearrange(firstframe, "1 f c h w -> f c h w")
|
| 257 |
+
|
| 258 |
+
if inversion_noise is not None:
|
| 259 |
+
inversion_noise = rearrange(inversion_noise, "1 f c h w -> f c h w")
|
| 260 |
+
|
| 261 |
+
if use_data_aug is not None:
|
| 262 |
+
if use_data_aug in self.data_aug_class:
|
| 263 |
+
self.data_augment = self.data_aug_class[use_data_aug]()
|
| 264 |
+
use_data_aug = True
|
| 265 |
+
print(f"Augmentation mode: {use_data_aug} is implemented.")
|
| 266 |
+
else:
|
| 267 |
+
raise NotImplementedError(f"Augmentation mode: {use_data_aug} is not implemented!")
|
| 268 |
+
else:
|
| 269 |
+
use_data_aug = False
|
| 270 |
+
|
| 271 |
+
self.video_clip = video_clip
|
| 272 |
+
self.keyframe = keyframe
|
| 273 |
+
self.firstframe = firstframe
|
| 274 |
+
self.inversion_noise = inversion_noise
|
| 275 |
+
self.use_data_aug = use_data_aug
|
| 276 |
+
self.keyframe_latent = keyframe_latent
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def __getname__(): return 'single_clip'
|
| 280 |
+
|
| 281 |
+
def __len__(self):
|
| 282 |
+
return 1
|
| 283 |
+
|
| 284 |
+
def __getitem__(self, index):
|
| 285 |
+
|
| 286 |
+
motion_values = torch.Tensor([127.])
|
| 287 |
+
|
| 288 |
+
pixel_values = self.resctrl(self.video_clip)
|
| 289 |
+
refer_pixel_values = self.resctrl(self.keyframe)
|
| 290 |
+
cross_pixel_values = self.resctrl(self.firstframe)
|
| 291 |
+
|
| 292 |
+
if self.use_data_aug:
|
| 293 |
+
print("pixel_values before augment", refer_pixel_values.min(), refer_pixel_values.max())
|
| 294 |
+
#pixel_values, refer_pixel_values, cross_pixel_values = \
|
| 295 |
+
#self.data_augment.augment(
|
| 296 |
+
# torch.cat([pixel_values, refer_pixel_values, cross_pixel_values], dim=0)
|
| 297 |
+
#).tensor_split([pixel_values.shape[0],pixel_values.shape[0]+refer_pixel_values.shape[0]],dim=0)
|
| 298 |
+
refer_pixel_values = self.data_augment.augment(refer_pixel_values)
|
| 299 |
+
print("pixel_values after augment", refer_pixel_values.min(), refer_pixel_values.max())
|
| 300 |
+
|
| 301 |
+
outputs = {
|
| 302 |
+
"pixel_values": pixel_values,
|
| 303 |
+
"refer_pixel_values": refer_pixel_values,
|
| 304 |
+
"cross_pixel_values": cross_pixel_values,
|
| 305 |
+
"motion_values": motion_values,
|
| 306 |
+
'dataset': self.__getname__(),
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
if self.inversion_noise is not None:
|
| 310 |
+
outputs.update({
|
| 311 |
+
"inversion_noise": self.inversion_noise
|
| 312 |
+
})
|
| 313 |
+
if self.keyframe_latent is not None:
|
| 314 |
+
outputs.update({
|
| 315 |
+
"refer_latents": self.keyframe_latent
|
| 316 |
+
})
|
| 317 |
+
return outputs
|
i2vedit/inference.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
+
import re
|
| 5 |
+
import warnings
|
| 6 |
+
import imageio
|
| 7 |
+
import random
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from tqdm import trange
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch.nn.functional import interpolate
|
| 15 |
+
from diffusers import StableVideoDiffusionPipeline, EulerDiscreteScheduler
|
| 16 |
+
from diffusers import TextToVideoSDPipeline
|
| 17 |
+
|
| 18 |
+
from i2vedit.train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
|
| 19 |
+
from i2vedit.utils.lora_handler import LoraHandler
|
| 20 |
+
from i2vedit.utils.model_utils import P2PStableVideoDiffusionPipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def initialize_pipeline(
|
| 24 |
+
model: str,
|
| 25 |
+
device: str = "cuda",
|
| 26 |
+
xformers: bool = False,
|
| 27 |
+
sdp: bool = False,
|
| 28 |
+
lora_path: str = "",
|
| 29 |
+
lora_rank: int = 64,
|
| 30 |
+
lora_scale: float = 1.0,
|
| 31 |
+
load_spatial_lora: bool = False,
|
| 32 |
+
dtype = torch.float16
|
| 33 |
+
):
|
| 34 |
+
with warnings.catch_warnings():
|
| 35 |
+
warnings.simplefilter("ignore")
|
| 36 |
+
|
| 37 |
+
scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(model)
|
| 38 |
+
|
| 39 |
+
# Freeze any necessary models
|
| 40 |
+
freeze_models([vae, image_encoder, unet])
|
| 41 |
+
|
| 42 |
+
# Enable xformers if available
|
| 43 |
+
handle_memory_attention(xformers, sdp, unet)
|
| 44 |
+
|
| 45 |
+
lora_manager_temporal = LoraHandler(
|
| 46 |
+
version="cloneofsimo",
|
| 47 |
+
use_unet_lora=True,
|
| 48 |
+
use_image_lora=False,
|
| 49 |
+
save_for_webui=False,
|
| 50 |
+
only_for_webui=False,
|
| 51 |
+
unet_replace_modules=["TemporalBasicTransformerBlock"],
|
| 52 |
+
image_encoder_replace_modules=None,
|
| 53 |
+
lora_bias=None
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
|
| 57 |
+
True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
|
| 58 |
+
|
| 59 |
+
if load_spatial_lora:
|
| 60 |
+
lora_manager_spatial = LoraHandler(
|
| 61 |
+
version="cloneofsimo",
|
| 62 |
+
use_unet_lora=True,
|
| 63 |
+
use_image_lora=False,
|
| 64 |
+
save_for_webui=False,
|
| 65 |
+
only_for_webui=False,
|
| 66 |
+
unet_replace_modules=["BasicTransformerBlock"],
|
| 67 |
+
image_encoder_replace_modules=None,
|
| 68 |
+
lora_bias=None
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
spatial_lora_path = lora_path.replace("temporal", "spatial")
|
| 72 |
+
unet_lora_params, unet_negation = lora_manager_spatial.add_lora_to_model(
|
| 73 |
+
True, unet, lora_manager_spatial.unet_replace_modules, 0, spatial_lora_path, r=lora_rank, scale=lora_scale)
|
| 74 |
+
|
| 75 |
+
unet.eval()
|
| 76 |
+
image_encoder.eval()
|
| 77 |
+
unet_and_text_g_c(unet, image_encoder, False, False)
|
| 78 |
+
|
| 79 |
+
pipe = P2PStableVideoDiffusionPipeline.from_pretrained(
|
| 80 |
+
model,
|
| 81 |
+
scheduler=scheduler,
|
| 82 |
+
feature_extractor=feature_extractor,
|
| 83 |
+
image_encoder=image_encoder.to(device=device, dtype=dtype),
|
| 84 |
+
vae=vae.to(device=device, dtype=dtype),
|
| 85 |
+
unet=unet.to(device=device, dtype=dtype)
|
| 86 |
+
)
|
| 87 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 88 |
+
|
| 89 |
+
return pipe
|
i2vedit/prompt_attention/__init__.py
ADDED
|
File without changes
|
i2vedit/prompt_attention/attention_register.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
register the attention controller into the UNet of stable diffusion
|
| 3 |
+
Build a customized attention function `_attention'
|
| 4 |
+
Replace the original attention function with `forward' and `spatial_temporal_forward' in attention_controlled_forward function
|
| 5 |
+
Most of spatial_temporal_forward is directly copy from `video_diffusion/models/attention.py'
|
| 6 |
+
TODO FIXME: merge redundant code with attention.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.utils.checkpoint import checkpoint
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
from einops import rearrange, repeat
|
| 16 |
+
import math
|
| 17 |
+
from inspect import isfunction
|
| 18 |
+
from typing import Any, Optional
|
| 19 |
+
from packaging import version
|
| 20 |
+
|
| 21 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
|
| 22 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
| 23 |
+
|
| 24 |
+
class AttnControllerProcessor:
|
| 25 |
+
|
| 26 |
+
def __init__(self, consistency_controller, controller, place_in_unet, attention_type):
|
| 27 |
+
|
| 28 |
+
self.consistency_controller = consistency_controller
|
| 29 |
+
self.controller = controller
|
| 30 |
+
self.place_in_unet = place_in_unet
|
| 31 |
+
self.attention_type = attention_type
|
| 32 |
+
|
| 33 |
+
def __call__(
|
| 34 |
+
self,
|
| 35 |
+
attn: Attention,
|
| 36 |
+
hidden_states: torch.FloatTensor,
|
| 37 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 38 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 39 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 40 |
+
scale: float = 1.0,
|
| 41 |
+
) -> torch.FloatTensor:
|
| 42 |
+
residual = hidden_states
|
| 43 |
+
if attn.spatial_norm is not None:
|
| 44 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 45 |
+
|
| 46 |
+
input_ndim = hidden_states.ndim
|
| 47 |
+
|
| 48 |
+
if input_ndim == 4:
|
| 49 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 50 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 51 |
+
|
| 52 |
+
batch_size, sequence_length, _ = (
|
| 53 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if attention_mask is not None:
|
| 57 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 58 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 59 |
+
# (batch, heads, source_length, target_length)
|
| 60 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 61 |
+
|
| 62 |
+
if attn.group_norm is not None:
|
| 63 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 64 |
+
|
| 65 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
| 66 |
+
query = attn.to_q(hidden_states, *args)
|
| 67 |
+
|
| 68 |
+
is_cross = True
|
| 69 |
+
if encoder_hidden_states is None:
|
| 70 |
+
encoder_hidden_states = hidden_states
|
| 71 |
+
is_cross = False
|
| 72 |
+
elif attn.norm_cross:
|
| 73 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 74 |
+
|
| 75 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
| 76 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
| 77 |
+
|
| 78 |
+
inner_dim = key.shape[-1]
|
| 79 |
+
head_dim = inner_dim // attn.heads
|
| 80 |
+
|
| 81 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 82 |
+
|
| 83 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 84 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 85 |
+
|
| 86 |
+
if self.consistency_controller is not None:
|
| 87 |
+
key = self.consistency_controller(
|
| 88 |
+
key, is_cross, f"{self.place_in_unet}_{self.attention_type}_k"
|
| 89 |
+
)
|
| 90 |
+
value = self.consistency_controller(
|
| 91 |
+
value, is_cross, f"{self.place_in_unet}_{self.attention_type}_v"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 95 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 96 |
+
if self.controller is not None:
|
| 97 |
+
hidden_states = self.controller.attention_control(
|
| 98 |
+
self.place_in_unet, self.attention_type, is_cross,
|
| 99 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 103 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 107 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 108 |
+
|
| 109 |
+
# linear proj
|
| 110 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
| 111 |
+
# dropout
|
| 112 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 113 |
+
|
| 114 |
+
if input_ndim == 4:
|
| 115 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 116 |
+
|
| 117 |
+
if attn.residual_connection:
|
| 118 |
+
hidden_states = hidden_states + residual
|
| 119 |
+
|
| 120 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 121 |
+
|
| 122 |
+
return hidden_states
|
| 123 |
+
|
| 124 |
+
def register_attention_control(
|
| 125 |
+
model,
|
| 126 |
+
controller=None,
|
| 127 |
+
consistency_controller=None,
|
| 128 |
+
find_modules = {},
|
| 129 |
+
consistency_find_modules = {},
|
| 130 |
+
undo=False
|
| 131 |
+
):
|
| 132 |
+
"Connect a model with a controller"
|
| 133 |
+
class DummyController:
|
| 134 |
+
|
| 135 |
+
def __call__(self, *args):
|
| 136 |
+
return args[0]
|
| 137 |
+
|
| 138 |
+
def __init__(self):
|
| 139 |
+
self.num_att_layers = 0
|
| 140 |
+
|
| 141 |
+
#if controller is None:
|
| 142 |
+
# controller = DummyController()
|
| 143 |
+
|
| 144 |
+
f_keys = list(set(find_modules.keys()).difference(set(consistency_find_modules.keys())))
|
| 145 |
+
c_keys = list(set(consistency_find_modules.keys()).difference(set(find_modules.keys())))
|
| 146 |
+
common_keys = list(set(find_modules.keys()).intersection(set(consistency_find_modules.keys())))
|
| 147 |
+
new_find_modules = {}
|
| 148 |
+
for f_key in f_keys:
|
| 149 |
+
new_find_modules.update({
|
| 150 |
+
f_key: find_modules[f_key]
|
| 151 |
+
})
|
| 152 |
+
new_consistency_find_modules = {}
|
| 153 |
+
for c_key in c_keys:
|
| 154 |
+
new_consistency_find_modules.update({
|
| 155 |
+
c_key: consistency_find_modules[c_key]
|
| 156 |
+
})
|
| 157 |
+
common_modules = {}
|
| 158 |
+
for key in common_keys:
|
| 159 |
+
find_modules[key] = [] if find_modules[key] is None else find_modules[key]
|
| 160 |
+
consistency_find_modules[key] = [] if consistency_find_modules[key] is None else consistency_find_modules[key]
|
| 161 |
+
f_list = list(set(find_modules[key]).difference(set(consistency_find_modules[key])))
|
| 162 |
+
c_list = list(set(consistency_find_modules[key]).difference(set(find_modules[key])))
|
| 163 |
+
common_list = list(set(find_modules[key]).intersection(set(consistency_find_modules[key])))
|
| 164 |
+
if len(f_list) > 0:
|
| 165 |
+
new_find_modules.update({key: f_list})
|
| 166 |
+
if len(c_list) > 0:
|
| 167 |
+
new_consistency_find_modules.update({key: c_list})
|
| 168 |
+
if len(common_list) > 0:
|
| 169 |
+
common_modules.update({key: common_list})
|
| 170 |
+
|
| 171 |
+
find_modules = new_find_modules
|
| 172 |
+
consistency_find_modules = new_consistency_find_modules
|
| 173 |
+
|
| 174 |
+
print("common_modules", common_modules)
|
| 175 |
+
print("find_modules", find_modules)
|
| 176 |
+
print("consistency_find_modules", consistency_find_modules)
|
| 177 |
+
print("controller", controller, "consistency_controller", consistency_controller)
|
| 178 |
+
|
| 179 |
+
def register_recr(net_, count1, count2, place_in_unet):
|
| 180 |
+
|
| 181 |
+
if net_[1].__class__.__name__ == 'BasicTransformerBlock':
|
| 182 |
+
attention_type = 'spatial'
|
| 183 |
+
elif net_[1].__class__.__name__ == 'TemporalBasicTransformerBlock':
|
| 184 |
+
attention_type = 'temporal'
|
| 185 |
+
|
| 186 |
+
control1, control2 = None, None
|
| 187 |
+
if net_[1].__class__.__name__ in common_modules.keys():
|
| 188 |
+
control1, control2 = consistency_controller, controller
|
| 189 |
+
module_list = common_modules[net_[1].__class__.__name__]
|
| 190 |
+
elif net_[1].__class__.__name__ in find_modules.keys():
|
| 191 |
+
control1, control2 = None, controller
|
| 192 |
+
module_list = find_modules[net_[1].__class__.__name__]
|
| 193 |
+
elif net_[1].__class__.__name__ in consistency_find_modules.keys():
|
| 194 |
+
control1, control2 = consistency_controller, None
|
| 195 |
+
module_list = consistency_find_modules[net_[1].__class__.__name__]
|
| 196 |
+
|
| 197 |
+
if any([control is not None for control in [control1, control2]]):
|
| 198 |
+
|
| 199 |
+
if module_list is not None and 'attn1' in module_list:
|
| 200 |
+
if undo:
|
| 201 |
+
net_[1].attn1.set_processor(AttnProcessor2_0())
|
| 202 |
+
else:
|
| 203 |
+
net_[1].attn1.set_processor(AttnControllerProcessor(control1, control2, place_in_unet, attention_type = attention_type))
|
| 204 |
+
if control1 is not None: count1 += 1
|
| 205 |
+
if control2 is not None: count2 += 1
|
| 206 |
+
|
| 207 |
+
if module_list is not None and 'attn2' in module_list:
|
| 208 |
+
if undo:
|
| 209 |
+
net_[1].attn2.set_processor(AttnProcessor2_0())
|
| 210 |
+
else:
|
| 211 |
+
net_[1].attn2.set_processor(AttnControllerProcessor(control1, control2, place_in_unet, attention_type = attention_type))
|
| 212 |
+
if control1 is not None: count1 += 1
|
| 213 |
+
if control2 is not None: count2 += 1
|
| 214 |
+
|
| 215 |
+
return count1, count2
|
| 216 |
+
|
| 217 |
+
elif hasattr(net_[1], 'children'):
|
| 218 |
+
for net in net_[1].named_children():
|
| 219 |
+
count1, count2 = register_recr(net, count1, count2, place_in_unet)
|
| 220 |
+
|
| 221 |
+
return count1, count2
|
| 222 |
+
|
| 223 |
+
cross_att_count1 = 0
|
| 224 |
+
cross_att_count2 = 0
|
| 225 |
+
sub_nets = model.named_children()
|
| 226 |
+
for net in sub_nets:
|
| 227 |
+
if "down" in net[0]:
|
| 228 |
+
c1, c2 = register_recr(net, 0, 0, "down")
|
| 229 |
+
cross_att_count1 += c1
|
| 230 |
+
cross_att_count2 += c2
|
| 231 |
+
elif "up" in net[0]:
|
| 232 |
+
c1, c2 = register_recr(net, 0, 0, "up")
|
| 233 |
+
cross_att_count1 += c1
|
| 234 |
+
cross_att_count2 += c2
|
| 235 |
+
elif "mid" in net[0]:
|
| 236 |
+
c1, c2 = register_recr(net, 0, 0, "mid")
|
| 237 |
+
cross_att_count1 += c1
|
| 238 |
+
cross_att_count2 += c2
|
| 239 |
+
if undo:
|
| 240 |
+
print(f"Number of attention layer unregistered for controller: {cross_att_count2}")
|
| 241 |
+
print(f"Number of attention layer unregistered for consistency_controller: {cross_att_count1}")
|
| 242 |
+
else:
|
| 243 |
+
print(f"Number of attention layer registered for controller: {cross_att_count2}")
|
| 244 |
+
if controller is not None:
|
| 245 |
+
controller.num_att_layers = cross_att_count2
|
| 246 |
+
print(f"Number of attention layer registered for consistency_controller: {cross_att_count1}")
|
| 247 |
+
if consistency_controller is not None:
|
| 248 |
+
consistency_controller.num_att_layers = cross_att_count1
|
| 249 |
+
|
| 250 |
+
|
i2vedit/prompt_attention/attention_store.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import abc
|
| 7 |
+
import os
|
| 8 |
+
import copy
|
| 9 |
+
import shutil
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from packaging import version
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
from i2vedit.prompt_attention.common.util import get_time_string
|
| 17 |
+
|
| 18 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 19 |
+
SDP_IS_AVAILABLE = True
|
| 20 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
| 21 |
+
|
| 22 |
+
BACKEND_MAP = {
|
| 23 |
+
SDPBackend.MATH: {
|
| 24 |
+
"enable_math": True,
|
| 25 |
+
"enable_flash": False,
|
| 26 |
+
"enable_mem_efficient": False,
|
| 27 |
+
},
|
| 28 |
+
SDPBackend.FLASH_ATTENTION: {
|
| 29 |
+
"enable_math": False,
|
| 30 |
+
"enable_flash": True,
|
| 31 |
+
"enable_mem_efficient": False,
|
| 32 |
+
},
|
| 33 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
| 34 |
+
"enable_math": False,
|
| 35 |
+
"enable_flash": False,
|
| 36 |
+
"enable_mem_efficient": True,
|
| 37 |
+
},
|
| 38 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
| 39 |
+
}
|
| 40 |
+
else:
|
| 41 |
+
from contextlib import nullcontext
|
| 42 |
+
|
| 43 |
+
SDP_IS_AVAILABLE = False
|
| 44 |
+
sdp_kernel = nullcontext
|
| 45 |
+
BACKEND_MAP = {}
|
| 46 |
+
logpy.warn(
|
| 47 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
| 48 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
| 49 |
+
f"You might want to consider upgrading."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
class AttentionControl(abc.ABC):
|
| 53 |
+
|
| 54 |
+
def step_callback(self, x_t):
|
| 55 |
+
self.cur_att_layer = 0
|
| 56 |
+
self.cur_step += 1
|
| 57 |
+
self.between_steps()
|
| 58 |
+
return x_t
|
| 59 |
+
|
| 60 |
+
def between_steps(self):
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def num_uncond_att_layers(self):
|
| 65 |
+
"""I guess the diffusion of google has some unconditional attention layer
|
| 66 |
+
No unconditional attention layer in Stable diffusion
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
_type_: _description_
|
| 70 |
+
"""
|
| 71 |
+
# return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
|
| 72 |
+
return 0
|
| 73 |
+
|
| 74 |
+
@abc.abstractmethod
|
| 75 |
+
def forward (self, attn, is_cross: bool, place_in_unet: str):
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 79 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 80 |
+
if self.LOW_RESOURCE or 'mask' in place_in_unet:
|
| 81 |
+
# For inversion without null text file
|
| 82 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 83 |
+
else:
|
| 84 |
+
# For classifier-free guidance scale!=1
|
| 85 |
+
h = attn.shape[0]
|
| 86 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 87 |
+
self.cur_att_layer += 1
|
| 88 |
+
|
| 89 |
+
return attn
|
| 90 |
+
|
| 91 |
+
def reset(self):
|
| 92 |
+
self.cur_step = 0
|
| 93 |
+
self.cur_att_layer = 0
|
| 94 |
+
|
| 95 |
+
def __init__(self,
|
| 96 |
+
):
|
| 97 |
+
self.LOW_RESOURCE = False # assume the edit have cfg
|
| 98 |
+
self.cur_step = 0
|
| 99 |
+
self.num_att_layers = -1
|
| 100 |
+
self.cur_att_layer = 0
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class AttentionStore(AttentionControl):
|
| 104 |
+
def step_callback(self, x_t):
|
| 105 |
+
|
| 106 |
+
x_t = super().step_callback(x_t)
|
| 107 |
+
if self.save_latents:
|
| 108 |
+
self.latents_store.append(x_t.cpu().detach())
|
| 109 |
+
return x_t
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def get_empty_store():
|
| 113 |
+
return {"down_spatial_q_cross": [], "mid_spatial_q_cross": [], "up_spatial_q_cross": [],
|
| 114 |
+
"down_spatial_k_cross": [], "mid_spatial_k_cross": [], "up_spatial_k_cross": [],
|
| 115 |
+
"down_spatial_mask_cross": [], "mid_spatial_mask_cross": [], "up_spatial_mask_cross": [],
|
| 116 |
+
"down_temporal_cross": [], "mid_temporal_cross": [], "up_temporal_cross": [],
|
| 117 |
+
"down_spatial_q_self": [], "mid_spatial_q_self": [], "up_spatial_q_self": [],
|
| 118 |
+
"down_spatial_k_self": [], "mid_spatial_k_self": [], "up_spatial_k_self": [],
|
| 119 |
+
"down_spatial_mask_self": [], "mid_spatial_mask_self": [], "up_spatial_mask_self": [],
|
| 120 |
+
"down_spatial_self": [], "mid_spatial_self": [], "up_spatial_self": [],
|
| 121 |
+
"down_temporal_self": [], "mid_temporal_self": [], "up_temporal_self": []}
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def get_empty_cross_store():
|
| 125 |
+
return {"down_spatial_q_cross": [], "mid_spatial_q_cross": [], "up_spatial_q_cross": [],
|
| 126 |
+
"down_spatial_k_cross": [], "mid_spatial_k_cross": [], "up_spatial_k_cross": [],
|
| 127 |
+
"down_spatial_mask_cross": [], "mid_spatial_mask_cross": [], "up_spatial_mask_cross": [],
|
| 128 |
+
"down_temporal_cross": [], "mid_temporal_cross": [], "up_temporal_cross": [],
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 132 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 133 |
+
if attn.shape[-2] <= 8*9*8*16: # avoid memory overload
|
| 134 |
+
# print(f"Store attention map {key} of shape {attn.shape}")
|
| 135 |
+
if (is_cross or self.save_self_attention or 'mask' in key):
|
| 136 |
+
if False:#attn.shape[-2] >= 4*9*4*16:
|
| 137 |
+
append_tensor = attn.cpu().detach()
|
| 138 |
+
else:
|
| 139 |
+
append_tensor = attn
|
| 140 |
+
self.step_store[key].append(copy.deepcopy(append_tensor))
|
| 141 |
+
# FIXME: Are these deepcopy all necessary?
|
| 142 |
+
# self.step_store[key].append(append_tensor)
|
| 143 |
+
return attn
|
| 144 |
+
|
| 145 |
+
def between_steps(self):
|
| 146 |
+
if len(self.attention_store) == 0:
|
| 147 |
+
self.attention_store = {key: self.step_store[key] for key in self.step_store if 'mask' in key}
|
| 148 |
+
else:
|
| 149 |
+
for key in self.attention_store:
|
| 150 |
+
if 'mask' in key:
|
| 151 |
+
for i in range(len(self.attention_store[key])):
|
| 152 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
| 153 |
+
|
| 154 |
+
if self.disk_store:
|
| 155 |
+
path = self.store_dir + f'/{self.cur_step:03d}.pt'
|
| 156 |
+
if self.load_attention_store is None:
|
| 157 |
+
torch.save(copy.deepcopy(self.step_store), path)
|
| 158 |
+
self.attention_store_all_step.append(path)
|
| 159 |
+
else:
|
| 160 |
+
self.attention_store_all_step.append(copy.deepcopy(self.step_store))
|
| 161 |
+
self.step_store = self.get_empty_store()
|
| 162 |
+
|
| 163 |
+
def get_average_attention(self):
|
| 164 |
+
"divide the attention map value in attention store by denoising steps"
|
| 165 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store if 'mask' in key}
|
| 166 |
+
return average_attention
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def reset(self):
|
| 170 |
+
super(AttentionStore, self).reset()
|
| 171 |
+
self.step_store = self.get_empty_store()
|
| 172 |
+
self.attention_store_all_step = []
|
| 173 |
+
if self.disk_store:
|
| 174 |
+
if self.load_attention_store is not None:
|
| 175 |
+
flist = sorted(os.listdir(self.load_attention_store), key=lambda x: int(x[:-3]))
|
| 176 |
+
self.attention_store_all_step = [
|
| 177 |
+
os.path.join(self.load_attention_store, fn) for fn in flist
|
| 178 |
+
]
|
| 179 |
+
self.attention_store = {}
|
| 180 |
+
|
| 181 |
+
def __init__(self,
|
| 182 |
+
save_self_attention:bool=True,
|
| 183 |
+
save_latents:bool=True,
|
| 184 |
+
disk_store=False,
|
| 185 |
+
load_attention_store:str=None,
|
| 186 |
+
store_path:str=None
|
| 187 |
+
):
|
| 188 |
+
super(AttentionStore, self).__init__()
|
| 189 |
+
self.disk_store = disk_store
|
| 190 |
+
if load_attention_store is not None:
|
| 191 |
+
if not os.path.exists(load_attention_store):
|
| 192 |
+
print(f"can not load attentions from {load_attention_store}: file doesn't exist.")
|
| 193 |
+
load_attention_store = None
|
| 194 |
+
else:
|
| 195 |
+
assert self.disk_store, f"can not load attentions from {load_attention_store} because disk_store is disabled."
|
| 196 |
+
self.attention_store_all_step = []
|
| 197 |
+
if self.disk_store:
|
| 198 |
+
if load_attention_store is not None:
|
| 199 |
+
self.store_dir = load_attention_store
|
| 200 |
+
flist = sorted([fpath for fpath in os.listdir(load_attention_store) if "inverted" not in fpath], key=lambda x: int(x[:-3]))
|
| 201 |
+
self.attention_store_all_step = [
|
| 202 |
+
os.path.join(load_attention_store, fn) for fn in flist
|
| 203 |
+
]
|
| 204 |
+
else:
|
| 205 |
+
if store_path is None:
|
| 206 |
+
time_string = get_time_string()
|
| 207 |
+
path = f'./trash/{self.__class__.__name__}_attention_cache_{time_string}'
|
| 208 |
+
else:
|
| 209 |
+
path = store_path
|
| 210 |
+
os.makedirs(path, exist_ok=True)
|
| 211 |
+
self.store_dir = path
|
| 212 |
+
else:
|
| 213 |
+
self.store_dir =None
|
| 214 |
+
self.step_store = self.get_empty_store()
|
| 215 |
+
self.attention_store = {}
|
| 216 |
+
self.save_self_attention = save_self_attention
|
| 217 |
+
self.latents_store = []
|
| 218 |
+
|
| 219 |
+
self.save_latents = save_latents
|
| 220 |
+
self.load_attention_store = load_attention_store
|
| 221 |
+
|
| 222 |
+
def delete(self):
|
| 223 |
+
if self.disk_store:
|
| 224 |
+
try:
|
| 225 |
+
shutil.rmtree(self.store_dir)
|
| 226 |
+
print(f"Successfully remove {self.store_dir}")
|
| 227 |
+
except:
|
| 228 |
+
print(f"Fail to remove {self.store_dir}")
|
| 229 |
+
|
| 230 |
+
def attention_control(
|
| 231 |
+
self, place_in_unet, attention_type, is_cross,
|
| 232 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 233 |
+
):
|
| 234 |
+
if attention_type == "temporal":
|
| 235 |
+
|
| 236 |
+
return self.temporal_attention_control(
|
| 237 |
+
place_in_unet, attention_type, is_cross,
|
| 238 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
elif attention_type == "spatial":
|
| 242 |
+
|
| 243 |
+
return self.spatial_attention_control(
|
| 244 |
+
place_in_unet, attention_type, is_cross,
|
| 245 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def temporal_attention_control(
|
| 249 |
+
self, place_in_unet, attention_type, is_cross,
|
| 250 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 251 |
+
):
|
| 252 |
+
|
| 253 |
+
h = q.shape[1]
|
| 254 |
+
q, k, v = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k, v))
|
| 255 |
+
attention_scores = torch.baddbmm(
|
| 256 |
+
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
| 257 |
+
q,
|
| 258 |
+
k.transpose(-1, -2),
|
| 259 |
+
beta=0,
|
| 260 |
+
alpha=1 / math.sqrt(q.size(-1)),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if attn_mask is not None:
|
| 264 |
+
if attn_mask.dtype == torch.bool:
|
| 265 |
+
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 266 |
+
attention_scores = attention_scores + attn_mask
|
| 267 |
+
|
| 268 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 269 |
+
|
| 270 |
+
# cast back to the original dtype
|
| 271 |
+
attention_probs = attention_probs.to(v.dtype)
|
| 272 |
+
|
| 273 |
+
# START OF CORE FUNCTION
|
| 274 |
+
# Record during inversion and edit the attention probs during editing
|
| 275 |
+
attention_probs = rearrange(
|
| 276 |
+
self.__call__(
|
| 277 |
+
rearrange(attention_probs, "(b h) n d -> b h n d", h=h),
|
| 278 |
+
is_cross,
|
| 279 |
+
f'{place_in_unet}_{attention_type}'
|
| 280 |
+
),
|
| 281 |
+
"b h n d -> (b h) n d"
|
| 282 |
+
)
|
| 283 |
+
# END OF CORE FUNCTION
|
| 284 |
+
|
| 285 |
+
# compute attention output
|
| 286 |
+
hidden_states = torch.bmm(attention_probs, v)
|
| 287 |
+
|
| 288 |
+
# reshape hidden_states
|
| 289 |
+
hidden_states = rearrange(hidden_states, "(b h) n d -> b h n d", h=h)
|
| 290 |
+
|
| 291 |
+
return hidden_states
|
| 292 |
+
|
| 293 |
+
def spatial_attention_control(
|
| 294 |
+
self, place_in_unet, attention_type, is_cross,
|
| 295 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 296 |
+
):
|
| 297 |
+
|
| 298 |
+
q = self.__call__(q, is_cross, f"{place_in_unet}_{attention_type}_q")
|
| 299 |
+
k = self.__call__(k, is_cross, f"{place_in_unet}_{attention_type}_k")
|
| 300 |
+
|
| 301 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 302 |
+
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
return hidden_states
|
i2vedit/prompt_attention/attention_util.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Collect all function in prompt_attention folder.
|
| 3 |
+
Provide a API `make_controller' to return an initialized AttentionControlEdit class object in the main validation loop.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Union, Tuple, List, Dict
|
| 7 |
+
import abc
|
| 8 |
+
import numpy as np
|
| 9 |
+
import copy
|
| 10 |
+
import math
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from packaging import version
|
| 18 |
+
|
| 19 |
+
from i2vedit.prompt_attention.visualization import (
|
| 20 |
+
show_cross_attention,
|
| 21 |
+
show_self_attention_comp,
|
| 22 |
+
show_self_attention,
|
| 23 |
+
show_self_attention_distance,
|
| 24 |
+
calculate_attention_mask,
|
| 25 |
+
show_avg_difference_maps
|
| 26 |
+
)
|
| 27 |
+
from i2vedit.prompt_attention.attention_store import AttentionStore, AttentionControl
|
| 28 |
+
from i2vedit.prompt_attention.attention_register import register_attention_control
|
| 29 |
+
|
| 30 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 31 |
+
|
| 32 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 33 |
+
SDP_IS_AVAILABLE = True
|
| 34 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
| 35 |
+
|
| 36 |
+
BACKEND_MAP = {
|
| 37 |
+
SDPBackend.MATH: {
|
| 38 |
+
"enable_math": True,
|
| 39 |
+
"enable_flash": False,
|
| 40 |
+
"enable_mem_efficient": False,
|
| 41 |
+
},
|
| 42 |
+
SDPBackend.FLASH_ATTENTION: {
|
| 43 |
+
"enable_math": False,
|
| 44 |
+
"enable_flash": True,
|
| 45 |
+
"enable_mem_efficient": False,
|
| 46 |
+
},
|
| 47 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
| 48 |
+
"enable_math": False,
|
| 49 |
+
"enable_flash": False,
|
| 50 |
+
"enable_mem_efficient": True,
|
| 51 |
+
},
|
| 52 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
| 53 |
+
}
|
| 54 |
+
else:
|
| 55 |
+
from contextlib import nullcontext
|
| 56 |
+
|
| 57 |
+
SDP_IS_AVAILABLE = False
|
| 58 |
+
sdp_kernel = nullcontext
|
| 59 |
+
BACKEND_MAP = {}
|
| 60 |
+
logpy.warn(
|
| 61 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
| 62 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
| 63 |
+
f"You might want to consider upgrading."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class EmptyControl:
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def step_callback(self, x_t):
|
| 72 |
+
return x_t
|
| 73 |
+
|
| 74 |
+
def between_steps(self):
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 78 |
+
return attn
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class AttentionControlEdit(AttentionStore, abc.ABC):
|
| 82 |
+
"""Decide self or cross-attention. Call the reweighting cross attention module
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
AttentionStore (_type_): ([1, 4, 8, 64, 64])
|
| 86 |
+
abc (_type_): [8, 8, 1024, 77]
|
| 87 |
+
"""
|
| 88 |
+
def get_all_last_latents(self, overlay_size):
|
| 89 |
+
return [latents[:,-overlay_size:,...] for latents in self.latents_store]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def step_callback(self, x_t):
|
| 93 |
+
x_t = super().step_callback(x_t)
|
| 94 |
+
x_t_device = x_t.device
|
| 95 |
+
x_t_dtype = x_t.dtype
|
| 96 |
+
|
| 97 |
+
# if self.previous_latents is not None:
|
| 98 |
+
# # replace latents
|
| 99 |
+
# step_in_store = self.cur_step - 1
|
| 100 |
+
# previous_latents = self.previous_latents[step_in_store]
|
| 101 |
+
# x_t[:,:len(previous_latents),...] = previous_latents.to(x_t_device, x_t_dtype)
|
| 102 |
+
if self.latent_blend:
|
| 103 |
+
|
| 104 |
+
avg_attention = self.get_average_attention()
|
| 105 |
+
masks = []
|
| 106 |
+
for key in avg_attention:
|
| 107 |
+
if 'down' in key and 'mask' in key:
|
| 108 |
+
for attn in avg_attention[key]:
|
| 109 |
+
if attn.shape[-2] == 8 * 9:
|
| 110 |
+
masks.append( attn )
|
| 111 |
+
mask = sum(masks) / len(masks)
|
| 112 |
+
mask[mask > 0.2] = 1.0
|
| 113 |
+
if self.use_inversion_attention and self.additional_attention_store is not None:
|
| 114 |
+
step_in_store = len(self.additional_attention_store.latents_store) - self.cur_step
|
| 115 |
+
elif self.additional_attention_store is None:
|
| 116 |
+
pass
|
| 117 |
+
else:
|
| 118 |
+
step_in_store = self.cur_step - 1
|
| 119 |
+
|
| 120 |
+
inverted_latents = self.additional_attention_store.latents_store[step_in_store]
|
| 121 |
+
inverted_latents = inverted_latents.to(device =x_t_device, dtype=x_t_dtype)
|
| 122 |
+
|
| 123 |
+
x_t = (1 - mask) * inverted_latents + mask * x_t
|
| 124 |
+
|
| 125 |
+
self.step_in_store_atten_dict = None
|
| 126 |
+
|
| 127 |
+
return x_t
|
| 128 |
+
|
| 129 |
+
def replace_self_attention(self, attn_base, attn_replace, reshaped_mask=None, key=None):
|
| 130 |
+
|
| 131 |
+
target_device = attn_replace.device
|
| 132 |
+
target_dtype = attn_replace.dtype
|
| 133 |
+
attn_base = attn_base.to(target_device, dtype=target_dtype)
|
| 134 |
+
|
| 135 |
+
if "temporal" in key:
|
| 136 |
+
|
| 137 |
+
if self.control_mode["temporal_self"] == "copy_v2":
|
| 138 |
+
if self.cur_step < int(self.temporal_step_thr[0] * self.num_steps):
|
| 139 |
+
return attn_base
|
| 140 |
+
if self.cur_step >= int(self.temporal_step_thr[1] * self.num_steps):
|
| 141 |
+
return attn_replace
|
| 142 |
+
if ('down' in key and self.current_pos<4) or \
|
| 143 |
+
('up' in key and self.current_pos>1):
|
| 144 |
+
return attn_replace
|
| 145 |
+
return attn_base
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
raise NotImplementedError
|
| 149 |
+
|
| 150 |
+
elif "spatial" in key:
|
| 151 |
+
|
| 152 |
+
raise NotImplementedError
|
| 153 |
+
|
| 154 |
+
def replace_cross_attention(self, attn_base, attn_replace, key=None):
|
| 155 |
+
raise NotImplementedError
|
| 156 |
+
|
| 157 |
+
def update_attention_position_dict(self, current_attention_key):
|
| 158 |
+
self.attention_position_counter_dict[current_attention_key] +=1
|
| 159 |
+
|
| 160 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 161 |
+
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
|
| 162 |
+
|
| 163 |
+
if 'mask' in place_in_unet:
|
| 164 |
+
return attn
|
| 165 |
+
|
| 166 |
+
if (not is_cross and 'temporal' in place_in_unet and (self.cur_step < self.num_temporal_self_replace[0] or self.cur_step >=self.num_temporal_self_replace[1])):
|
| 167 |
+
if self.control_mode["temporal_self"] == "copy" or \
|
| 168 |
+
self.control_mode["temporal_self"] == "copy_v2":
|
| 169 |
+
return attn
|
| 170 |
+
|
| 171 |
+
if (not is_cross and 'spatial' in place_in_unet and (self.cur_step < self.num_spatial_self_replace[0] or self.cur_step >=self.num_spatial_self_replace[1])):
|
| 172 |
+
if self.control_mode["spatial_self"] == "copy":
|
| 173 |
+
return attn
|
| 174 |
+
|
| 175 |
+
if (is_cross and (self.cur_step < self.num_cross_replace[0] or self.cur_step >= self.num_cross_replace[1])):
|
| 176 |
+
return attn
|
| 177 |
+
|
| 178 |
+
if True:#'temporal' in place_in_unet:
|
| 179 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 180 |
+
current_pos = self.attention_position_counter_dict[key]
|
| 181 |
+
|
| 182 |
+
if self.use_inversion_attention and self.additional_attention_store is not None:
|
| 183 |
+
step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
|
| 184 |
+
elif self.additional_attention_store is None:
|
| 185 |
+
return attn
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
step_in_store = self.cur_step
|
| 189 |
+
|
| 190 |
+
step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
|
| 191 |
+
|
| 192 |
+
if isinstance(step_in_store_atten_dict, str):
|
| 193 |
+
if self.step_in_store_atten_dict is None:
|
| 194 |
+
step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
|
| 195 |
+
self.step_in_store_atten_dict = step_in_store_atten_dict
|
| 196 |
+
else:
|
| 197 |
+
step_in_store_atten_dict = self.step_in_store_atten_dict
|
| 198 |
+
|
| 199 |
+
# Note that attn is append to step_store,
|
| 200 |
+
# if attn is get through clean -> noisy, we should inverse it
|
| 201 |
+
#print(key)
|
| 202 |
+
attn_base = step_in_store_atten_dict[key][current_pos]
|
| 203 |
+
self.current_pos = current_pos
|
| 204 |
+
|
| 205 |
+
self.update_attention_position_dict(key)
|
| 206 |
+
# save in format of [temporal, head, resolution, text_embedding]
|
| 207 |
+
attn_base, attn_replace = attn_base, attn
|
| 208 |
+
|
| 209 |
+
if not is_cross:
|
| 210 |
+
attn = self.replace_self_attention(attn_base, attn_replace, None, key)
|
| 211 |
+
|
| 212 |
+
#elif is_cross and (self.num_cross_replace[0] <= self.cur_step < self.num_cross_replace[1]):
|
| 213 |
+
elif is_cross:
|
| 214 |
+
attn = self.replace_cross_attention(attn_base, attn_replace, key)
|
| 215 |
+
|
| 216 |
+
return attn
|
| 217 |
+
|
| 218 |
+
else:
|
| 219 |
+
|
| 220 |
+
raise NotImplementedError("Due to CUDA RAM limit, direct replace functions for spatial are not implemented.")
|
| 221 |
+
|
| 222 |
+
def between_steps(self):
|
| 223 |
+
|
| 224 |
+
super().between_steps()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
self.step_store = self.get_empty_store()
|
| 229 |
+
|
| 230 |
+
self.attention_position_counter_dict = {
|
| 231 |
+
'down_spatial_q_cross': 0,
|
| 232 |
+
'mid_spatial_q_cross': 0,
|
| 233 |
+
'up_spatial_q_cross': 0,
|
| 234 |
+
'down_spatial_k_cross': 0,
|
| 235 |
+
'mid_spatial_k_cross': 0,
|
| 236 |
+
'up_spatial_k_cross': 0,
|
| 237 |
+
'down_spatial_mask_cross': 0,
|
| 238 |
+
'mid_spatial_mask_cross': 0,
|
| 239 |
+
'up_spatial_mask_cross': 0,
|
| 240 |
+
'down_spatial_q_self': 0,
|
| 241 |
+
'mid_spatial_q_self': 0,
|
| 242 |
+
'up_spatial_q_self': 0,
|
| 243 |
+
'down_spatial_k_self': 0,
|
| 244 |
+
'mid_spatial_k_self': 0,
|
| 245 |
+
'up_spatial_k_self': 0,
|
| 246 |
+
'down_spatial_mask_self': 0,
|
| 247 |
+
'mid_spatial_mask_self': 0,
|
| 248 |
+
'up_spatial_mask_self': 0,
|
| 249 |
+
'down_temporal_cross': 0,
|
| 250 |
+
'mid_temporal_cross': 0,
|
| 251 |
+
'up_temporal_cross': 0,
|
| 252 |
+
'down_temporal_self': 0,
|
| 253 |
+
'mid_temporal_self': 0,
|
| 254 |
+
'up_temporal_self': 0
|
| 255 |
+
}
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
def __init__(self, num_steps: int,
|
| 259 |
+
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
| 260 |
+
temporal_self_replace_steps: Union[float, Tuple[float, float]],
|
| 261 |
+
spatial_self_replace_steps: Union[float, Tuple[float, float]],
|
| 262 |
+
control_mode={"temporal_self":"copy","spatial_self":"copy"},
|
| 263 |
+
spatial_attention_chunk_size = 1,
|
| 264 |
+
additional_attention_store: AttentionStore =None,
|
| 265 |
+
use_inversion_attention: bool=False,
|
| 266 |
+
save_self_attention: bool=True,
|
| 267 |
+
save_latents: bool=True,
|
| 268 |
+
disk_store=False,
|
| 269 |
+
*args, **kwargs
|
| 270 |
+
):
|
| 271 |
+
super(AttentionControlEdit, self).__init__(
|
| 272 |
+
save_self_attention=save_self_attention,
|
| 273 |
+
save_latents=save_latents,
|
| 274 |
+
disk_store=disk_store)
|
| 275 |
+
self.additional_attention_store = additional_attention_store
|
| 276 |
+
if type(temporal_self_replace_steps) is float:
|
| 277 |
+
temporal_self_replace_steps = 0, temporal_self_replace_steps
|
| 278 |
+
if type(spatial_self_replace_steps) is float:
|
| 279 |
+
spatial_self_replace_steps = 0, spatial_self_replace_steps
|
| 280 |
+
if type(cross_replace_steps) is float:
|
| 281 |
+
cross_replace_steps = 0, cross_replace_steps
|
| 282 |
+
self.num_temporal_self_replace = int(num_steps * temporal_self_replace_steps[0]), int(num_steps * temporal_self_replace_steps[1])
|
| 283 |
+
self.num_spatial_self_replace = int(num_steps * spatial_self_replace_steps[0]), int(num_steps * spatial_self_replace_steps[1])
|
| 284 |
+
self.num_cross_replace = int(num_steps * cross_replace_steps[0]), int(num_steps * cross_replace_steps[1])
|
| 285 |
+
self.control_mode = control_mode
|
| 286 |
+
self.spatial_attention_chunk_size = spatial_attention_chunk_size
|
| 287 |
+
self.step_in_store_atten_dict = None
|
| 288 |
+
# We need to know the current position in attention
|
| 289 |
+
self.prev_attention_key_name = 0
|
| 290 |
+
self.use_inversion_attention = use_inversion_attention
|
| 291 |
+
self.attention_position_counter_dict = {
|
| 292 |
+
'down_spatial_q_cross': 0,
|
| 293 |
+
'mid_spatial_q_cross': 0,
|
| 294 |
+
'up_spatial_q_cross': 0,
|
| 295 |
+
'down_spatial_k_cross': 0,
|
| 296 |
+
'mid_spatial_k_cross': 0,
|
| 297 |
+
'up_spatial_k_cross': 0,
|
| 298 |
+
'down_spatial_mask_cross': 0,
|
| 299 |
+
'mid_spatial_mask_cross': 0,
|
| 300 |
+
'up_spatial_mask_cross': 0,
|
| 301 |
+
'down_spatial_q_self': 0,
|
| 302 |
+
'mid_spatial_q_self': 0,
|
| 303 |
+
'up_spatial_q_self': 0,
|
| 304 |
+
'down_spatial_k_self': 0,
|
| 305 |
+
'mid_spatial_k_self': 0,
|
| 306 |
+
'up_spatial_k_self': 0,
|
| 307 |
+
'down_spatial_mask_self': 0,
|
| 308 |
+
'mid_spatial_mask_self': 0,
|
| 309 |
+
'up_spatial_mask_self': 0,
|
| 310 |
+
'down_temporal_cross': 0,
|
| 311 |
+
'mid_temporal_cross': 0,
|
| 312 |
+
'up_temporal_cross': 0,
|
| 313 |
+
'down_temporal_self': 0,
|
| 314 |
+
'mid_temporal_self': 0,
|
| 315 |
+
'up_temporal_self': 0
|
| 316 |
+
}
|
| 317 |
+
self.mask_thr = kwargs.get("mask_thr", 0.35)
|
| 318 |
+
self.latent_blend = kwargs.get('latent_blend', False)
|
| 319 |
+
|
| 320 |
+
self.temporal_step_thr = kwargs.get("temporal_step_thr", [0.4,0.8])
|
| 321 |
+
self.num_steps = num_steps
|
| 322 |
+
|
| 323 |
+
def spatial_attention_control(
|
| 324 |
+
self, place_in_unet, attention_type, is_cross,
|
| 325 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False
|
| 326 |
+
):
|
| 327 |
+
|
| 328 |
+
return self.spatial_attention_matching(
|
| 329 |
+
place_in_unet, attention_type, is_cross,
|
| 330 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False,
|
| 331 |
+
mode = self.control_mode["spatial_self"]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def spatial_attention_matching(
|
| 336 |
+
self, place_in_unet, attention_type, is_cross,
|
| 337 |
+
q, k, v, attn_mask, dropout_p=0.0, is_causal=False,
|
| 338 |
+
mode = "matching"
|
| 339 |
+
):
|
| 340 |
+
place_in_unet = f"{place_in_unet}_{attention_type}"
|
| 341 |
+
with sdp_kernel(**BACKEND_MAP[None]):
|
| 342 |
+
# print("register", q.shape, k.shape, v.shape)
|
| 343 |
+
|
| 344 |
+
# fetch inversion q and k
|
| 345 |
+
key_q = f"{place_in_unet}_q_{'cross' if is_cross else 'self'}"
|
| 346 |
+
key_k = f"{place_in_unet}_k_{'cross' if is_cross else 'self'}"
|
| 347 |
+
current_pos_q = self.attention_position_counter_dict[key_q]
|
| 348 |
+
current_pos_k = self.attention_position_counter_dict[key_k]
|
| 349 |
+
|
| 350 |
+
if self.use_inversion_attention and self.additional_attention_store is not None:
|
| 351 |
+
step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
|
| 352 |
+
else:
|
| 353 |
+
step_in_store = self.cur_step
|
| 354 |
+
|
| 355 |
+
step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
|
| 356 |
+
|
| 357 |
+
if isinstance(step_in_store_atten_dict, str):
|
| 358 |
+
if self.step_in_store_atten_dict is None:
|
| 359 |
+
step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
|
| 360 |
+
self.step_in_store_atten_dict = step_in_store_atten_dict
|
| 361 |
+
else:
|
| 362 |
+
step_in_store_atten_dict = self.step_in_store_atten_dict
|
| 363 |
+
|
| 364 |
+
q0s = step_in_store_atten_dict[key_q][current_pos_q].to(q.device)
|
| 365 |
+
k0s = step_in_store_atten_dict[key_k][current_pos_k].to(k.device)
|
| 366 |
+
|
| 367 |
+
self.update_attention_position_dict(key_q)
|
| 368 |
+
self.update_attention_position_dict(key_k)
|
| 369 |
+
|
| 370 |
+
qs, ks, vs = q, k, v
|
| 371 |
+
|
| 372 |
+
h = q.shape[1]
|
| 373 |
+
res = int(np.sqrt(q.shape[-2] / (9*16)))
|
| 374 |
+
if res == 0:
|
| 375 |
+
res = 1
|
| 376 |
+
#res = int(np.sqrt(q.shape[-2] / (8*14)))
|
| 377 |
+
bs = self.spatial_attention_chunk_size
|
| 378 |
+
if bs is None: bs = qs.shape[0]
|
| 379 |
+
N = qs.shape[0] // bs
|
| 380 |
+
assert qs.shape[0] % bs == 0
|
| 381 |
+
i1st, n1st = qs.shape[0]//2//bs, qs.shape[0]//2%bs
|
| 382 |
+
outs = []
|
| 383 |
+
masks = []
|
| 384 |
+
|
| 385 |
+
# this might reduce time costs but will introduce inaccurate motions
|
| 386 |
+
# if current_pos_q >= 6 and 'up' in key_q:
|
| 387 |
+
# return F.scaled_dot_product_attention(
|
| 388 |
+
# q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
| 389 |
+
# )
|
| 390 |
+
|
| 391 |
+
for i in range(N):
|
| 392 |
+
q = qs[i*bs:(i+1)*bs,...].type(torch.float32)
|
| 393 |
+
k = ks[i*bs:(i+1)*bs,...].type(torch.float32)
|
| 394 |
+
v = vs[i*bs:(i+1)*bs,...].type(torch.float32)
|
| 395 |
+
|
| 396 |
+
q, k, v = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k, v))
|
| 397 |
+
|
| 398 |
+
with torch.autocast("cuda", enabled=False):
|
| 399 |
+
attention_scores = torch.baddbmm(
|
| 400 |
+
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
| 401 |
+
q,
|
| 402 |
+
k.transpose(-1, -2),
|
| 403 |
+
beta=0,
|
| 404 |
+
alpha=1 / math.sqrt(q.size(-1)),
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if attn_mask is not None:
|
| 408 |
+
if attn_mask.dtype == torch.bool:
|
| 409 |
+
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 410 |
+
attention_scores = attention_scores + attn_mask
|
| 411 |
+
|
| 412 |
+
attention_probs = attention_scores.softmax(dim=-1).to(vs.dtype)
|
| 413 |
+
|
| 414 |
+
# only compute conditional output
|
| 415 |
+
if i >= N//2:
|
| 416 |
+
|
| 417 |
+
q0 = q0s[(i-N//2)*bs:(i-N//2+1)*bs,...].type(torch.float32)
|
| 418 |
+
k0 = k0s[(i-N//2)*bs:(i-N//2+1)*bs,...].type(torch.float32)
|
| 419 |
+
|
| 420 |
+
q0, k0 = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q0, k0))
|
| 421 |
+
|
| 422 |
+
with torch.autocast("cuda", enabled=False):
|
| 423 |
+
attention_scores_0 = torch.baddbmm(
|
| 424 |
+
torch.empty(q0.shape[0], q0.shape[1], k0.shape[1], dtype=q0.dtype, device=q0.device),
|
| 425 |
+
q0,
|
| 426 |
+
k0.transpose(-1, -2),
|
| 427 |
+
beta=0,
|
| 428 |
+
alpha=1 / math.sqrt(q0.size(-1)),
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
attention_probs_0 = attention_scores_0.softmax(dim=-1).to(vs.dtype)
|
| 432 |
+
|
| 433 |
+
attention_probs, attention_probs_0 = \
|
| 434 |
+
map(lambda t: rearrange(t, "(b h) n d -> b h n d", h=h),
|
| 435 |
+
(attention_probs, attention_probs_0))
|
| 436 |
+
|
| 437 |
+
if mode == "masked_copy":
|
| 438 |
+
|
| 439 |
+
mask = torch.sum(
|
| 440 |
+
torch.mean(
|
| 441 |
+
torch.abs(attention_probs_0 - attention_probs),
|
| 442 |
+
dim=1
|
| 443 |
+
),
|
| 444 |
+
dim=2
|
| 445 |
+
).reshape(bs,1,-1,1).clamp(0,2)/2.0
|
| 446 |
+
mask_thr = (self.mask_thr[1]-self.mask_thr[0]) / (qs.shape[0]//2)*(i-N//2) + self.mask_thr[0]
|
| 447 |
+
mask_tmp = mask.clone()
|
| 448 |
+
mask[mask>=mask_thr] = 1.0
|
| 449 |
+
masks.append(mask)
|
| 450 |
+
|
| 451 |
+
# apply mask
|
| 452 |
+
attention_probs = (1 - mask) * attention_probs_0 + mask * attention_probs
|
| 453 |
+
|
| 454 |
+
else:
|
| 455 |
+
raise NotImplementedError
|
| 456 |
+
|
| 457 |
+
attention_probs = rearrange(attention_probs, "b h n d -> (b h) n d")
|
| 458 |
+
|
| 459 |
+
# compute attention output
|
| 460 |
+
hidden_states = torch.bmm(attention_probs, v)
|
| 461 |
+
|
| 462 |
+
# reshape hidden_states
|
| 463 |
+
hidden_states = rearrange(hidden_states, "(b h) n d -> b h n d", h=h)
|
| 464 |
+
|
| 465 |
+
outs.append(hidden_states)
|
| 466 |
+
|
| 467 |
+
if mode == "masked_copy":
|
| 468 |
+
|
| 469 |
+
# masks = rearrange(torch.cat(masks, 0), "b 1 (h w) 1 -> h (b w)", h=res*9)
|
| 470 |
+
masks = torch.cat(masks, 0)
|
| 471 |
+
#print(f"{place_in_unet}_masked_copy")
|
| 472 |
+
# save mask
|
| 473 |
+
_ = self.__call__(masks, is_cross, f"{place_in_unet}_mask")
|
| 474 |
+
|
| 475 |
+
return torch.cat(outs, 0)
|
| 476 |
+
|
| 477 |
+
class ConsistencyAttentionControl(AttentionStore, abc.ABC):
|
| 478 |
+
"""Decide self or cross-attention. Call the reweighting cross attention module
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
AttentionStore (_type_): ([1, 4, 8, 64, 64])
|
| 482 |
+
abc (_type_): [8, 8, 1024, 77]
|
| 483 |
+
"""
|
| 484 |
+
def step_callback(self, x_t):
|
| 485 |
+
x_t = super().step_callback(x_t)
|
| 486 |
+
x_t_device = x_t.device
|
| 487 |
+
x_t_dtype = x_t.dtype
|
| 488 |
+
|
| 489 |
+
# if self.previous_latents is not None:
|
| 490 |
+
# # replace latents
|
| 491 |
+
# step_in_store = self.cur_step - 1
|
| 492 |
+
# previous_latents = self.previous_latents[step_in_store]
|
| 493 |
+
# x_t[:,:len(previous_latents),...] = previous_latents.to(x_t_device, x_t_dtype)
|
| 494 |
+
|
| 495 |
+
self.step_in_store_atten_dict = None
|
| 496 |
+
|
| 497 |
+
return x_t
|
| 498 |
+
|
| 499 |
+
def update_attention_position_dict(self, current_attention_key):
|
| 500 |
+
self.attention_position_counter_dict[current_attention_key] +=1
|
| 501 |
+
|
| 502 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 503 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 504 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 505 |
+
|
| 506 |
+
self.cur_att_layer += 1
|
| 507 |
+
|
| 508 |
+
return attn
|
| 509 |
+
|
| 510 |
+
def set_cur_step(self, step: int = 0):
|
| 511 |
+
self.cur_step = step
|
| 512 |
+
|
| 513 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 514 |
+
super(ConsistencyAttentionControl, self).forward(attn, is_cross, place_in_unet)
|
| 515 |
+
|
| 516 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 517 |
+
current_pos = self.attention_position_counter_dict[key]
|
| 518 |
+
|
| 519 |
+
if self.use_inversion_attention and self.additional_attention_store is not None:
|
| 520 |
+
step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
|
| 521 |
+
elif self.additional_attention_store is None:
|
| 522 |
+
return attn
|
| 523 |
+
|
| 524 |
+
else:
|
| 525 |
+
step_in_store = self.cur_step
|
| 526 |
+
|
| 527 |
+
step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
|
| 528 |
+
|
| 529 |
+
if isinstance(step_in_store_atten_dict, str):
|
| 530 |
+
if self.step_in_store_atten_dict is None:
|
| 531 |
+
step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
|
| 532 |
+
self.step_in_store_atten_dict = step_in_store_atten_dict
|
| 533 |
+
else:
|
| 534 |
+
step_in_store_atten_dict = self.step_in_store_atten_dict
|
| 535 |
+
|
| 536 |
+
# Note that attn is append to step_store,
|
| 537 |
+
# if attn is get through clean -> noisy, we should inverse it
|
| 538 |
+
#print("consistency", key)
|
| 539 |
+
attn_base = step_in_store_atten_dict[key][current_pos].to(attn.device, attn.dtype)
|
| 540 |
+
attn_base = attn_base.detach()
|
| 541 |
+
|
| 542 |
+
self.update_attention_position_dict(key)
|
| 543 |
+
# save in format of [temporal, head, resolution, text_embedding]
|
| 544 |
+
|
| 545 |
+
attn = torch.cat([attn_base, attn], dim=2)
|
| 546 |
+
|
| 547 |
+
return attn
|
| 548 |
+
|
| 549 |
+
@staticmethod
|
| 550 |
+
def get_empty_store():
|
| 551 |
+
return {
|
| 552 |
+
"down_temporal_k_self": [], "mid_temporal_k_self": [], "up_temporal_k_self": [],
|
| 553 |
+
"down_temporal_v_self": [], "mid_temporal_v_self": [], "up_temporal_v_self": []
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
def between_steps(self):
|
| 557 |
+
|
| 558 |
+
super().between_steps()
|
| 559 |
+
|
| 560 |
+
self.step_store = self.get_empty_store()
|
| 561 |
+
|
| 562 |
+
self.attention_position_counter_dict = {
|
| 563 |
+
'down_temporal_k_self': 0,
|
| 564 |
+
'mid_temporal_k_self': 0,
|
| 565 |
+
'up_temporal_k_self': 0,
|
| 566 |
+
'down_temporal_v_self': 0,
|
| 567 |
+
'mid_temporal_v_self': 0,
|
| 568 |
+
'up_temporal_v_self': 0
|
| 569 |
+
}
|
| 570 |
+
return
|
| 571 |
+
|
| 572 |
+
def __init__(self,
|
| 573 |
+
additional_attention_store: AttentionStore =None,
|
| 574 |
+
use_inversion_attention: bool=False,
|
| 575 |
+
load_attention_store: str = None,
|
| 576 |
+
save_self_attention: bool=True,
|
| 577 |
+
save_latents: bool=True,
|
| 578 |
+
disk_store=False,
|
| 579 |
+
store_path:str=None
|
| 580 |
+
):
|
| 581 |
+
super(ConsistencyAttentionControl, self).__init__(
|
| 582 |
+
save_self_attention=save_self_attention,
|
| 583 |
+
load_attention_store=load_attention_store,
|
| 584 |
+
save_latents=save_latents,
|
| 585 |
+
disk_store=disk_store,
|
| 586 |
+
store_path=store_path
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
self.additional_attention_store = additional_attention_store
|
| 590 |
+
self.step_in_store_atten_dict = None
|
| 591 |
+
# We need to know the current position in attention
|
| 592 |
+
self.use_inversion_attention = use_inversion_attention
|
| 593 |
+
self.attention_position_counter_dict = {
|
| 594 |
+
'down_temporal_k_self': 0,
|
| 595 |
+
'mid_temporal_k_self': 0,
|
| 596 |
+
'up_temporal_k_self': 0,
|
| 597 |
+
'down_temporal_v_self': 0,
|
| 598 |
+
'mid_temporal_v_self': 0,
|
| 599 |
+
'up_temporal_v_self': 0
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def make_controller(
|
| 605 |
+
cross_replace_steps: Dict[str, float], self_replace_steps: float=0.0,
|
| 606 |
+
additional_attention_store=None, use_inversion_attention = False,
|
| 607 |
+
NUM_DDIM_STEPS=None,
|
| 608 |
+
save_path = None,
|
| 609 |
+
save_self_attention = True,
|
| 610 |
+
disk_store = False
|
| 611 |
+
) -> AttentionControlEdit:
|
| 612 |
+
controller = AttentionControlEdit(NUM_DDIM_STEPS,
|
| 613 |
+
cross_replace_steps=cross_replace_steps,
|
| 614 |
+
self_replace_steps=self_replace_steps,
|
| 615 |
+
additional_attention_store=additional_attention_store,
|
| 616 |
+
use_inversion_attention = use_inversion_attention,
|
| 617 |
+
save_self_attention = save_self_attention,
|
| 618 |
+
disk_store=disk_store
|
| 619 |
+
)
|
| 620 |
+
return controller
|
| 621 |
+
|
i2vedit/prompt_attention/common/__init__.py
ADDED
|
File without changes
|
i2vedit/prompt_attention/common/image_util.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import textwrap
|
| 4 |
+
|
| 5 |
+
import imageio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Sequence
|
| 8 |
+
import requests
|
| 9 |
+
import cv2
|
| 10 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".JPEG")
|
| 18 |
+
|
| 19 |
+
FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf"
|
| 20 |
+
FONT_PATH = "./docs/OpenSans-Regular.ttf"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image:
|
| 24 |
+
new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color)
|
| 25 |
+
new_image.paste(image, (left, top))
|
| 26 |
+
return new_image
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_font_opensans(path=FONT_PATH):
|
| 30 |
+
font_url = FONT_URL
|
| 31 |
+
response = requests.get(font_url)
|
| 32 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 33 |
+
with open(path, "wb") as f:
|
| 34 |
+
f.write(response.content)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image:
|
| 38 |
+
image_w = image.width
|
| 39 |
+
_, _, text_w, text_h = font.getbbox(text)
|
| 40 |
+
line_size = math.floor(len(text) * image_w / text_w)
|
| 41 |
+
|
| 42 |
+
lines = textwrap.wrap(text, width=line_size)
|
| 43 |
+
padding = text_h * len(lines)
|
| 44 |
+
image = pad(image, top=padding + 3)
|
| 45 |
+
|
| 46 |
+
ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font)
|
| 47 |
+
return image
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def annotate_image(image: Image.Image, text: str, font_size: int = 15):
|
| 51 |
+
if not os.path.isfile(FONT_PATH):
|
| 52 |
+
download_font_opensans()
|
| 53 |
+
font = ImageFont.truetype(FONT_PATH, size=font_size)
|
| 54 |
+
return annotate_image_with_font(image=image, text=text, font=font)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image:
|
| 58 |
+
if isinstance(images[0], np.ndarray):
|
| 59 |
+
images = [Image.fromarray(i) for i in images]
|
| 60 |
+
|
| 61 |
+
if rows is None:
|
| 62 |
+
assert cols is not None
|
| 63 |
+
rows = math.ceil(len(images) / cols)
|
| 64 |
+
else:
|
| 65 |
+
cols = math.ceil(len(images) / rows)
|
| 66 |
+
|
| 67 |
+
w, h = images[0].size
|
| 68 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 69 |
+
for i, image in enumerate(images):
|
| 70 |
+
if image.size != (w, h):
|
| 71 |
+
image = image.resize((w, h))
|
| 72 |
+
grid.paste(image, box=(i % cols * w, i // cols * h))
|
| 73 |
+
return grid
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def save_images_as_gif(
|
| 77 |
+
images: Sequence[Image.Image],
|
| 78 |
+
save_path: str,
|
| 79 |
+
loop=0,
|
| 80 |
+
duration=100,
|
| 81 |
+
optimize=False,
|
| 82 |
+
) -> None:
|
| 83 |
+
|
| 84 |
+
images[0].save(
|
| 85 |
+
save_path,
|
| 86 |
+
save_all=True,
|
| 87 |
+
append_images=images[1:],
|
| 88 |
+
optimize=optimize,
|
| 89 |
+
loop=loop,
|
| 90 |
+
duration=duration,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def save_images_as_mp4(
|
| 94 |
+
images: Sequence[Image.Image],
|
| 95 |
+
save_path: str,
|
| 96 |
+
) -> None:
|
| 97 |
+
|
| 98 |
+
writer_edit = imageio.get_writer(
|
| 99 |
+
save_path,
|
| 100 |
+
fps=10)
|
| 101 |
+
for i in images:
|
| 102 |
+
init_image = i.convert("RGB")
|
| 103 |
+
writer_edit.append_data(np.array(init_image))
|
| 104 |
+
writer_edit.close()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_images_as_folder(
|
| 109 |
+
images: Sequence[Image.Image],
|
| 110 |
+
save_path: str,
|
| 111 |
+
) -> None:
|
| 112 |
+
os.makedirs(save_path, exist_ok=True)
|
| 113 |
+
for index, image in enumerate(images):
|
| 114 |
+
init_image = image
|
| 115 |
+
if len(np.array(init_image).shape) == 3:
|
| 116 |
+
cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)[:, :, ::-1])
|
| 117 |
+
else:
|
| 118 |
+
cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image))
|
| 119 |
+
|
| 120 |
+
def log_train_samples(
|
| 121 |
+
train_dataloader,
|
| 122 |
+
save_path,
|
| 123 |
+
num_batch: int = 4,
|
| 124 |
+
):
|
| 125 |
+
train_samples = []
|
| 126 |
+
for idx, batch in enumerate(train_dataloader):
|
| 127 |
+
if idx >= num_batch:
|
| 128 |
+
break
|
| 129 |
+
train_samples.append(batch["images"])
|
| 130 |
+
|
| 131 |
+
train_samples = torch.cat(train_samples).numpy()
|
| 132 |
+
train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
|
| 133 |
+
train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
|
| 134 |
+
train_samples = numpy_batch_seq_to_pil(train_samples)
|
| 135 |
+
train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
|
| 136 |
+
# save_images_as_gif(train_samples, save_path)
|
| 137 |
+
save_gif_mp4_folder_type(train_samples, save_path)
|
| 138 |
+
|
| 139 |
+
def log_train_reg_samples(
|
| 140 |
+
train_dataloader,
|
| 141 |
+
save_path,
|
| 142 |
+
num_batch: int = 4,
|
| 143 |
+
):
|
| 144 |
+
train_samples = []
|
| 145 |
+
for idx, batch in enumerate(train_dataloader):
|
| 146 |
+
if idx >= num_batch:
|
| 147 |
+
break
|
| 148 |
+
train_samples.append(batch["class_images"])
|
| 149 |
+
|
| 150 |
+
train_samples = torch.cat(train_samples).numpy()
|
| 151 |
+
train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
|
| 152 |
+
train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
|
| 153 |
+
train_samples = numpy_batch_seq_to_pil(train_samples)
|
| 154 |
+
train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
|
| 155 |
+
# save_images_as_gif(train_samples, save_path)
|
| 156 |
+
save_gif_mp4_folder_type(train_samples, save_path)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def save_gif_mp4_folder_type(images, save_path, save_gif=True):
|
| 160 |
+
|
| 161 |
+
if isinstance(images[0], np.ndarray):
|
| 162 |
+
images = [Image.fromarray(i) for i in images]
|
| 163 |
+
elif isinstance(images[0], torch.Tensor):
|
| 164 |
+
images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images]
|
| 165 |
+
save_path_mp4 = save_path.replace('gif', 'mp4')
|
| 166 |
+
save_path_folder = save_path.replace('.gif', '')
|
| 167 |
+
if save_gif: save_images_as_gif(images, save_path)
|
| 168 |
+
save_images_as_mp4(images, save_path_mp4)
|
| 169 |
+
save_images_as_folder(images, save_path_folder)
|
| 170 |
+
|
| 171 |
+
# copy from video_diffusion/pipelines/stable_diffusion.py
|
| 172 |
+
def numpy_seq_to_pil(images):
|
| 173 |
+
"""
|
| 174 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 175 |
+
"""
|
| 176 |
+
if images.ndim == 3:
|
| 177 |
+
images = images[None, ...]
|
| 178 |
+
images = (images * 255).round().astype("uint8")
|
| 179 |
+
if images.shape[-1] == 1:
|
| 180 |
+
# special case for grayscale (single channel) images
|
| 181 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 182 |
+
else:
|
| 183 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 184 |
+
|
| 185 |
+
return pil_images
|
| 186 |
+
|
| 187 |
+
# copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py
|
| 188 |
+
def numpy_batch_seq_to_pil(images):
|
| 189 |
+
pil_images = []
|
| 190 |
+
for sequence in images:
|
| 191 |
+
pil_images.append(numpy_seq_to_pil(sequence))
|
| 192 |
+
return pil_images
|
i2vedit/prompt_attention/common/instantiate_from_config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copy from stable diffusion
|
| 3 |
+
"""
|
| 4 |
+
import importlib
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def instantiate_from_config(config:dict, **args_from_code):
|
| 8 |
+
"""Util funciton to decompose differenct modules using config
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
config (dict): with key of "target" and "params", better from yaml
|
| 12 |
+
static
|
| 13 |
+
args_from_code: additional con
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
a validation/training pipeline, a module
|
| 18 |
+
"""
|
| 19 |
+
if not "target" in config:
|
| 20 |
+
if config == '__is_first_stage__':
|
| 21 |
+
return None
|
| 22 |
+
elif config == "__is_unconditional__":
|
| 23 |
+
return None
|
| 24 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 25 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_obj_from_str(string, reload=False):
|
| 29 |
+
module, cls = string.rsplit(".", 1)
|
| 30 |
+
if reload:
|
| 31 |
+
module_imp = importlib.import_module(module)
|
| 32 |
+
importlib.reload(module_imp)
|
| 33 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
i2vedit/prompt_attention/common/logger.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging, logging.handlers
|
| 3 |
+
from accelerate.logging import get_logger
|
| 4 |
+
|
| 5 |
+
def get_logger_config_path(logdir):
|
| 6 |
+
# accelerate handles the logger in multiprocessing
|
| 7 |
+
logger = get_logger(__name__)
|
| 8 |
+
logging.basicConfig(
|
| 9 |
+
level=logging.INFO,
|
| 10 |
+
format='%(asctime)s:%(levelname)s : %(message)s',
|
| 11 |
+
datefmt='%a, %d %b %Y %H:%M:%S',
|
| 12 |
+
filename=os.path.join(logdir, 'log.log'),
|
| 13 |
+
filemode='w')
|
| 14 |
+
chlr = logging.StreamHandler()
|
| 15 |
+
chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s'))
|
| 16 |
+
logger.logger.addHandler(chlr)
|
| 17 |
+
return logger
|
i2vedit/prompt_attention/common/set_seed.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from accelerate.utils import set_seed
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def video_set_seed(seed: int):
|
| 12 |
+
"""
|
| 13 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
seed (`int`): The seed to set.
|
| 17 |
+
device_specific (`bool`, *optional*, defaults to `False`):
|
| 18 |
+
Whether to differ the seed on each device slightly with `self.process_index`.
|
| 19 |
+
"""
|
| 20 |
+
set_seed(seed)
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed_all(seed)
|
| 25 |
+
torch.backends.cudnn.benchmark = False
|
| 26 |
+
# torch.use_deterministic_algorithms(True, warn_only=True)
|
| 27 |
+
# [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic)
|
| 28 |
+
|
i2vedit/prompt_attention/common/util.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import copy
|
| 4 |
+
import inspect
|
| 5 |
+
import datetime
|
| 6 |
+
from typing import List, Tuple, Optional, Dict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def glob_files(
|
| 10 |
+
root_path: str,
|
| 11 |
+
extensions: Tuple[str],
|
| 12 |
+
recursive: bool = True,
|
| 13 |
+
skip_hidden_directories: bool = True,
|
| 14 |
+
max_directories: Optional[int] = None,
|
| 15 |
+
max_files: Optional[int] = None,
|
| 16 |
+
relative_path: bool = False,
|
| 17 |
+
) -> Tuple[List[str], bool, bool]:
|
| 18 |
+
"""glob files with specified extensions
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
root_path (str): _description_
|
| 22 |
+
extensions (Tuple[str]): _description_
|
| 23 |
+
recursive (bool, optional): _description_. Defaults to True.
|
| 24 |
+
skip_hidden_directories (bool, optional): _description_. Defaults to True.
|
| 25 |
+
max_directories (Optional[int], optional): max number of directories to search. Defaults to None.
|
| 26 |
+
max_files (Optional[int], optional): max file number limit. Defaults to None.
|
| 27 |
+
relative_path (bool, optional): _description_. Defaults to False.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tuple[List[str], bool, bool]: _description_
|
| 31 |
+
"""
|
| 32 |
+
paths = []
|
| 33 |
+
hit_max_directories = False
|
| 34 |
+
hit_max_files = False
|
| 35 |
+
for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)):
|
| 36 |
+
if skip_hidden_directories and os.path.basename(directory).startswith("."):
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
if max_directories is not None and directory_idx >= max_directories:
|
| 40 |
+
hit_max_directories = True
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
paths += [
|
| 44 |
+
os.path.join(directory, fname)
|
| 45 |
+
for fname in sorted(fnames)
|
| 46 |
+
if fname.lower().endswith(extensions)
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
if not recursive:
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
if max_files is not None and len(paths) > max_files:
|
| 53 |
+
hit_max_files = True
|
| 54 |
+
paths = paths[:max_files]
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
if relative_path:
|
| 58 |
+
paths = [os.path.relpath(p, root_path) for p in paths]
|
| 59 |
+
|
| 60 |
+
return paths, hit_max_directories, hit_max_files
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_time_string() -> str:
|
| 64 |
+
x = datetime.datetime.now()
|
| 65 |
+
return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_function_args() -> Dict:
|
| 69 |
+
frame = sys._getframe(1)
|
| 70 |
+
args, _, _, values = inspect.getargvalues(frame)
|
| 71 |
+
args_dict = copy.deepcopy({arg: values[arg] for arg in args})
|
| 72 |
+
|
| 73 |
+
return args_dict
|
i2vedit/prompt_attention/ptp_utils.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
utils code for image visualization
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Copyright 2022 Google LLC
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from PIL import Image
|
| 23 |
+
import cv2
|
| 24 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
| 25 |
+
|
| 26 |
+
import datetime
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
| 30 |
+
h, w, c = image.shape
|
| 31 |
+
offset = int(h * .2)
|
| 32 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
| 33 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 34 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
| 35 |
+
img[:h] = image
|
| 36 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
| 37 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
| 38 |
+
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
| 39 |
+
return img
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None):
|
| 43 |
+
if type(images) is list:
|
| 44 |
+
num_empty = len(images) % num_rows
|
| 45 |
+
elif images.ndim == 4:
|
| 46 |
+
num_empty = images.shape[0] % num_rows
|
| 47 |
+
else:
|
| 48 |
+
images = [images]
|
| 49 |
+
num_empty = 0
|
| 50 |
+
|
| 51 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
| 52 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
| 53 |
+
num_items = len(images)
|
| 54 |
+
|
| 55 |
+
h, w, c = images[0].shape
|
| 56 |
+
offset = int(h * offset_ratio)
|
| 57 |
+
num_cols = num_items // num_rows
|
| 58 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
| 59 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
| 60 |
+
for i in range(num_rows):
|
| 61 |
+
for j in range(num_cols):
|
| 62 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
| 63 |
+
i * num_cols + j]
|
| 64 |
+
|
| 65 |
+
if save_path is not None:
|
| 66 |
+
pil_img = Image.fromarray(image_)
|
| 67 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 68 |
+
pil_img.save(f'{save_path}/{now}.png')
|
| 69 |
+
# display(pil_img)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def register_attention_control_p2p_deprecated(model, controller):
|
| 74 |
+
"Original code from prompt to prompt"
|
| 75 |
+
def ca_forward(self, place_in_unet):
|
| 76 |
+
to_out = self.to_out
|
| 77 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
| 78 |
+
to_out = self.to_out[0]
|
| 79 |
+
else:
|
| 80 |
+
to_out = self.to_out
|
| 81 |
+
|
| 82 |
+
# def forward(x, encoder_hidden_states=None, attention_mask=None):
|
| 83 |
+
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
| 84 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 85 |
+
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 86 |
+
|
| 87 |
+
query = self.to_q(hidden_states)
|
| 88 |
+
query = self.head_to_batch_dim(query)
|
| 89 |
+
|
| 90 |
+
is_cross = encoder_hidden_states is not None
|
| 91 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 92 |
+
key = self.to_k(encoder_hidden_states)
|
| 93 |
+
value = self.to_v(encoder_hidden_states)
|
| 94 |
+
key = self.head_to_batch_dim(key)
|
| 95 |
+
value = self.head_to_batch_dim(value)
|
| 96 |
+
|
| 97 |
+
attention_probs = self.get_attention_scores(query, key, attention_mask) # [16, 4096, 4096]
|
| 98 |
+
attention_probs = controller(attention_probs, is_cross, place_in_unet)
|
| 99 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 100 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
| 101 |
+
|
| 102 |
+
# linear proj
|
| 103 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 104 |
+
# dropout
|
| 105 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 106 |
+
|
| 107 |
+
return hidden_states
|
| 108 |
+
|
| 109 |
+
return forward
|
| 110 |
+
|
| 111 |
+
class DummyController:
|
| 112 |
+
|
| 113 |
+
def __call__(self, *args):
|
| 114 |
+
return args[0]
|
| 115 |
+
|
| 116 |
+
def __init__(self):
|
| 117 |
+
self.num_att_layers = 0
|
| 118 |
+
|
| 119 |
+
if controller is None:
|
| 120 |
+
controller = DummyController()
|
| 121 |
+
|
| 122 |
+
def register_recr(net_, count, place_in_unet):
|
| 123 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
| 124 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
| 125 |
+
return count + 1
|
| 126 |
+
elif hasattr(net_, 'children'):
|
| 127 |
+
for net__ in net_.children():
|
| 128 |
+
count = register_recr(net__, count, place_in_unet)
|
| 129 |
+
return count
|
| 130 |
+
|
| 131 |
+
cross_att_count = 0
|
| 132 |
+
sub_nets = model.unet.named_children()
|
| 133 |
+
for net in sub_nets:
|
| 134 |
+
if "down" in net[0]:
|
| 135 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
| 136 |
+
elif "up" in net[0]:
|
| 137 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
| 138 |
+
elif "mid" in net[0]:
|
| 139 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
| 140 |
+
|
| 141 |
+
controller.num_att_layers = cross_att_count
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_word_inds(text: str, word_place: int, tokenizer):
|
| 145 |
+
split_text = text.split(" ")
|
| 146 |
+
if type(word_place) is str:
|
| 147 |
+
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
| 148 |
+
elif type(word_place) is int:
|
| 149 |
+
word_place = [word_place]
|
| 150 |
+
out = []
|
| 151 |
+
if len(word_place) > 0:
|
| 152 |
+
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
| 153 |
+
cur_len, ptr = 0, 0
|
| 154 |
+
|
| 155 |
+
for i in range(len(words_encode)):
|
| 156 |
+
cur_len += len(words_encode[i])
|
| 157 |
+
if ptr in word_place:
|
| 158 |
+
out.append(i + 1)
|
| 159 |
+
if cur_len >= len(split_text[ptr]):
|
| 160 |
+
ptr += 1
|
| 161 |
+
cur_len = 0
|
| 162 |
+
return np.array(out)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
|
| 166 |
+
word_inds: Optional[torch.Tensor]=None):
|
| 167 |
+
# Edit the alpha map during attention map editing
|
| 168 |
+
if type(bounds) is float:
|
| 169 |
+
bounds = 0, bounds
|
| 170 |
+
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
| 171 |
+
if word_inds is None:
|
| 172 |
+
word_inds = torch.arange(alpha.shape[2])
|
| 173 |
+
alpha[: start, prompt_ind, word_inds] = 0
|
| 174 |
+
alpha[start: end, prompt_ind, word_inds] = 1
|
| 175 |
+
alpha[end:, prompt_ind, word_inds] = 0
|
| 176 |
+
return alpha
|
| 177 |
+
|
| 178 |
+
import omegaconf
|
| 179 |
+
def get_time_words_attention_alpha(prompts, num_steps,
|
| 180 |
+
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
|
| 181 |
+
tokenizer, max_num_words=77):
|
| 182 |
+
# Not understand
|
| 183 |
+
if (type(cross_replace_steps) is not dict) and \
|
| 184 |
+
(type(cross_replace_steps) is not omegaconf.dictconfig.DictConfig):
|
| 185 |
+
cross_replace_steps = {"default_": cross_replace_steps}
|
| 186 |
+
if "default_" not in cross_replace_steps:
|
| 187 |
+
cross_replace_steps["default_"] = (0., 1.)
|
| 188 |
+
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
|
| 189 |
+
for i in range(len(prompts) - 1):
|
| 190 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
|
| 191 |
+
i)
|
| 192 |
+
for key, item in cross_replace_steps.items():
|
| 193 |
+
if key != "default_":
|
| 194 |
+
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
|
| 195 |
+
for i, ind in enumerate(inds):
|
| 196 |
+
if len(ind) > 0:
|
| 197 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
|
| 198 |
+
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
|
| 199 |
+
return alpha_time_words
|
i2vedit/prompt_attention/visualization.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import os
|
| 3 |
+
import datetime
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from packaging import version
|
| 12 |
+
|
| 13 |
+
from i2vedit.prompt_attention import ptp_utils
|
| 14 |
+
from i2vedit.prompt_attention.common.image_util import save_gif_mp4_folder_type
|
| 15 |
+
from i2vedit.prompt_attention.attention_store import AttentionStore
|
| 16 |
+
|
| 17 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 18 |
+
SDP_IS_AVAILABLE = True
|
| 19 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
| 20 |
+
|
| 21 |
+
BACKEND_MAP = {
|
| 22 |
+
SDPBackend.MATH: {
|
| 23 |
+
"enable_math": True,
|
| 24 |
+
"enable_flash": False,
|
| 25 |
+
"enable_mem_efficient": False,
|
| 26 |
+
},
|
| 27 |
+
SDPBackend.FLASH_ATTENTION: {
|
| 28 |
+
"enable_math": False,
|
| 29 |
+
"enable_flash": True,
|
| 30 |
+
"enable_mem_efficient": False,
|
| 31 |
+
},
|
| 32 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
| 33 |
+
"enable_math": False,
|
| 34 |
+
"enable_flash": False,
|
| 35 |
+
"enable_mem_efficient": True,
|
| 36 |
+
},
|
| 37 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
| 38 |
+
}
|
| 39 |
+
else:
|
| 40 |
+
from contextlib import nullcontext
|
| 41 |
+
|
| 42 |
+
SDP_IS_AVAILABLE = False
|
| 43 |
+
sdp_kernel = nullcontext
|
| 44 |
+
BACKEND_MAP = {}
|
| 45 |
+
logpy.warn(
|
| 46 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
| 47 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
| 48 |
+
f"You might want to consider upgrading."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 52 |
+
out = []
|
| 53 |
+
attention_maps = attention_store.get_average_attention()
|
| 54 |
+
num_pixels = res ** 2
|
| 55 |
+
for location in from_where:
|
| 56 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 57 |
+
if item.dim() == 3:
|
| 58 |
+
if item.shape[1] == num_pixels:
|
| 59 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 60 |
+
out.append(cross_maps)
|
| 61 |
+
elif item.dim() == 4:
|
| 62 |
+
t, h, res_sq, token = item.shape
|
| 63 |
+
if item.shape[2] == num_pixels:
|
| 64 |
+
cross_maps = item.reshape(len(prompts), t, -1, res, res, item.shape[-1])[select]
|
| 65 |
+
out.append(cross_maps)
|
| 66 |
+
|
| 67 |
+
out = torch.cat(out, dim=-4)
|
| 68 |
+
out = out.sum(-4) / out.shape[-4]
|
| 69 |
+
return out.cpu()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore,
|
| 73 |
+
res: int, from_where: List[str], select: int = 0, save_path = None):
|
| 74 |
+
"""
|
| 75 |
+
attention_store (AttentionStore):
|
| 76 |
+
["down", "mid", "up"] X ["self", "cross"]
|
| 77 |
+
4, 1, 6
|
| 78 |
+
head*res*text_token_len = 8*res*77
|
| 79 |
+
res=1024 -> 64 -> 1024
|
| 80 |
+
res (int): res
|
| 81 |
+
from_where (List[str]): "up", "down'
|
| 82 |
+
"""
|
| 83 |
+
if isinstance(prompts, str):
|
| 84 |
+
prompts = [prompts,]
|
| 85 |
+
tokens = tokenizer.encode(prompts[select])
|
| 86 |
+
decoder = tokenizer.decode
|
| 87 |
+
|
| 88 |
+
attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
|
| 89 |
+
os.makedirs('trash', exist_ok=True)
|
| 90 |
+
attention_list = []
|
| 91 |
+
if attention_maps.dim()==3: attention_maps=attention_maps[None, ...]
|
| 92 |
+
for j in range(attention_maps.shape[0]):
|
| 93 |
+
images = []
|
| 94 |
+
for i in range(len(tokens)):
|
| 95 |
+
image = attention_maps[j, :, :, i]
|
| 96 |
+
image = 255 * image / image.max()
|
| 97 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
| 98 |
+
image = image.numpy().astype(np.uint8)
|
| 99 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
| 100 |
+
image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
|
| 101 |
+
images.append(image)
|
| 102 |
+
ptp_utils.view_images(np.stack(images, axis=0), save_path=save_path)
|
| 103 |
+
atten_j = np.concatenate(images, axis=1)
|
| 104 |
+
attention_list.append(atten_j)
|
| 105 |
+
if save_path is not None:
|
| 106 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 107 |
+
video_save_path = f'{save_path}/{now}.gif'
|
| 108 |
+
save_gif_mp4_folder_type(attention_list, video_save_path)
|
| 109 |
+
return attention_list
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
| 113 |
+
max_com=10, select: int = 0):
|
| 114 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
|
| 115 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
| 116 |
+
images = []
|
| 117 |
+
for i in range(max_com):
|
| 118 |
+
image = vh[i].reshape(res, res)
|
| 119 |
+
image = image - image.min()
|
| 120 |
+
image = 255 * image / image.max()
|
| 121 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
| 122 |
+
image = Image.fromarray(image).resize((256, 256))
|
| 123 |
+
image = np.array(image)
|
| 124 |
+
images.append(image)
|
| 125 |
+
ptp_utils.view_images(np.concatenate(images, axis=1))
|
| 126 |
+
|
| 127 |
+
def show_avg_difference_maps(
|
| 128 |
+
attention_store: AttentionStore,
|
| 129 |
+
save_path = None
|
| 130 |
+
):
|
| 131 |
+
avg_attention = attention_store.get_average_attention()
|
| 132 |
+
masks = []
|
| 133 |
+
for key in avg_attention:
|
| 134 |
+
if 'mask' in key:
|
| 135 |
+
for cur_pos in range(len(avg_attention[key])):
|
| 136 |
+
mask = avg_attention[key][cur_pos]
|
| 137 |
+
res = mask.shape[0] / 9
|
| 138 |
+
file_path = os.path.join(
|
| 139 |
+
save_path,
|
| 140 |
+
f"avg_key_{key}_curpos_{cur_pos}_res_{res}_mask.png"
|
| 141 |
+
)
|
| 142 |
+
print(key, cur_pos, mask.shape)
|
| 143 |
+
image = 255 * mask #/ attn.max()
|
| 144 |
+
image = image.cpu().numpy().astype(np.uint8)
|
| 145 |
+
image = Image.fromarray(image)
|
| 146 |
+
image.save(file_path)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def show_self_attention(
|
| 152 |
+
attention_store: AttentionStore,
|
| 153 |
+
steps: List[int],
|
| 154 |
+
save_path = None,
|
| 155 |
+
inversed = False):
|
| 156 |
+
"""
|
| 157 |
+
attention_store (AttentionStore):
|
| 158 |
+
["down", "mid", "up"] X ["self", "cross"]
|
| 159 |
+
4, 1, 6
|
| 160 |
+
head*res*text_token_len = 8*res*77
|
| 161 |
+
res=1024 -> 64 -> 1024
|
| 162 |
+
res (int): res
|
| 163 |
+
from_where (List[str]): "up", "down'
|
| 164 |
+
"""
|
| 165 |
+
#os.system(f"rm -rf {save_path}")
|
| 166 |
+
os.makedirs(save_path, exist_ok=True)
|
| 167 |
+
for step in steps:
|
| 168 |
+
step_in_store = len(attention_store.attention_store_all_step) - step - 1 if inversed else step
|
| 169 |
+
print("step_in_store", step_in_store)
|
| 170 |
+
step_in_store_atten_dict = attention_store.attention_store_all_step[step_in_store]
|
| 171 |
+
if isinstance(step_in_store_atten_dict, str):
|
| 172 |
+
step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
|
| 173 |
+
|
| 174 |
+
step_in_store_atten_dict_reorg = {}
|
| 175 |
+
|
| 176 |
+
for key in step_in_store_atten_dict:
|
| 177 |
+
if '_q_' not in key and ('_k_' not in key and '_v_' not in key):
|
| 178 |
+
step_in_store_atten_dict_reorg[key] = step_in_store_atten_dict[key]
|
| 179 |
+
elif '_q_' in key:
|
| 180 |
+
step_in_store_atten_dict_reorg[key.replace("_q_","_qxk_")] = \
|
| 181 |
+
[[step_in_store_atten_dict[key][i], \
|
| 182 |
+
step_in_store_atten_dict[key.replace("_q_","_k_")][i] \
|
| 183 |
+
] \
|
| 184 |
+
for i in range(len(step_in_store_atten_dict[key]))]
|
| 185 |
+
|
| 186 |
+
for key in step_in_store_atten_dict_reorg:
|
| 187 |
+
if '_mask_' not in key and '_qxk_' not in key:
|
| 188 |
+
for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
|
| 189 |
+
attn = step_in_store_atten_dict_reorg[key][cur_pos]
|
| 190 |
+
attn = torch.mean(attn, dim=1)
|
| 191 |
+
s, t, d = attn.shape
|
| 192 |
+
res = int(np.sqrt(s / (9*16)))
|
| 193 |
+
attn = attn.reshape(res*9,res*16,t,d).permute(2,0,3,1).reshape(t*res*9,d*res*16)
|
| 194 |
+
file_path = os.path.join(
|
| 195 |
+
save_path,
|
| 196 |
+
f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}.png"
|
| 197 |
+
)
|
| 198 |
+
print(step, key, cur_pos, attn.shape)
|
| 199 |
+
image = 255 * attn #/ attn.max()
|
| 200 |
+
image = image.cpu().numpy().astype(np.uint8)
|
| 201 |
+
image = Image.fromarray(image)
|
| 202 |
+
image.save(file_path)
|
| 203 |
+
|
| 204 |
+
elif '_mask_' in key:
|
| 205 |
+
for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
|
| 206 |
+
mask = step_in_store_atten_dict_reorg[key][cur_pos]
|
| 207 |
+
res = mask.shape[0] / 9
|
| 208 |
+
file_path = os.path.join(
|
| 209 |
+
save_path,
|
| 210 |
+
f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_mask.png"
|
| 211 |
+
)
|
| 212 |
+
print(step, key, cur_pos, mask.shape)
|
| 213 |
+
image = 255 * mask #/ attn.max()
|
| 214 |
+
image = image.cpu().numpy().astype(np.uint8)
|
| 215 |
+
image = Image.fromarray(image)
|
| 216 |
+
image.save(file_path)
|
| 217 |
+
|
| 218 |
+
else:
|
| 219 |
+
for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
|
| 220 |
+
q, k = step_in_store_atten_dict_reorg[key][cur_pos]
|
| 221 |
+
q = q.to("cuda").type(torch.float32)
|
| 222 |
+
k = k.to("cuda").type(torch.float32)
|
| 223 |
+
res = int(np.sqrt(q.shape[-2] / (9*16)))
|
| 224 |
+
h = q.shape[1]
|
| 225 |
+
bs = 1
|
| 226 |
+
N = q.shape[0] // bs
|
| 227 |
+
vectors = []
|
| 228 |
+
vectors_diff = []
|
| 229 |
+
for i in range(N):
|
| 230 |
+
attn_prob = calculate_attention_probs(q[i*bs:(i+1)*bs], k[i*bs:(i+1)*bs])
|
| 231 |
+
print("attn_prob 1", attn_prob.min(), attn_prob.max())
|
| 232 |
+
attn_prob = torch.mean(attn_prob, dim=2).reshape(h, res*9, res*16)
|
| 233 |
+
print("attn_prob 2", attn_prob.min(), attn_prob.max())
|
| 234 |
+
attn_prob = torch.mean(attn_prob, dim=0)
|
| 235 |
+
print("attn_prob 3", attn_prob.min(), attn_prob.max())
|
| 236 |
+
vectors.append( attn_prob )
|
| 237 |
+
for i in range(1, len(vectors)):
|
| 238 |
+
vectors_diff.append(vectors[i] - vectors[i-1])
|
| 239 |
+
vectors = torch.cat(vectors, dim=1)
|
| 240 |
+
vectors_diff = torch.cat(vectors_diff, dim=1)
|
| 241 |
+
file_path = os.path.join(
|
| 242 |
+
save_path,
|
| 243 |
+
f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_vector.png"
|
| 244 |
+
)
|
| 245 |
+
print(step, key, cur_pos, vectors.shape)
|
| 246 |
+
image = 255 * vectors / vectors.max()
|
| 247 |
+
image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
|
| 248 |
+
image = Image.fromarray(image)
|
| 249 |
+
image.save(file_path)
|
| 250 |
+
|
| 251 |
+
file_path = os.path.join(
|
| 252 |
+
save_path,
|
| 253 |
+
f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_diff.png"
|
| 254 |
+
)
|
| 255 |
+
print(step, key, cur_pos, vectors_diff.shape)
|
| 256 |
+
image = 255 * vectors_diff / vectors_diff.max()
|
| 257 |
+
image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
|
| 258 |
+
image = Image.fromarray(image)
|
| 259 |
+
image.save(file_path)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# else:
|
| 263 |
+
# # 只看最后两帧
|
| 264 |
+
# for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
|
| 265 |
+
# q, k, v = step_in_store_atten_dict_reorg[key][cur_pos]
|
| 266 |
+
# q = q[-2:,...].to("cuda")
|
| 267 |
+
# k = k[-2:,...].to("cuda")
|
| 268 |
+
# v = v[-2:,...].to("cuda")
|
| 269 |
+
# res = int(np.sqrt(q.shape[-2] / (9*16)))
|
| 270 |
+
# attn = calculate_attention_probs(q,k,v)
|
| 271 |
+
# attn_d = torch.sum(torch.mean(torch.abs(attn[0,...] - attn[1,...]), dim=0), dim=1).reshape(res*9,res*16)
|
| 272 |
+
# print(step, key, cur_pos, attn_d.shape, attn_d.min(), attn_d.max())
|
| 273 |
+
# file_path = os.path.join(
|
| 274 |
+
# save_path,
|
| 275 |
+
# f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_attn_d.png"
|
| 276 |
+
# )
|
| 277 |
+
# image = (255 * attn_d + 1e-3) / 2.#attn_d.max()
|
| 278 |
+
# image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
|
| 279 |
+
# image = Image.fromarray(image)
|
| 280 |
+
# image.save(file_path)
|
| 281 |
+
|
| 282 |
+
def show_self_attention_distance(
|
| 283 |
+
attention_store: List[AttentionStore],
|
| 284 |
+
steps: List[int],
|
| 285 |
+
save_path = None,
|
| 286 |
+
):
|
| 287 |
+
"""
|
| 288 |
+
attention_store (AttentionStore):
|
| 289 |
+
["down", "mid", "up"] X ["self", "cross"]
|
| 290 |
+
4, 1, 6
|
| 291 |
+
head*res*text_token_len = 8*res*77
|
| 292 |
+
res=1024 -> 64 -> 1024
|
| 293 |
+
res (int): res
|
| 294 |
+
from_where (List[str]): "up", "down'
|
| 295 |
+
"""
|
| 296 |
+
os.system(f"rm -rf {save_path}")
|
| 297 |
+
os.makedirs(save_path, exist_ok=True)
|
| 298 |
+
assert len(attention_store) == 2
|
| 299 |
+
for step in steps:
|
| 300 |
+
step_in_store = [len(attention_store[0].attention_store_all_step) - step - 1, step]
|
| 301 |
+
step_in_store_atten_dict = [attention_store[i].attention_store_all_step[step_in_store[i]] \
|
| 302 |
+
for i in range(2)]
|
| 303 |
+
step_in_store_atten_dict = [ \
|
| 304 |
+
torch.load(step_in_store_atten_dict[i]) \
|
| 305 |
+
if isinstance(step_in_store_atten_dict[i], str) \
|
| 306 |
+
else step_in_store_atten_dict[i] \
|
| 307 |
+
for i in range(2)]
|
| 308 |
+
|
| 309 |
+
step_in_store_atten_dict_reorg = [{},{}]
|
| 310 |
+
|
| 311 |
+
for i in range(2):
|
| 312 |
+
item = step_in_store_atten_dict[i]
|
| 313 |
+
for key in item:
|
| 314 |
+
if '_q_' in key:
|
| 315 |
+
step_in_store_atten_dict_reorg[i][key.replace("_q_","_qxk_")] = \
|
| 316 |
+
[[step_in_store_atten_dict[i][key][j], \
|
| 317 |
+
step_in_store_atten_dict[i][key.replace("_q_","_k_")][j] \
|
| 318 |
+
] \
|
| 319 |
+
for j in range(len(step_in_store_atten_dict[i][key]))]
|
| 320 |
+
|
| 321 |
+
for key in step_in_store_atten_dict_reorg[1]:
|
| 322 |
+
for cur_pos in range(len(step_in_store_atten_dict_reorg[1][key])):
|
| 323 |
+
q1, k1 = step_in_store_atten_dict_reorg[1][key][cur_pos]
|
| 324 |
+
q0, k0 = step_in_store_atten_dict_reorg[0][key][cur_pos]
|
| 325 |
+
res = int(np.sqrt(q1.shape[-2] / (9*16)))
|
| 326 |
+
|
| 327 |
+
attn_d = calculate_attention_mask(q0, k0, q1, k1, bs=1, device="cuda")
|
| 328 |
+
attn_d = rearrange(attn_d, "b h w -> h (b w)")
|
| 329 |
+
|
| 330 |
+
print(step, key, cur_pos, attn_d.shape, "attnd", attn_d.min(), attn_d.max())
|
| 331 |
+
file_path = os.path.join(
|
| 332 |
+
save_path,
|
| 333 |
+
f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_attn_d.png"
|
| 334 |
+
)
|
| 335 |
+
image = 255 * attn_d#attn_d.max()
|
| 336 |
+
image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
|
| 337 |
+
image = Image.fromarray(image)
|
| 338 |
+
image.save(file_path)
|
| 339 |
+
|
| 340 |
+
def calculate_attention_mask(q0, k0, q1, k1, bs=1, device="cuda"):
|
| 341 |
+
q1 = q1.to(device)
|
| 342 |
+
k1 = k1.to(device)
|
| 343 |
+
q0 = q0.to(device)
|
| 344 |
+
k0 = k0.to(device)
|
| 345 |
+
res = int(np.sqrt(q1.shape[-2] / (9*16)))
|
| 346 |
+
N = q0.shape[0] // bs
|
| 347 |
+
attn_d = []
|
| 348 |
+
for i in range(N):
|
| 349 |
+
attn0 = calculate_attention_probs(q0[bs*i:bs*(i+1),...],k0[bs*i:bs*(i+1),...])
|
| 350 |
+
attn1 = calculate_attention_probs(q1[bs*i:bs*(i+1),...],k1[bs*i:bs*(i+1),...])
|
| 351 |
+
attn_d_i = torch.sum(torch.mean(torch.abs(attn0 - attn1), dim=1), dim=2).reshape(bs,res*9,res*16)
|
| 352 |
+
attn_d.append( attn_d_i )
|
| 353 |
+
attn_d = torch.cat(attn_d, dim=0) / 2.0
|
| 354 |
+
return attn_d.clamp(0,1)
|
| 355 |
+
|
| 356 |
+
def calculate_attention_probs(q, k, attn_mask=None):
|
| 357 |
+
with sdp_kernel(**BACKEND_MAP[None]):
|
| 358 |
+
h = q.shape[1]
|
| 359 |
+
q, k = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k))
|
| 360 |
+
|
| 361 |
+
with torch.autocast("cuda", enabled=False):
|
| 362 |
+
attention_scores = torch.baddbmm(
|
| 363 |
+
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
| 364 |
+
q,
|
| 365 |
+
k.transpose(-1, -2),
|
| 366 |
+
beta=0,
|
| 367 |
+
alpha=1 / math.sqrt(q.size(-1)),
|
| 368 |
+
)
|
| 369 |
+
#print("attention_scores", attention_scores.min(), attention_scores.max())
|
| 370 |
+
|
| 371 |
+
if attn_mask is not None:
|
| 372 |
+
if attn_mask.dtype == torch.bool:
|
| 373 |
+
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 374 |
+
attention_scores = attention_scores + attn_mask
|
| 375 |
+
|
| 376 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 377 |
+
#print("attention_softmax", attention_probs.min(), attention_probs.max())
|
| 378 |
+
|
| 379 |
+
# cast back to the original dtype
|
| 380 |
+
attention_probs = attention_probs.to(q.dtype)
|
| 381 |
+
|
| 382 |
+
# reshape hidden_states
|
| 383 |
+
attention_probs = rearrange(attention_probs, "(b h) n d -> b h n d", h=h)
|
| 384 |
+
|
| 385 |
+
# v = torch.eye(q.shape[-2], device=q.device)
|
| 386 |
+
# v = repeat(v, "... -> b h ...", b=q.shape[0], h=q.shape[1])
|
| 387 |
+
# attention_probs = F.scaled_dot_product_attention(
|
| 388 |
+
# q, k, v, attn_mask=attn_mask
|
| 389 |
+
# ) # scale is dim_head ** -0.5 per default
|
| 390 |
+
|
| 391 |
+
return attention_probs
|
i2vedit/train.py
ADDED
|
@@ -0,0 +1,1488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import inspect
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import gc
|
| 9 |
+
import copy
|
| 10 |
+
from scipy.stats import anderson
|
| 11 |
+
import imageio
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from typing import Dict, Optional, Tuple, List
|
| 15 |
+
from omegaconf import OmegaConf
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
import diffusers
|
| 22 |
+
import transformers
|
| 23 |
+
|
| 24 |
+
from torchvision import transforms
|
| 25 |
+
from tqdm.auto import tqdm
|
| 26 |
+
from PIL import Image
|
| 27 |
+
|
| 28 |
+
from accelerate import Accelerator
|
| 29 |
+
from accelerate.logging import get_logger
|
| 30 |
+
from accelerate.utils import set_seed
|
| 31 |
+
|
| 32 |
+
import diffusers
|
| 33 |
+
from diffusers.models import AutoencoderKL
|
| 34 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
| 35 |
+
from diffusers.optimization import get_scheduler
|
| 36 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 37 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
|
| 38 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 39 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 40 |
+
from diffusers.models.lora import LoRALinearLayer
|
| 41 |
+
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
|
| 42 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 43 |
+
from diffusers.optimization import get_scheduler
|
| 44 |
+
from diffusers.training_utils import EMAModel
|
| 45 |
+
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
|
| 46 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 47 |
+
from diffusers.models.unet_3d_blocks import \
|
| 48 |
+
(CrossAttnDownBlockSpatioTemporal,
|
| 49 |
+
DownBlockSpatioTemporal,
|
| 50 |
+
CrossAttnUpBlockSpatioTemporal,
|
| 51 |
+
UpBlockSpatioTemporal)
|
| 52 |
+
|
| 53 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 54 |
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
| 55 |
+
|
| 56 |
+
from i2vedit.utils.dataset import (
|
| 57 |
+
CachedDataset,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
from i2vedit.utils.lora_handler import LoraHandler
|
| 61 |
+
from i2vedit.utils.lora import extract_lora_child_module
|
| 62 |
+
from i2vedit.utils.euler_utils import euler_inversion
|
| 63 |
+
from i2vedit.utils.svd_util import SmoothAreaRandomDetection
|
| 64 |
+
from i2vedit.utils.model_utils import (
|
| 65 |
+
tensor_to_vae_latent,
|
| 66 |
+
P2PEulerDiscreteScheduler,
|
| 67 |
+
P2PStableVideoDiffusionPipeline
|
| 68 |
+
)
|
| 69 |
+
from i2vedit.data import ResolutionControl, SingleClipDataset
|
| 70 |
+
from i2vedit.prompt_attention import attention_util
|
| 71 |
+
|
| 72 |
+
already_printed_trainables = False
|
| 73 |
+
|
| 74 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_logging(logging, logger, accelerator):
|
| 78 |
+
logging.basicConfig(
|
| 79 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 80 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 81 |
+
level=logging.INFO,
|
| 82 |
+
)
|
| 83 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def accelerate_set_verbose(accelerator):
|
| 87 |
+
if accelerator.is_local_main_process:
|
| 88 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 89 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 90 |
+
else:
|
| 91 |
+
transformers.utils.logging.set_verbosity_error()
|
| 92 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 93 |
+
|
| 94 |
+
def extend_datasets(datasets, dataset_items, extend=False):
|
| 95 |
+
biggest_data_len = max(x.__len__() for x in datasets)
|
| 96 |
+
extended = []
|
| 97 |
+
for dataset in datasets:
|
| 98 |
+
if dataset.__len__() == 0:
|
| 99 |
+
del dataset
|
| 100 |
+
continue
|
| 101 |
+
if dataset.__len__() < biggest_data_len:
|
| 102 |
+
for item in dataset_items:
|
| 103 |
+
if extend and item not in extended and hasattr(dataset, item):
|
| 104 |
+
print(f"Extending {item}")
|
| 105 |
+
|
| 106 |
+
value = getattr(dataset, item)
|
| 107 |
+
value *= biggest_data_len
|
| 108 |
+
value = value[:biggest_data_len]
|
| 109 |
+
|
| 110 |
+
setattr(dataset, item, value)
|
| 111 |
+
|
| 112 |
+
print(f"New {item} dataset length: {dataset.__len__()}")
|
| 113 |
+
extended.append(item)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def export_to_video(video_frames, output_video_path, fps, resctrl:ResolutionControl):
|
| 117 |
+
flattened_video_frames = [img for sublist in video_frames for img in sublist]
|
| 118 |
+
video_writer = imageio.get_writer(output_video_path, fps=fps)
|
| 119 |
+
for img in flattened_video_frames:
|
| 120 |
+
img = resctrl.callback(img)
|
| 121 |
+
video_writer.append_data(np.array(img))
|
| 122 |
+
video_writer.close()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def create_output_folders(output_dir, config, clip_id):
|
| 126 |
+
out_dir = os.path.join(output_dir, f"train_motion_lora/clip_{clip_id}")
|
| 127 |
+
|
| 128 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 129 |
+
os.makedirs(f"{out_dir}/samples", exist_ok=True)
|
| 130 |
+
# OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))
|
| 131 |
+
|
| 132 |
+
return out_dir
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_primary_models(pretrained_model_path):
|
| 136 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
| 137 |
+
pretrained_model_path, subfolder="scheduler")
|
| 138 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 139 |
+
pretrained_model_path, subfolder="feature_extractor", revision=None
|
| 140 |
+
)
|
| 141 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 142 |
+
pretrained_model_path, subfolder="image_encoder", revision=None, variant="fp16"
|
| 143 |
+
)
|
| 144 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
| 145 |
+
pretrained_model_path, subfolder="vae", revision=None, variant="fp16")
|
| 146 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 147 |
+
pretrained_model_path,
|
| 148 |
+
subfolder="unet",
|
| 149 |
+
low_cpu_mem_usage=True,
|
| 150 |
+
variant="fp16",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return noise_scheduler, feature_extractor, image_encoder, vae, unet
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def unet_and_text_g_c(unet, image_encoder, unet_enable, image_enable):
|
| 157 |
+
unet.gradient_checkpointing = unet_enable
|
| 158 |
+
unet.mid_block.gradient_checkpointing = unet_enable
|
| 159 |
+
for module in unet.down_blocks + unet.up_blocks:
|
| 160 |
+
if isinstance(module,
|
| 161 |
+
(CrossAttnDownBlockSpatioTemporal,
|
| 162 |
+
DownBlockSpatioTemporal,
|
| 163 |
+
CrossAttnUpBlockSpatioTemporal,
|
| 164 |
+
UpBlockSpatioTemporal)):
|
| 165 |
+
module.gradient_checkpointing = unet_enable
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def freeze_models(models_to_freeze):
|
| 169 |
+
for model in models_to_freeze:
|
| 170 |
+
if model is not None: model.requires_grad_(False)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def is_attn(name):
|
| 174 |
+
return ('attn1' or 'attn2' == name.split('.')[-1])
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def set_processors(attentions):
|
| 178 |
+
for attn in attentions: attn.set_processor(AttnProcessor2_0())
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def set_torch_2_attn(unet):
|
| 182 |
+
optim_count = 0
|
| 183 |
+
|
| 184 |
+
for name, module in unet.named_modules():
|
| 185 |
+
if is_attn(name):
|
| 186 |
+
if isinstance(module, torch.nn.ModuleList):
|
| 187 |
+
for m in module:
|
| 188 |
+
if isinstance(m, BasicTransformerBlock):
|
| 189 |
+
set_processors([m.attn1, m.attn2])
|
| 190 |
+
optim_count += 1
|
| 191 |
+
if optim_count > 0:
|
| 192 |
+
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
|
| 196 |
+
try:
|
| 197 |
+
is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
|
| 198 |
+
enable_torch_2 = is_torch_2 and enable_torch_2_attn
|
| 199 |
+
|
| 200 |
+
if enable_xformers_memory_efficient_attention and not enable_torch_2:
|
| 201 |
+
if is_xformers_available():
|
| 202 |
+
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 203 |
+
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 206 |
+
|
| 207 |
+
if enable_torch_2:
|
| 208 |
+
set_torch_2_attn(unet)
|
| 209 |
+
|
| 210 |
+
except:
|
| 211 |
+
print("Could not enable memory efficient attention for xformers or Torch 2.0.")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
|
| 215 |
+
extra_params = extra_params if len(extra_params.keys()) > 0 else None
|
| 216 |
+
return {
|
| 217 |
+
"model": model,
|
| 218 |
+
"condition": condition,
|
| 219 |
+
'extra_params': extra_params,
|
| 220 |
+
'is_lora': is_lora,
|
| 221 |
+
"negation": negation
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
|
| 226 |
+
params = {
|
| 227 |
+
"name": name,
|
| 228 |
+
"params": params,
|
| 229 |
+
"lr": lr
|
| 230 |
+
}
|
| 231 |
+
if extra_params is not None:
|
| 232 |
+
for k, v in extra_params.items():
|
| 233 |
+
params[k] = v
|
| 234 |
+
|
| 235 |
+
return params
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def negate_params(name, negation):
|
| 239 |
+
# We have to do this if we are co-training with LoRA.
|
| 240 |
+
# This ensures that parameter groups aren't duplicated.
|
| 241 |
+
if negation is None: return False
|
| 242 |
+
for n in negation:
|
| 243 |
+
if n in name and 'temp' not in name:
|
| 244 |
+
return True
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def create_optimizer_params(model_list, lr):
|
| 249 |
+
import itertools
|
| 250 |
+
optimizer_params = []
|
| 251 |
+
|
| 252 |
+
for optim in model_list:
|
| 253 |
+
model, condition, extra_params, is_lora, negation = optim.values()
|
| 254 |
+
# Check if we are doing LoRA training.
|
| 255 |
+
if is_lora and condition and isinstance(model, list):
|
| 256 |
+
params = create_optim_params(
|
| 257 |
+
params=itertools.chain(*model),
|
| 258 |
+
extra_params=extra_params
|
| 259 |
+
)
|
| 260 |
+
optimizer_params.append(params)
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
if is_lora and condition and not isinstance(model, list):
|
| 264 |
+
for n, p in model.named_parameters():
|
| 265 |
+
if 'lora' in n:
|
| 266 |
+
params = create_optim_params(n, p, lr, extra_params)
|
| 267 |
+
optimizer_params.append(params)
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
# If this is true, we can train it.
|
| 271 |
+
if condition:
|
| 272 |
+
for n, p in model.named_parameters():
|
| 273 |
+
should_negate = 'lora' in n and not is_lora
|
| 274 |
+
if should_negate: continue
|
| 275 |
+
|
| 276 |
+
params = create_optim_params(n, p, lr, extra_params)
|
| 277 |
+
optimizer_params.append(params)
|
| 278 |
+
|
| 279 |
+
return optimizer_params
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_optimizer(use_8bit_adam):
|
| 283 |
+
if use_8bit_adam:
|
| 284 |
+
try:
|
| 285 |
+
import bitsandbytes as bnb
|
| 286 |
+
except ImportError:
|
| 287 |
+
raise ImportError(
|
| 288 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return bnb.optim.AdamW8bit
|
| 292 |
+
else:
|
| 293 |
+
return torch.optim.AdamW
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def is_mixed_precision(accelerator):
|
| 297 |
+
weight_dtype = torch.float32
|
| 298 |
+
|
| 299 |
+
if accelerator.mixed_precision == "fp16":
|
| 300 |
+
weight_dtype = torch.float16
|
| 301 |
+
|
| 302 |
+
elif accelerator.mixed_precision == "bf16":
|
| 303 |
+
weight_dtype = torch.bfloat16
|
| 304 |
+
|
| 305 |
+
return weight_dtype
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
|
| 309 |
+
for model in model_list:
|
| 310 |
+
if model is not None: model.to(accelerator.device, dtype=weight_dtype)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def inverse_video(pipe, latents, num_steps, image):
|
| 314 |
+
euler_inv_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 315 |
+
euler_inv_scheduler.set_timesteps(num_steps)
|
| 316 |
+
|
| 317 |
+
euler_inv_latent = euler_inversion(
|
| 318 |
+
pipe, euler_inv_scheduler, video_latent=latents.to(pipe.device),
|
| 319 |
+
num_inv_steps=num_steps, image=image)[-1]
|
| 320 |
+
return euler_inv_latent
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def handle_cache_latents(
|
| 324 |
+
should_cache,
|
| 325 |
+
output_dir,
|
| 326 |
+
train_dataloader,
|
| 327 |
+
train_batch_size,
|
| 328 |
+
vae,
|
| 329 |
+
unet,
|
| 330 |
+
cached_latent_dir=None,
|
| 331 |
+
):
|
| 332 |
+
# Cache latents by storing them in VRAM.
|
| 333 |
+
# Speeds up training and saves memory by not encoding during the train loop.
|
| 334 |
+
if not should_cache: return None
|
| 335 |
+
vae.to('cuda', dtype=torch.float32)
|
| 336 |
+
#vae.enable_slicing()
|
| 337 |
+
|
| 338 |
+
cached_latent_dir = (
|
| 339 |
+
os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if cached_latent_dir is None:
|
| 343 |
+
cache_save_dir = f"{output_dir}/cached_latents"
|
| 344 |
+
os.makedirs(cache_save_dir, exist_ok=True)
|
| 345 |
+
|
| 346 |
+
for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
|
| 347 |
+
|
| 348 |
+
save_name = f"cached_{i}"
|
| 349 |
+
full_out_path = f"{cache_save_dir}/{save_name}.pt"
|
| 350 |
+
|
| 351 |
+
pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32)
|
| 352 |
+
refer_pixel_values = batch['refer_pixel_values'].to('cuda', dtype=torch.float32)
|
| 353 |
+
cross_pixel_values = batch['cross_pixel_values'].to('cuda', dtype=torch.float32)
|
| 354 |
+
batch['latents'] = tensor_to_vae_latent(pixel_values, vae)
|
| 355 |
+
if batch.get("refer_latents") is None:
|
| 356 |
+
batch['refer_latents'] = tensor_to_vae_latent(refer_pixel_values, vae)
|
| 357 |
+
batch['cross_latents'] = tensor_to_vae_latent(cross_pixel_values, vae)
|
| 358 |
+
|
| 359 |
+
for k, v in batch.items(): batch[k] = v[0]
|
| 360 |
+
|
| 361 |
+
torch.save(batch, full_out_path)
|
| 362 |
+
del pixel_values
|
| 363 |
+
del batch
|
| 364 |
+
|
| 365 |
+
# We do this to avoid fragmentation from casting latents between devices.
|
| 366 |
+
torch.cuda.empty_cache()
|
| 367 |
+
else:
|
| 368 |
+
cache_save_dir = cached_latent_dir
|
| 369 |
+
|
| 370 |
+
return torch.utils.data.DataLoader(
|
| 371 |
+
CachedDataset(cache_dir=cache_save_dir),
|
| 372 |
+
batch_size=train_batch_size,
|
| 373 |
+
shuffle=True,
|
| 374 |
+
num_workers=0
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
|
| 379 |
+
global already_printed_trainables
|
| 380 |
+
|
| 381 |
+
# This can most definitely be refactored :-)
|
| 382 |
+
unfrozen_params = 0
|
| 383 |
+
if trainable_modules is not None:
|
| 384 |
+
for name, module in model.named_modules():
|
| 385 |
+
for tm in tuple(trainable_modules):
|
| 386 |
+
if tm == 'all':
|
| 387 |
+
model.requires_grad_(is_enabled)
|
| 388 |
+
unfrozen_params = len(list(model.parameters()))
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
if tm in name and 'lora' not in name:
|
| 392 |
+
for m in module.parameters():
|
| 393 |
+
m.requires_grad_(is_enabled)
|
| 394 |
+
if is_enabled: unfrozen_params += 1
|
| 395 |
+
|
| 396 |
+
if unfrozen_params > 0 and not already_printed_trainables:
|
| 397 |
+
already_printed_trainables = True
|
| 398 |
+
print(f"{unfrozen_params} params have been unfrozen for training.")
|
| 399 |
+
|
| 400 |
+
def sample_noise(latents, noise_strength, use_offset_noise=False):
|
| 401 |
+
b, c, f, *_ = latents.shape
|
| 402 |
+
noise_latents = torch.randn_like(latents, device=latents.device)
|
| 403 |
+
|
| 404 |
+
if use_offset_noise:
|
| 405 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
|
| 406 |
+
noise_latents = noise_latents + noise_strength * offset_noise
|
| 407 |
+
|
| 408 |
+
return noise_latents
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def enforce_zero_terminal_snr(betas):
|
| 412 |
+
"""
|
| 413 |
+
Corrects noise in diffusion schedulers.
|
| 414 |
+
From: Common Diffusion Noise Schedules and Sample Steps are Flawed
|
| 415 |
+
https://arxiv.org/pdf/2305.08891.pdf
|
| 416 |
+
"""
|
| 417 |
+
# Convert betas to alphas_bar_sqrt
|
| 418 |
+
alphas = 1 - betas
|
| 419 |
+
alphas_bar = alphas.cumprod(0)
|
| 420 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
| 421 |
+
|
| 422 |
+
# Store old values.
|
| 423 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 424 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 425 |
+
|
| 426 |
+
# Shift so the last timestep is zero.
|
| 427 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 428 |
+
|
| 429 |
+
# Scale so the first timestep is back to the old value.
|
| 430 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
| 431 |
+
alphas_bar_sqrt_0 - alphas_bar_sqrt_T
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Convert alphas_bar_sqrt to betas
|
| 435 |
+
alphas_bar = alphas_bar_sqrt ** 2
|
| 436 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
| 437 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 438 |
+
betas = 1 - alphas
|
| 439 |
+
|
| 440 |
+
return betas
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def should_sample(global_step, validation_steps, validation_data):
|
| 444 |
+
return global_step % validation_steps == 0 and validation_data.sample_preview
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def save_pipe(
|
| 448 |
+
path,
|
| 449 |
+
global_step,
|
| 450 |
+
accelerator,
|
| 451 |
+
unet,
|
| 452 |
+
image_encoder,
|
| 453 |
+
vae,
|
| 454 |
+
output_dir,
|
| 455 |
+
lora_manager_spatial: LoraHandler,
|
| 456 |
+
lora_manager_temporal: LoraHandler,
|
| 457 |
+
unet_target_replace_module=None,
|
| 458 |
+
image_target_replace_module=None,
|
| 459 |
+
is_checkpoint=False,
|
| 460 |
+
save_pretrained_model=True,
|
| 461 |
+
):
|
| 462 |
+
if is_checkpoint:
|
| 463 |
+
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
| 464 |
+
os.makedirs(save_path, exist_ok=True)
|
| 465 |
+
else:
|
| 466 |
+
save_path = output_dir
|
| 467 |
+
|
| 468 |
+
# Save the dtypes so we can continue training at the same precision.
|
| 469 |
+
u_dtype, i_dtype, v_dtype = unet.dtype, image_encoder.dtype, vae.dtype
|
| 470 |
+
|
| 471 |
+
# Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
|
| 472 |
+
unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False))
|
| 473 |
+
image_encoder_out = copy.deepcopy(accelerator.unwrap_model(image_encoder.cpu(), keep_fp32_wrapper=False))
|
| 474 |
+
pipeline = P2PStableVideoDiffusionPipeline.from_pretrained(
|
| 475 |
+
path,
|
| 476 |
+
unet=unet_out,
|
| 477 |
+
image_encoder=image_encoder_out,
|
| 478 |
+
vae=accelerator.unwrap_model(vae),
|
| 479 |
+
# torch_dtype=weight_dtype,
|
| 480 |
+
).to(torch_dtype=torch.float32)
|
| 481 |
+
|
| 482 |
+
# lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step)
|
| 483 |
+
if lora_manager_temporal is not None:
|
| 484 |
+
lora_manager_temporal.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/temporal', step=global_step)
|
| 485 |
+
|
| 486 |
+
if save_pretrained_model:
|
| 487 |
+
pipeline.save_pretrained(save_path)
|
| 488 |
+
|
| 489 |
+
if is_checkpoint:
|
| 490 |
+
unet, image_encoder = accelerator.prepare(unet, image_encoder)
|
| 491 |
+
models_to_cast_back = [(unet, u_dtype), (image_encoder, i_dtype), (vae, v_dtype)]
|
| 492 |
+
[x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
|
| 493 |
+
|
| 494 |
+
logger.info(f"Saved model at {save_path} on step {global_step}")
|
| 495 |
+
|
| 496 |
+
del pipeline
|
| 497 |
+
del unet_out
|
| 498 |
+
del image_encoder_out
|
| 499 |
+
torch.cuda.empty_cache()
|
| 500 |
+
gc.collect()
|
| 501 |
+
|
| 502 |
+
def load_images_from_list(img_list):
|
| 503 |
+
images = []
|
| 504 |
+
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"} # Add or remove extensions as needed
|
| 505 |
+
|
| 506 |
+
# Function to extract frame number from the filename
|
| 507 |
+
def frame_number(filename):
|
| 508 |
+
parts = filename.split('_')
|
| 509 |
+
if len(parts) > 1 and parts[0] == 'frame':
|
| 510 |
+
try:
|
| 511 |
+
return int(parts[1].split('.')[0]) # Extracting the number part
|
| 512 |
+
except ValueError:
|
| 513 |
+
return float('inf') # In case of non-integer part, place this file at the end
|
| 514 |
+
return float('inf') # Non-frame files are placed at the end
|
| 515 |
+
|
| 516 |
+
# Sorting files based on frame number
|
| 517 |
+
#sorted_files = sorted(os.listdir(folder), key=frame_number)
|
| 518 |
+
sorted_files = img_list
|
| 519 |
+
|
| 520 |
+
# Load images in sorted order
|
| 521 |
+
for filename in sorted_files:
|
| 522 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 523 |
+
if ext in valid_extensions:
|
| 524 |
+
img = Image.open(filename).convert('RGB')
|
| 525 |
+
images.append(img)
|
| 526 |
+
|
| 527 |
+
return images
|
| 528 |
+
|
| 529 |
+
# copy from https://github.com/crowsonkb/k-diffusion.git
|
| 530 |
+
def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
|
| 531 |
+
"""Draws stratified samples from a uniform distribution."""
|
| 532 |
+
if groups <= 0:
|
| 533 |
+
raise ValueError(f"groups must be positive, got {groups}")
|
| 534 |
+
if group < 0 or group >= groups:
|
| 535 |
+
raise ValueError(f"group must be in [0, {groups})")
|
| 536 |
+
n = shape[-1] * groups
|
| 537 |
+
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
|
| 538 |
+
u = torch.rand(shape, dtype=dtype, device=device)
|
| 539 |
+
return (offsets + u) / n
|
| 540 |
+
|
| 541 |
+
def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
|
| 542 |
+
"""Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
|
| 543 |
+
|
| 544 |
+
def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
|
| 545 |
+
t_min = math.atan(math.exp(-0.5 * logsnr_max))
|
| 546 |
+
t_max = math.atan(math.exp(-0.5 * logsnr_min))
|
| 547 |
+
return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
| 548 |
+
|
| 549 |
+
def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
|
| 550 |
+
shift = 2 * math.log(noise_d / image_d)
|
| 551 |
+
return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
|
| 552 |
+
|
| 553 |
+
def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
|
| 554 |
+
logsnr_low = logsnr_schedule_cosine_shifted(
|
| 555 |
+
t, image_d, noise_d_low, logsnr_min, logsnr_max)
|
| 556 |
+
logsnr_high = logsnr_schedule_cosine_shifted(
|
| 557 |
+
t, image_d, noise_d_high, logsnr_min, logsnr_max)
|
| 558 |
+
return torch.lerp(logsnr_low, logsnr_high, t)
|
| 559 |
+
|
| 560 |
+
logsnr_min = -2 * math.log(min_value / sigma_data)
|
| 561 |
+
logsnr_max = -2 * math.log(max_value / sigma_data)
|
| 562 |
+
u = stratified_uniform(
|
| 563 |
+
shape, group=0, groups=1, dtype=dtype, device=device
|
| 564 |
+
)
|
| 565 |
+
logsnr = logsnr_schedule_cosine_interpolated(
|
| 566 |
+
u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
|
| 567 |
+
return torch.exp(-logsnr / 2) * sigma_data, u
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
min_value = 0.002
|
| 571 |
+
max_value = 700
|
| 572 |
+
image_d = 64
|
| 573 |
+
noise_d_low = 32
|
| 574 |
+
noise_d_high = 64
|
| 575 |
+
sigma_data = 0.5
|
| 576 |
+
|
| 577 |
+
def _compute_padding(kernel_size):
|
| 578 |
+
"""Compute padding tuple."""
|
| 579 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
| 580 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
| 581 |
+
if len(kernel_size) < 2:
|
| 582 |
+
raise AssertionError(kernel_size)
|
| 583 |
+
computed = [k - 1 for k in kernel_size]
|
| 584 |
+
|
| 585 |
+
# for even kernels we need to do asymmetric padding :(
|
| 586 |
+
out_padding = 2 * len(kernel_size) * [0]
|
| 587 |
+
|
| 588 |
+
for i in range(len(kernel_size)):
|
| 589 |
+
computed_tmp = computed[-(i + 1)]
|
| 590 |
+
|
| 591 |
+
pad_front = computed_tmp // 2
|
| 592 |
+
pad_rear = computed_tmp - pad_front
|
| 593 |
+
|
| 594 |
+
out_padding[2 * i + 0] = pad_front
|
| 595 |
+
out_padding[2 * i + 1] = pad_rear
|
| 596 |
+
|
| 597 |
+
return out_padding
|
| 598 |
+
|
| 599 |
+
def _filter2d(input, kernel):
|
| 600 |
+
# prepare kernel
|
| 601 |
+
b, c, h, w = input.shape
|
| 602 |
+
tmp_kernel = kernel[:, None, ...].to(
|
| 603 |
+
device=input.device, dtype=input.dtype)
|
| 604 |
+
|
| 605 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
| 606 |
+
|
| 607 |
+
height, width = tmp_kernel.shape[-2:]
|
| 608 |
+
|
| 609 |
+
padding_shape: list[int] = _compute_padding([height, width])
|
| 610 |
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
| 611 |
+
|
| 612 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
| 613 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
| 614 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
| 615 |
+
|
| 616 |
+
# convolve the tensor with the kernel.
|
| 617 |
+
output = torch.nn.functional.conv2d(
|
| 618 |
+
input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
| 619 |
+
|
| 620 |
+
out = output.view(b, c, h, w)
|
| 621 |
+
return out
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def _gaussian(window_size: int, sigma):
|
| 625 |
+
if isinstance(sigma, float):
|
| 626 |
+
sigma = torch.tensor([[sigma]])
|
| 627 |
+
|
| 628 |
+
batch_size = sigma.shape[0]
|
| 629 |
+
|
| 630 |
+
x = (torch.arange(window_size, device=sigma.device,
|
| 631 |
+
dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
| 632 |
+
|
| 633 |
+
if window_size % 2 == 0:
|
| 634 |
+
x = x + 0.5
|
| 635 |
+
|
| 636 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
| 637 |
+
|
| 638 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def _gaussian_blur2d(input, kernel_size, sigma):
|
| 642 |
+
if isinstance(sigma, tuple):
|
| 643 |
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
| 644 |
+
else:
|
| 645 |
+
sigma = sigma.to(dtype=input.dtype)
|
| 646 |
+
|
| 647 |
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
| 648 |
+
bs = sigma.shape[0]
|
| 649 |
+
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
|
| 650 |
+
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
|
| 651 |
+
out_x = _filter2d(input, kernel_x[..., None, :])
|
| 652 |
+
out = _filter2d(out_x, kernel_y[..., None])
|
| 653 |
+
|
| 654 |
+
return out
|
| 655 |
+
|
| 656 |
+
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
|
| 657 |
+
h, w = input.shape[-2:]
|
| 658 |
+
factors = (h / size[0], w / size[1])
|
| 659 |
+
|
| 660 |
+
# First, we have to determine sigma
|
| 661 |
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
| 662 |
+
sigmas = (
|
| 663 |
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
| 664 |
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
| 668 |
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
| 669 |
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
| 670 |
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
| 671 |
+
|
| 672 |
+
# Make sure it is odd
|
| 673 |
+
if (ks[0] % 2) == 0:
|
| 674 |
+
ks = ks[0] + 1, ks[1]
|
| 675 |
+
|
| 676 |
+
if (ks[1] % 2) == 0:
|
| 677 |
+
ks = ks[0], ks[1] + 1
|
| 678 |
+
|
| 679 |
+
input = _gaussian_blur2d(input, ks, sigmas)
|
| 680 |
+
|
| 681 |
+
output = torch.nn.functional.interpolate(
|
| 682 |
+
input, size=size, mode=interpolation, align_corners=align_corners)
|
| 683 |
+
return output
|
| 684 |
+
|
| 685 |
+
def train_motion_lora(
|
| 686 |
+
pretrained_model_path,
|
| 687 |
+
output_dir: str,
|
| 688 |
+
train_dataset: SingleClipDataset,
|
| 689 |
+
validation_data: Dict,
|
| 690 |
+
edited_firstframes: List[Image.Image],
|
| 691 |
+
train_data: Dict,
|
| 692 |
+
validation_images: List[Image.Image],
|
| 693 |
+
validation_images_latents: List[torch.Tensor],
|
| 694 |
+
clip_id: int,
|
| 695 |
+
consistency_controller: attention_util.ConsistencyAttentionControl = None,
|
| 696 |
+
consistency_edit_controller_list: List[attention_util.ConsistencyAttentionControl] = [None,],
|
| 697 |
+
consistency_find_modules: Dict = {},
|
| 698 |
+
single_spatial_lora: bool = False,
|
| 699 |
+
train_temporal_lora: bool = True,
|
| 700 |
+
validation_steps: int = 100,
|
| 701 |
+
trainable_modules: Tuple[str] = None, # Eg: ("attn1", "attn2")
|
| 702 |
+
extra_unet_params=None,
|
| 703 |
+
train_batch_size: int = 1,
|
| 704 |
+
max_train_steps: int = 500,
|
| 705 |
+
learning_rate: float = 5e-5,
|
| 706 |
+
lr_scheduler: str = "constant",
|
| 707 |
+
lr_warmup_steps: int = 0,
|
| 708 |
+
adam_beta1: float = 0.9,
|
| 709 |
+
adam_beta2: float = 0.999,
|
| 710 |
+
adam_weight_decay: float = 1e-2,
|
| 711 |
+
adam_epsilon: float = 1e-08,
|
| 712 |
+
gradient_accumulation_steps: int = 1,
|
| 713 |
+
gradient_checkpointing: bool = False,
|
| 714 |
+
image_encoder_gradient_checkpointing: bool = False,
|
| 715 |
+
checkpointing_steps: int = 500,
|
| 716 |
+
resume_from_checkpoint: Optional[str] = None,
|
| 717 |
+
resume_step: Optional[int] = None,
|
| 718 |
+
mixed_precision: Optional[str] = "fp16",
|
| 719 |
+
use_8bit_adam: bool = False,
|
| 720 |
+
enable_xformers_memory_efficient_attention: bool = True,
|
| 721 |
+
enable_torch_2_attn: bool = False,
|
| 722 |
+
seed: Optional[int] = None,
|
| 723 |
+
use_offset_noise: bool = False,
|
| 724 |
+
rescale_schedule: bool = False,
|
| 725 |
+
offset_noise_strength: float = 0.1,
|
| 726 |
+
extend_dataset: bool = False,
|
| 727 |
+
cache_latents: bool = False,
|
| 728 |
+
cached_latent_dir=None,
|
| 729 |
+
use_unet_lora: bool = False,
|
| 730 |
+
unet_lora_modules: Tuple[str] = [],
|
| 731 |
+
image_encoder_lora_modules: Tuple[str] = [],
|
| 732 |
+
save_pretrained_model: bool = True,
|
| 733 |
+
lora_rank: int = 16,
|
| 734 |
+
lora_path: str = '',
|
| 735 |
+
lora_unet_dropout: float = 0.1,
|
| 736 |
+
logger_type: str = 'tensorboard',
|
| 737 |
+
**kwargs
|
| 738 |
+
):
|
| 739 |
+
|
| 740 |
+
*_, config = inspect.getargvalues(inspect.currentframe())
|
| 741 |
+
|
| 742 |
+
accelerator = Accelerator(
|
| 743 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 744 |
+
mixed_precision=mixed_precision,
|
| 745 |
+
log_with=logger_type,
|
| 746 |
+
project_dir=output_dir
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
# Make one log on every process with the configuration for debugging.
|
| 750 |
+
create_logging(logging, logger, accelerator)
|
| 751 |
+
|
| 752 |
+
# Initialize accelerate, transformers, and diffusers warnings
|
| 753 |
+
accelerate_set_verbose(accelerator)
|
| 754 |
+
|
| 755 |
+
# Handle the output folder creation
|
| 756 |
+
if accelerator.is_main_process:
|
| 757 |
+
output_dir = create_output_folders(output_dir, config, clip_id)
|
| 758 |
+
|
| 759 |
+
# Load scheduler, tokenizer and models.
|
| 760 |
+
noise_scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(pretrained_model_path)
|
| 761 |
+
|
| 762 |
+
# Freeze any necessary models
|
| 763 |
+
freeze_models([vae, image_encoder, unet])
|
| 764 |
+
|
| 765 |
+
# Enable xformers if available
|
| 766 |
+
handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
|
| 767 |
+
|
| 768 |
+
# Initialize the optimizer
|
| 769 |
+
optimizer_cls = get_optimizer(use_8bit_adam)
|
| 770 |
+
|
| 771 |
+
# Create parameters to optimize over with a condition (if "condition" is true, optimize it)
|
| 772 |
+
#extra_unet_params = extra_unet_params if extra_unet_params is not None else {}
|
| 773 |
+
#extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {}
|
| 774 |
+
|
| 775 |
+
# Temporal LoRA
|
| 776 |
+
if train_temporal_lora:
|
| 777 |
+
# one temporal lora
|
| 778 |
+
lora_manager_temporal = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["TemporalBasicTransformerBlock"])
|
| 779 |
+
|
| 780 |
+
unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model(
|
| 781 |
+
use_unet_lora, unet, lora_manager_temporal.unet_replace_modules, lora_unet_dropout,
|
| 782 |
+
lora_path + '/temporal/lora/', r=lora_rank)
|
| 783 |
+
|
| 784 |
+
optimizer_temporal = optimizer_cls(
|
| 785 |
+
create_optimizer_params([param_optim(unet_lora_params_temporal, use_unet_lora, is_lora=True,
|
| 786 |
+
extra_params={**{"lr": learning_rate}}
|
| 787 |
+
)], learning_rate),
|
| 788 |
+
lr=learning_rate,
|
| 789 |
+
betas=(adam_beta1, adam_beta2),
|
| 790 |
+
weight_decay=adam_weight_decay,
|
| 791 |
+
eps=adam_epsilon,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
lr_scheduler_temporal = get_scheduler(
|
| 795 |
+
lr_scheduler,
|
| 796 |
+
optimizer=optimizer_temporal,
|
| 797 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
| 798 |
+
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
| 799 |
+
)
|
| 800 |
+
else:
|
| 801 |
+
lora_manager_temporal = None
|
| 802 |
+
unet_lora_params_temporal, unet_negation_temporal = [], []
|
| 803 |
+
optimizer_temporal = None
|
| 804 |
+
lr_scheduler_temporal = None
|
| 805 |
+
|
| 806 |
+
# Spatial LoRAs
|
| 807 |
+
if single_spatial_lora:
|
| 808 |
+
spatial_lora_num = 1
|
| 809 |
+
else:
|
| 810 |
+
# one spatial lora for each video
|
| 811 |
+
spatial_lora_num = train_dataset.__len__()
|
| 812 |
+
|
| 813 |
+
lora_managers_spatial = []
|
| 814 |
+
unet_lora_params_spatial_list = []
|
| 815 |
+
optimizer_spatial_list = []
|
| 816 |
+
lr_scheduler_spatial_list = []
|
| 817 |
+
for i in range(spatial_lora_num):
|
| 818 |
+
lora_manager_spatial = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["BasicTransformerBlock"])
|
| 819 |
+
lora_managers_spatial.append(lora_manager_spatial)
|
| 820 |
+
unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model(
|
| 821 |
+
use_unet_lora, unet, lora_manager_spatial.unet_replace_modules, lora_unet_dropout,
|
| 822 |
+
lora_path + '/spatial/lora/', r=lora_rank)
|
| 823 |
+
|
| 824 |
+
unet_lora_params_spatial_list.append(unet_lora_params_spatial)
|
| 825 |
+
|
| 826 |
+
optimizer_spatial = optimizer_cls(
|
| 827 |
+
create_optimizer_params([param_optim(unet_lora_params_spatial, use_unet_lora, is_lora=True,
|
| 828 |
+
extra_params={**{"lr": learning_rate}}
|
| 829 |
+
)], learning_rate),
|
| 830 |
+
lr=learning_rate,
|
| 831 |
+
betas=(adam_beta1, adam_beta2),
|
| 832 |
+
weight_decay=adam_weight_decay,
|
| 833 |
+
eps=adam_epsilon,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
optimizer_spatial_list.append(optimizer_spatial)
|
| 837 |
+
|
| 838 |
+
# Scheduler
|
| 839 |
+
lr_scheduler_spatial = get_scheduler(
|
| 840 |
+
lr_scheduler,
|
| 841 |
+
optimizer=optimizer_spatial,
|
| 842 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
| 843 |
+
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
| 844 |
+
)
|
| 845 |
+
lr_scheduler_spatial_list.append(lr_scheduler_spatial)
|
| 846 |
+
|
| 847 |
+
unet_negation_all = unet_negation_spatial + unet_negation_temporal
|
| 848 |
+
|
| 849 |
+
# DataLoaders creation:
|
| 850 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 851 |
+
train_dataset,
|
| 852 |
+
batch_size=train_batch_size,
|
| 853 |
+
shuffle=True
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
# Latents caching
|
| 857 |
+
cached_data_loader = handle_cache_latents(
|
| 858 |
+
cache_latents,
|
| 859 |
+
output_dir,
|
| 860 |
+
train_dataloader,
|
| 861 |
+
train_batch_size,
|
| 862 |
+
vae,
|
| 863 |
+
unet,
|
| 864 |
+
cached_latent_dir
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
if cached_data_loader is not None and train_data.get("use_data_aug") is None:
|
| 868 |
+
train_dataloader = cached_data_loader
|
| 869 |
+
|
| 870 |
+
# Prepare everything with our `accelerator`.
|
| 871 |
+
unet, optimizer_temporal, train_dataloader, lr_scheduler_temporal, image_encoder = accelerator.prepare(
|
| 872 |
+
unet,
|
| 873 |
+
optimizer_temporal,
|
| 874 |
+
train_dataloader,
|
| 875 |
+
lr_scheduler_temporal,
|
| 876 |
+
image_encoder
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# Use Gradient Checkpointing if enabled.
|
| 880 |
+
unet_and_text_g_c(
|
| 881 |
+
unet,
|
| 882 |
+
image_encoder,
|
| 883 |
+
gradient_checkpointing,
|
| 884 |
+
image_encoder_gradient_checkpointing
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# Enable VAE slicing to save memory.
|
| 888 |
+
#vae.enable_slicing()
|
| 889 |
+
|
| 890 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 891 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 892 |
+
weight_dtype = is_mixed_precision(accelerator)
|
| 893 |
+
|
| 894 |
+
# Move text encoders, and VAE to GPU
|
| 895 |
+
models_to_cast = [image_encoder, vae]
|
| 896 |
+
cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)
|
| 897 |
+
|
| 898 |
+
# Fix noise schedules to predcit light and dark areas if available.
|
| 899 |
+
# if not use_offset_noise and rescale_schedule:
|
| 900 |
+
# noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas)
|
| 901 |
+
|
| 902 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 903 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
| 904 |
+
|
| 905 |
+
# Afterwards we recalculate our number of training epochs
|
| 906 |
+
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
| 907 |
+
|
| 908 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 909 |
+
# The trackers initializes automatically on the main process.
|
| 910 |
+
if accelerator.is_main_process:
|
| 911 |
+
accelerator.init_trackers("svd-finetune")
|
| 912 |
+
|
| 913 |
+
# Train!
|
| 914 |
+
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
| 915 |
+
|
| 916 |
+
logger.info("***** Running training for motion lora*****")
|
| 917 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 918 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
| 919 |
+
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
| 920 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 921 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
| 922 |
+
logger.info(f" Total optimization steps = {max_train_steps}")
|
| 923 |
+
global_step = 0
|
| 924 |
+
first_epoch = 0
|
| 925 |
+
|
| 926 |
+
def encode_image(pixel_values):
|
| 927 |
+
# pixel_values = pixel_values * 2.0 - 1.0
|
| 928 |
+
pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
|
| 929 |
+
pixel_values = (pixel_values + 1.0) / 2.0
|
| 930 |
+
|
| 931 |
+
# Normalize the image with for CLIP input
|
| 932 |
+
pixel_values = feature_extractor(
|
| 933 |
+
images=pixel_values,
|
| 934 |
+
do_normalize=True,
|
| 935 |
+
do_center_crop=False,
|
| 936 |
+
do_resize=False,
|
| 937 |
+
do_rescale=False,
|
| 938 |
+
return_tensors="pt",
|
| 939 |
+
).pixel_values
|
| 940 |
+
|
| 941 |
+
pixel_values = pixel_values.to(
|
| 942 |
+
device=accelerator.device, dtype=weight_dtype)
|
| 943 |
+
image_embeddings = image_encoder(pixel_values).image_embeds
|
| 944 |
+
image_embeddings= image_embeddings.unsqueeze(1)
|
| 945 |
+
return image_embeddings
|
| 946 |
+
|
| 947 |
+
def _get_add_time_ids(
|
| 948 |
+
fps,
|
| 949 |
+
motion_bucket_ids, # Expecting a list of tensor floats
|
| 950 |
+
noise_aug_strength,
|
| 951 |
+
dtype,
|
| 952 |
+
batch_size,
|
| 953 |
+
unet=None,
|
| 954 |
+
device=None, # Add a device parameter
|
| 955 |
+
):
|
| 956 |
+
# Determine the target device
|
| 957 |
+
target_device = device if device is not None else 'cpu'
|
| 958 |
+
|
| 959 |
+
# Ensure motion_bucket_ids is a tensor and on the target device
|
| 960 |
+
if not isinstance(motion_bucket_ids, torch.Tensor):
|
| 961 |
+
motion_bucket_ids = torch.tensor(motion_bucket_ids, dtype=dtype, device=target_device)
|
| 962 |
+
else:
|
| 963 |
+
motion_bucket_ids = motion_bucket_ids.to(device=target_device)
|
| 964 |
+
|
| 965 |
+
# Reshape motion_bucket_ids if necessary
|
| 966 |
+
if motion_bucket_ids.dim() == 1:
|
| 967 |
+
motion_bucket_ids = motion_bucket_ids.view(-1, 1)
|
| 968 |
+
|
| 969 |
+
# Check for batch size consistency
|
| 970 |
+
if motion_bucket_ids.size(0) != batch_size:
|
| 971 |
+
raise ValueError("The length of motion_bucket_ids must match the batch_size.")
|
| 972 |
+
|
| 973 |
+
# Create fps and noise_aug_strength tensors on the target device
|
| 974 |
+
add_time_ids = torch.tensor([fps, noise_aug_strength], dtype=dtype, device=target_device).repeat(batch_size, 1)
|
| 975 |
+
|
| 976 |
+
# Concatenate with motion_bucket_ids
|
| 977 |
+
add_time_ids = torch.cat([add_time_ids, motion_bucket_ids], dim=1)
|
| 978 |
+
|
| 979 |
+
# Checking the dimensions of the added time embedding
|
| 980 |
+
passed_add_embed_dim = unet.config.addition_time_embed_dim * add_time_ids.size(1)
|
| 981 |
+
expected_add_embed_dim = unet.add_embedding.linear_1.in_features
|
| 982 |
+
|
| 983 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 984 |
+
raise ValueError(
|
| 985 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
| 986 |
+
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. "
|
| 987 |
+
"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
return add_time_ids
|
| 991 |
+
|
| 992 |
+
# Only show the progress bar once on each machine.
|
| 993 |
+
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
| 994 |
+
progress_bar.set_description("Steps")
|
| 995 |
+
|
| 996 |
+
# set consistency controller
|
| 997 |
+
if consistency_controller is not None:
|
| 998 |
+
consistency_train_controller = attention_util.ConsistencyAttentionControl(
|
| 999 |
+
additional_attention_store=consistency_controller,
|
| 1000 |
+
use_inversion_attention=True,
|
| 1001 |
+
save_self_attention=False,
|
| 1002 |
+
save_latents=False,
|
| 1003 |
+
disk_store=True
|
| 1004 |
+
)
|
| 1005 |
+
attention_util.register_attention_control(
|
| 1006 |
+
unet,
|
| 1007 |
+
None,
|
| 1008 |
+
consistency_train_controller,
|
| 1009 |
+
find_modules={},
|
| 1010 |
+
consistency_find_modules=consistency_find_modules
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def finetune_unet(batch, step, mask_spatial_lora=False, mask_temporal_lora=False):
|
| 1014 |
+
nonlocal use_offset_noise
|
| 1015 |
+
nonlocal rescale_schedule
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
# Unfreeze UNET Layers
|
| 1019 |
+
if global_step == 0:
|
| 1020 |
+
already_printed_trainables = False
|
| 1021 |
+
unet.train()
|
| 1022 |
+
handle_trainable_modules(
|
| 1023 |
+
unet,
|
| 1024 |
+
trainable_modules,
|
| 1025 |
+
is_enabled=True,
|
| 1026 |
+
negation=unet_negation_all
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
# Convert videos to latent space
|
| 1030 |
+
#print("use_data_aug", train_data.get("use_data_aug"))
|
| 1031 |
+
if not cache_latents or train_data.get("use_data_aug") is not None:
|
| 1032 |
+
latents = tensor_to_vae_latent(batch["pixel_values"], vae)
|
| 1033 |
+
refer_latents = tensor_to_vae_latent(batch["refer_pixel_values"], vae)
|
| 1034 |
+
cross_latents = tensor_to_vae_latent(batch["cross_pixel_values"], vae)
|
| 1035 |
+
else:
|
| 1036 |
+
latents = batch["latents"]
|
| 1037 |
+
refer_latents = batch["refer_latents"]
|
| 1038 |
+
cross_latents = batch["cross_latents"]
|
| 1039 |
+
|
| 1040 |
+
# Sample noise that we'll add to the latents
|
| 1041 |
+
use_offset_noise = use_offset_noise and not rescale_schedule
|
| 1042 |
+
noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
|
| 1043 |
+
noise_1 = sample_noise(latents, offset_noise_strength, False)
|
| 1044 |
+
bsz = latents.shape[0]
|
| 1045 |
+
|
| 1046 |
+
# Sample a random timestep for each video
|
| 1047 |
+
sigmas, u = rand_cosine_interpolated(shape=[bsz,], image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high,
|
| 1048 |
+
sigma_data=sigma_data, min_value=min_value, max_value=max_value)
|
| 1049 |
+
noise_scheduler.set_timesteps(validation_data.num_inference_steps, device=latents.device)
|
| 1050 |
+
all_sigmas = noise_scheduler.sigmas
|
| 1051 |
+
sigmas = sigmas.to(latents.device)
|
| 1052 |
+
timestep = (validation_data.num_inference_steps - torch.searchsorted(all_sigmas.to(latents.device).flip(dims=(0,)), sigmas, right=False)).clamp(0,validation_data.num_inference_steps-1)[0]
|
| 1053 |
+
u = u.item()
|
| 1054 |
+
if consistency_controller is not None:
|
| 1055 |
+
#timestep = int(u * (validation_data.num_inference_steps-1)+0.5)
|
| 1056 |
+
#print("u", u, "timestep", timestep, "sigmas", sigmas, "all_sigmas", all_sigmas)
|
| 1057 |
+
consistency_train_controller.set_cur_step(timestep)
|
| 1058 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 1059 |
+
# (this is the forward diffusion process)
|
| 1060 |
+
sigmas_reshaped = sigmas.clone()
|
| 1061 |
+
while len(sigmas_reshaped.shape) < len(latents.shape):
|
| 1062 |
+
sigmas_reshaped = sigmas_reshaped.unsqueeze(-1)
|
| 1063 |
+
|
| 1064 |
+
# add noise to the latents or the original image?
|
| 1065 |
+
train_noise_aug = 0.02
|
| 1066 |
+
conditional_latents = refer_latents / vae.config.scaling_factor
|
| 1067 |
+
small_noise_latents = conditional_latents + noise_1[:,0:1,:,:,:] * train_noise_aug
|
| 1068 |
+
conditional_latents = small_noise_latents[:, 0, :, :, :]
|
| 1069 |
+
|
| 1070 |
+
noisy_latents = latents + noise * sigmas_reshaped
|
| 1071 |
+
|
| 1072 |
+
timesteps = torch.Tensor(
|
| 1073 |
+
[0.25 * sigma.log() for sigma in sigmas]).to(latents.device)
|
| 1074 |
+
|
| 1075 |
+
inp_noisy_latents = noisy_latents / ((sigmas_reshaped**2 + 1) ** 0.5)
|
| 1076 |
+
|
| 1077 |
+
# *Potentially* Fixes gradient checkpointing training.
|
| 1078 |
+
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
|
| 1079 |
+
if kwargs.get('eval_train', False):
|
| 1080 |
+
unet.eval()
|
| 1081 |
+
image_encoder.eval()
|
| 1082 |
+
|
| 1083 |
+
# Get the text embedding for conditioning.
|
| 1084 |
+
encoder_hidden_states = encode_image(
|
| 1085 |
+
batch["cross_pixel_values"][:, 0, :, :, :])
|
| 1086 |
+
detached_encoder_state = encoder_hidden_states.clone().detach()
|
| 1087 |
+
|
| 1088 |
+
added_time_ids = _get_add_time_ids(
|
| 1089 |
+
6,
|
| 1090 |
+
batch["motion_values"],
|
| 1091 |
+
train_noise_aug, # noise_aug_strength == 0.0
|
| 1092 |
+
encoder_hidden_states.dtype,
|
| 1093 |
+
bsz,
|
| 1094 |
+
unet,
|
| 1095 |
+
device=latents.device
|
| 1096 |
+
)
|
| 1097 |
+
added_time_ids = added_time_ids.to(latents.device)
|
| 1098 |
+
|
| 1099 |
+
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
|
| 1100 |
+
conditioning_dropout_prob = kwargs.get('conditioning_dropout_prob')
|
| 1101 |
+
if conditioning_dropout_prob is not None:
|
| 1102 |
+
random_p = torch.rand(
|
| 1103 |
+
bsz, device=latents.device, generator=generator)
|
| 1104 |
+
# Sample masks for the edit prompts.
|
| 1105 |
+
prompt_mask = random_p < 2 * conditioning_dropout_prob
|
| 1106 |
+
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
|
| 1107 |
+
# Final text conditioning.
|
| 1108 |
+
null_conditioning = torch.zeros_like(encoder_hidden_states)
|
| 1109 |
+
encoder_hidden_states = torch.where(
|
| 1110 |
+
prompt_mask, null_conditioning, encoder_hidden_states)
|
| 1111 |
+
|
| 1112 |
+
# Sample masks for the original images.
|
| 1113 |
+
image_mask_dtype = conditional_latents.dtype
|
| 1114 |
+
image_mask = 1 - (
|
| 1115 |
+
(random_p >= conditioning_dropout_prob).to(
|
| 1116 |
+
image_mask_dtype)
|
| 1117 |
+
* (random_p < 3 * conditioning_dropout_prob).to(image_mask_dtype)
|
| 1118 |
+
)
|
| 1119 |
+
image_mask = image_mask.reshape(bsz, 1, 1, 1)
|
| 1120 |
+
# Final image conditioning.
|
| 1121 |
+
conditional_latents = image_mask * conditional_latents
|
| 1122 |
+
|
| 1123 |
+
# Concatenate the `conditional_latents` with the `noisy_latents`.
|
| 1124 |
+
conditional_latents = conditional_latents.unsqueeze(
|
| 1125 |
+
1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
|
| 1126 |
+
inp_noisy_latents = torch.cat(
|
| 1127 |
+
[inp_noisy_latents, conditional_latents], dim=2)
|
| 1128 |
+
|
| 1129 |
+
# Get the target for loss depending on the prediction type
|
| 1130 |
+
# if noise_scheduler.config.prediction_type == "epsilon":
|
| 1131 |
+
# target = latents # we are computing loss against denoise latents
|
| 1132 |
+
# elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 1133 |
+
# target = noise_scheduler.get_velocity(
|
| 1134 |
+
# latents, noise, timesteps)
|
| 1135 |
+
# else:
|
| 1136 |
+
# raise ValueError(
|
| 1137 |
+
# f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 1138 |
+
|
| 1139 |
+
target = latents
|
| 1140 |
+
|
| 1141 |
+
encoder_hidden_states = detached_encoder_state
|
| 1142 |
+
|
| 1143 |
+
if True:#mask_spatial_lora:
|
| 1144 |
+
loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
|
| 1145 |
+
for lora_i in loras:
|
| 1146 |
+
lora_i.scale = 0.
|
| 1147 |
+
loss_spatial = None
|
| 1148 |
+
else:
|
| 1149 |
+
loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
|
| 1150 |
+
|
| 1151 |
+
if spatial_lora_num == 1:
|
| 1152 |
+
for lora_i in loras:
|
| 1153 |
+
lora_i.scale = 1.
|
| 1154 |
+
else:
|
| 1155 |
+
for lora_i in loras:
|
| 1156 |
+
lora_i.scale = 0.
|
| 1157 |
+
|
| 1158 |
+
for lora_idx in range(0, len(loras), spatial_lora_num):
|
| 1159 |
+
loras[lora_idx + step].scale = 1.
|
| 1160 |
+
|
| 1161 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
|
| 1162 |
+
if len(loras) > 0:
|
| 1163 |
+
for lora_i in loras:
|
| 1164 |
+
lora_i.scale = 0.
|
| 1165 |
+
|
| 1166 |
+
ran_idx = 0#torch.randint(0, noisy_latents.shape[2], (1,)).item()
|
| 1167 |
+
|
| 1168 |
+
#spatial_inp_noisy_latents = inp_noisy_refer_latents[:, ran_idx:ran_idx+1, :, :, :]
|
| 1169 |
+
inp_noisy_spatial_latents = inp_noisy_latents#[:, ran_idx:ran_idx+1, :, :, :]
|
| 1170 |
+
|
| 1171 |
+
target_spatial = latents#[:, ran_idx:ran_idx+1, :, :, :]
|
| 1172 |
+
# Predict the noise residual
|
| 1173 |
+
model_pred = unet(
|
| 1174 |
+
inp_noisy_spatial_latents, timesteps, encoder_hidden_states,
|
| 1175 |
+
added_time_ids
|
| 1176 |
+
).sample
|
| 1177 |
+
|
| 1178 |
+
sigmas = sigmas_reshaped
|
| 1179 |
+
# Denoise the latents
|
| 1180 |
+
c_out = -sigmas / ((sigmas**2 + 1)**0.5)
|
| 1181 |
+
c_skip = 1 / (sigmas**2 + 1)
|
| 1182 |
+
denoised_latents = model_pred * c_out + c_skip * noisy_latents#[:, ran_idx:ran_idx+1, :, :, :]
|
| 1183 |
+
weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
|
| 1184 |
+
|
| 1185 |
+
# MSE loss
|
| 1186 |
+
loss_spatial = torch.mean(
|
| 1187 |
+
(weighing.float() * (denoised_latents.float() -
|
| 1188 |
+
target_spatial.float()) ** 2).reshape(target_spatial.shape[0], -1),
|
| 1189 |
+
dim=1,
|
| 1190 |
+
)
|
| 1191 |
+
loss_spatial = loss_spatial.mean()
|
| 1192 |
+
|
| 1193 |
+
if mask_temporal_lora:
|
| 1194 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
|
| 1195 |
+
for lora_i in loras:
|
| 1196 |
+
lora_i.scale = 0.
|
| 1197 |
+
loss_temporal = None
|
| 1198 |
+
else:
|
| 1199 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
|
| 1200 |
+
for lora_i in loras:
|
| 1201 |
+
lora_i.scale = 1.
|
| 1202 |
+
# Predict the noise residual
|
| 1203 |
+
model_pred = unet(
|
| 1204 |
+
inp_noisy_latents, timesteps, encoder_hidden_states,
|
| 1205 |
+
added_time_ids=added_time_ids,
|
| 1206 |
+
).sample
|
| 1207 |
+
|
| 1208 |
+
sigmas = sigmas_reshaped
|
| 1209 |
+
# Denoise the latents
|
| 1210 |
+
c_out = -sigmas / ((sigmas**2 + 1)**0.5)
|
| 1211 |
+
c_skip = 1 / (sigmas**2 + 1)
|
| 1212 |
+
denoised_latents = model_pred * c_out + c_skip * noisy_latents
|
| 1213 |
+
if consistency_controller is not None:
|
| 1214 |
+
consistency_train_controller.step_callback(denoised_latents.detach())
|
| 1215 |
+
weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
|
| 1216 |
+
|
| 1217 |
+
# MSE loss
|
| 1218 |
+
loss_temporal = torch.mean(
|
| 1219 |
+
(weighing.float() * (denoised_latents.float() -
|
| 1220 |
+
target.float()) ** 2).reshape(target.shape[0], -1),
|
| 1221 |
+
dim=1,
|
| 1222 |
+
)
|
| 1223 |
+
loss_temporal = loss_temporal.mean()
|
| 1224 |
+
|
| 1225 |
+
# beta = 1
|
| 1226 |
+
# alpha = (beta ** 2 + 1) ** 0.5
|
| 1227 |
+
# ran_idx = torch.randint(0, model_pred.shape[1], (1,)).item()
|
| 1228 |
+
# model_pred_decent = alpha * model_pred - beta * model_pred[:, ran_idx, :, :, :].unsqueeze(1)
|
| 1229 |
+
# target_decent = alpha * target - beta * target[:, ran_idx, :, :, :].unsqueeze(1)
|
| 1230 |
+
# loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
|
| 1231 |
+
loss_temporal = loss_temporal #+ loss_ad_temporal
|
| 1232 |
+
|
| 1233 |
+
return loss_spatial, loss_temporal, latents, noise
|
| 1234 |
+
|
| 1235 |
+
for epoch in range(first_epoch, num_train_epochs):
|
| 1236 |
+
train_loss_spatial = 0.0
|
| 1237 |
+
train_loss_temporal = 0.0
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
for step, batch in enumerate(train_dataloader):
|
| 1241 |
+
#torch.cuda.empty_cache()
|
| 1242 |
+
# Skip steps until we reach the resumed step
|
| 1243 |
+
if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 1244 |
+
if step % gradient_accumulation_steps == 0:
|
| 1245 |
+
progress_bar.update(1)
|
| 1246 |
+
continue
|
| 1247 |
+
|
| 1248 |
+
with accelerator.accumulate(unet):
|
| 1249 |
+
|
| 1250 |
+
for optimizer_spatial in optimizer_spatial_list:
|
| 1251 |
+
optimizer_spatial.zero_grad(set_to_none=True)
|
| 1252 |
+
|
| 1253 |
+
if optimizer_temporal is not None:
|
| 1254 |
+
optimizer_temporal.zero_grad(set_to_none=True)
|
| 1255 |
+
|
| 1256 |
+
if train_temporal_lora:
|
| 1257 |
+
mask_temporal_lora = False
|
| 1258 |
+
else:
|
| 1259 |
+
mask_temporal_lora = True
|
| 1260 |
+
if False:#clip_id != 0:
|
| 1261 |
+
mask_spatial_lora = random.uniform(0, 1) < 0.2 and not mask_temporal_lora
|
| 1262 |
+
else:
|
| 1263 |
+
mask_spatial_lora = True
|
| 1264 |
+
|
| 1265 |
+
with accelerator.autocast():
|
| 1266 |
+
loss_spatial, loss_temporal, latents, init_noise = finetune_unet(batch, step, mask_spatial_lora=mask_spatial_lora, mask_temporal_lora=mask_temporal_lora)
|
| 1267 |
+
|
| 1268 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 1269 |
+
if not mask_spatial_lora:
|
| 1270 |
+
avg_loss_spatial = accelerator.gather(loss_spatial.repeat(train_batch_size)).mean()
|
| 1271 |
+
train_loss_spatial += avg_loss_spatial.item() / gradient_accumulation_steps
|
| 1272 |
+
|
| 1273 |
+
if not mask_temporal_lora and train_temporal_lora:
|
| 1274 |
+
avg_loss_temporal = accelerator.gather(loss_temporal.repeat(train_batch_size)).mean()
|
| 1275 |
+
train_loss_temporal += avg_loss_temporal.item() / gradient_accumulation_steps
|
| 1276 |
+
|
| 1277 |
+
# Backpropagate
|
| 1278 |
+
if not mask_spatial_lora:
|
| 1279 |
+
accelerator.backward(loss_spatial, retain_graph=True)
|
| 1280 |
+
if spatial_lora_num == 1:
|
| 1281 |
+
optimizer_spatial_list[0].step()
|
| 1282 |
+
else:
|
| 1283 |
+
optimizer_spatial_list[step].step()
|
| 1284 |
+
if spatial_lora_num == 1:
|
| 1285 |
+
lr_scheduler_spatial_list[0].step()
|
| 1286 |
+
else:
|
| 1287 |
+
lr_scheduler_spatial_list[step].step()
|
| 1288 |
+
|
| 1289 |
+
if not mask_temporal_lora and train_temporal_lora:
|
| 1290 |
+
accelerator.backward(loss_temporal)
|
| 1291 |
+
optimizer_temporal.step()
|
| 1292 |
+
|
| 1293 |
+
if lr_scheduler_temporal is not None:
|
| 1294 |
+
lr_scheduler_temporal.step()
|
| 1295 |
+
|
| 1296 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1297 |
+
if accelerator.sync_gradients:
|
| 1298 |
+
progress_bar.update(1)
|
| 1299 |
+
global_step += 1
|
| 1300 |
+
accelerator.log({"train_loss": train_loss_temporal}, step=global_step)
|
| 1301 |
+
train_loss_temporal = 0.0
|
| 1302 |
+
if global_step % checkpointing_steps == 0 and global_step > 0:
|
| 1303 |
+
save_pipe(
|
| 1304 |
+
pretrained_model_path,
|
| 1305 |
+
global_step,
|
| 1306 |
+
accelerator,
|
| 1307 |
+
unet,
|
| 1308 |
+
image_encoder,
|
| 1309 |
+
vae,
|
| 1310 |
+
output_dir,
|
| 1311 |
+
lora_manager_spatial,
|
| 1312 |
+
lora_manager_temporal,
|
| 1313 |
+
unet_lora_modules,
|
| 1314 |
+
image_encoder_lora_modules,
|
| 1315 |
+
is_checkpoint=True,
|
| 1316 |
+
save_pretrained_model=save_pretrained_model
|
| 1317 |
+
)
|
| 1318 |
+
|
| 1319 |
+
if should_sample(global_step, validation_steps, validation_data):
|
| 1320 |
+
if accelerator.is_main_process:
|
| 1321 |
+
with accelerator.autocast():
|
| 1322 |
+
unet.eval()
|
| 1323 |
+
image_encoder.eval()
|
| 1324 |
+
generator = torch.Generator(device="cpu")
|
| 1325 |
+
generator.manual_seed(seed)
|
| 1326 |
+
unet_and_text_g_c(unet, image_encoder, False, False)
|
| 1327 |
+
loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
|
| 1328 |
+
for lora_i in loras:
|
| 1329 |
+
lora_i.scale = 0.0
|
| 1330 |
+
|
| 1331 |
+
if consistency_controller is not None:
|
| 1332 |
+
attention_util.register_attention_control(
|
| 1333 |
+
unet,
|
| 1334 |
+
None,
|
| 1335 |
+
consistency_train_controller,
|
| 1336 |
+
find_modules={},
|
| 1337 |
+
consistency_find_modules=consistency_find_modules,
|
| 1338 |
+
undo=True
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
pipeline = P2PStableVideoDiffusionPipeline.from_pretrained(
|
| 1342 |
+
pretrained_model_path,
|
| 1343 |
+
image_encoder=image_encoder,
|
| 1344 |
+
vae=vae,
|
| 1345 |
+
unet=unet
|
| 1346 |
+
)
|
| 1347 |
+
if consistency_controller is not None:
|
| 1348 |
+
pipeline.scheduler = P2PEulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
| 1349 |
+
|
| 1350 |
+
# # recalculate inversed noise latent
|
| 1351 |
+
# if any([np > 0. for np in validation_data.noise_prior]):
|
| 1352 |
+
# pixel_values_for_inv = batch['pixel_values_for_inv'].to('cuda', dtype=torch.float16)
|
| 1353 |
+
# batch['inversion_noise'] = inverse_video(pipeline, batch['latents_for_inv'], 25, pixel_values_for_inv[:,0,:,:,:])
|
| 1354 |
+
|
| 1355 |
+
preset_noises = []
|
| 1356 |
+
for noise_prior in validation_data.noise_prior:
|
| 1357 |
+
if noise_prior > 0:
|
| 1358 |
+
assert batch['inversion_noise'] is not None, "inversion_noise should not be None when noise_prior > 0"
|
| 1359 |
+
preset_noise = (noise_prior) ** 0.5 * batch['inversion_noise'] + (
|
| 1360 |
+
1-noise_prior) ** 0.5 * torch.randn_like(batch['inversion_noise'])
|
| 1361 |
+
#print("preset noise", torch.mean(preset_noise), torch.std(preset_noise))
|
| 1362 |
+
else:
|
| 1363 |
+
preset_noise = None
|
| 1364 |
+
preset_noises.append( preset_noise )
|
| 1365 |
+
|
| 1366 |
+
for val_img_idx in range(len(validation_images)):
|
| 1367 |
+
for i in range(len(preset_noises)):
|
| 1368 |
+
|
| 1369 |
+
if consistency_controller is not None:
|
| 1370 |
+
consistency_edit_controller = attention_util.ConsistencyAttentionControl(
|
| 1371 |
+
additional_attention_store=consistency_edit_controller_list[val_img_idx],
|
| 1372 |
+
use_inversion_attention=False,
|
| 1373 |
+
save_self_attention=False,
|
| 1374 |
+
save_latents=False,
|
| 1375 |
+
disk_store=True
|
| 1376 |
+
)
|
| 1377 |
+
attention_util.register_attention_control(
|
| 1378 |
+
pipeline.unet,
|
| 1379 |
+
None,
|
| 1380 |
+
consistency_edit_controller,
|
| 1381 |
+
find_modules={},
|
| 1382 |
+
consistency_find_modules=consistency_find_modules,
|
| 1383 |
+
)
|
| 1384 |
+
pipeline.scheduler.controller = [consistency_edit_controller]
|
| 1385 |
+
|
| 1386 |
+
preset_noise = preset_noises[i]
|
| 1387 |
+
save_filename = f"step_{global_step}_noise_{i}_{val_img_idx}"
|
| 1388 |
+
|
| 1389 |
+
out_file = f"{output_dir}/samples/{save_filename}.mp4"
|
| 1390 |
+
|
| 1391 |
+
val_img = validation_images[val_img_idx]
|
| 1392 |
+
edited_firstframe = edited_firstframes[val_img_idx]
|
| 1393 |
+
original_res = val_img.size
|
| 1394 |
+
resctrl = ResolutionControl(
|
| 1395 |
+
(original_res[1],original_res[0]),
|
| 1396 |
+
(validation_data.height, validation_data.width),
|
| 1397 |
+
validation_data.get("pad_to_fit", False),
|
| 1398 |
+
fill=0
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
#val_img = Image.open("white.png").convert("RGB")
|
| 1402 |
+
val_img = resctrl(val_img)
|
| 1403 |
+
edited_firstframe = resctrl(edited_firstframe)
|
| 1404 |
+
|
| 1405 |
+
with torch.no_grad():
|
| 1406 |
+
video_frames = pipeline(
|
| 1407 |
+
val_img,
|
| 1408 |
+
edited_firstframe=edited_firstframe,
|
| 1409 |
+
image_latents=validation_images_latents[val_img_idx],
|
| 1410 |
+
width=validation_data.width,
|
| 1411 |
+
height=validation_data.height,
|
| 1412 |
+
num_frames=batch["pixel_values"].shape[1],
|
| 1413 |
+
decode_chunk_size=8,
|
| 1414 |
+
motion_bucket_id=127,
|
| 1415 |
+
fps=validation_data.get('fps', 7),
|
| 1416 |
+
noise_aug_strength=0.02,
|
| 1417 |
+
generator=generator,
|
| 1418 |
+
num_inference_steps=validation_data.num_inference_steps,
|
| 1419 |
+
latents=preset_noise
|
| 1420 |
+
).frames
|
| 1421 |
+
export_to_video(video_frames, out_file, validation_data.get('fps', 7), resctrl)
|
| 1422 |
+
if consistency_controller is not None:
|
| 1423 |
+
attention_util.register_attention_control(
|
| 1424 |
+
pipeline.unet,
|
| 1425 |
+
None,
|
| 1426 |
+
consistency_edit_controller,
|
| 1427 |
+
find_modules={},
|
| 1428 |
+
consistency_find_modules=consistency_find_modules,
|
| 1429 |
+
undo=True
|
| 1430 |
+
)
|
| 1431 |
+
consistency_edit_controller.delete()
|
| 1432 |
+
del consistency_edit_controller
|
| 1433 |
+
logger.info(f"Saved a new sample to {out_file}")
|
| 1434 |
+
if consistency_controller is not None:
|
| 1435 |
+
attention_util.register_attention_control(
|
| 1436 |
+
unet,
|
| 1437 |
+
None,
|
| 1438 |
+
consistency_train_controller,
|
| 1439 |
+
find_modules={},
|
| 1440 |
+
consistency_find_modules=consistency_find_modules,
|
| 1441 |
+
)
|
| 1442 |
+
del pipeline
|
| 1443 |
+
torch.cuda.empty_cache()
|
| 1444 |
+
|
| 1445 |
+
unet_and_text_g_c(
|
| 1446 |
+
unet,
|
| 1447 |
+
image_encoder,
|
| 1448 |
+
gradient_checkpointing,
|
| 1449 |
+
image_encoder_gradient_checkpointing
|
| 1450 |
+
)
|
| 1451 |
+
|
| 1452 |
+
if loss_temporal is not None:
|
| 1453 |
+
accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step)
|
| 1454 |
+
|
| 1455 |
+
if global_step >= max_train_steps:
|
| 1456 |
+
break
|
| 1457 |
+
|
| 1458 |
+
# Create the pipeline using the trained modules and save it.
|
| 1459 |
+
accelerator.wait_for_everyone()
|
| 1460 |
+
if accelerator.is_main_process:
|
| 1461 |
+
save_pipe(
|
| 1462 |
+
pretrained_model_path,
|
| 1463 |
+
global_step,
|
| 1464 |
+
accelerator,
|
| 1465 |
+
unet,
|
| 1466 |
+
image_encoder,
|
| 1467 |
+
vae,
|
| 1468 |
+
output_dir,
|
| 1469 |
+
lora_manager_spatial,
|
| 1470 |
+
lora_manager_temporal,
|
| 1471 |
+
unet_lora_modules,
|
| 1472 |
+
image_encoder_lora_modules,
|
| 1473 |
+
is_checkpoint=False,
|
| 1474 |
+
save_pretrained_model=save_pretrained_model
|
| 1475 |
+
)
|
| 1476 |
+
accelerator.end_training()
|
| 1477 |
+
|
| 1478 |
+
if consistency_controller is not None:
|
| 1479 |
+
consistency_train_controller.delete()
|
| 1480 |
+
del consistency_train_controller
|
| 1481 |
+
|
| 1482 |
+
|
| 1483 |
+
if __name__ == "__main__":
|
| 1484 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
| 1485 |
+
parser = argparse.ArgumentParser()
|
| 1486 |
+
parser.add_argument("--config", type=str, default='./configs/config_multi_videos.yaml')
|
| 1487 |
+
args = parser.parse_args()
|
| 1488 |
+
train_motion_lora(**OmegaConf.load(args.config))
|
i2vedit/utils/__init__.py
ADDED
|
File without changes
|
i2vedit/utils/bucketing.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
|
| 3 |
+
def min_res(size, min_size): return 192 if size < 192 else size
|
| 4 |
+
|
| 5 |
+
def up_down_bucket(m_size, in_size, direction):
|
| 6 |
+
if direction == 'down': return abs(int(m_size - in_size))
|
| 7 |
+
if direction == 'up': return abs(int(m_size + in_size))
|
| 8 |
+
|
| 9 |
+
def get_bucket_sizes(size, direction: 'down', min_size):
|
| 10 |
+
multipliers = [64, 128]
|
| 11 |
+
for i, m in enumerate(multipliers):
|
| 12 |
+
res = up_down_bucket(m, size, direction)
|
| 13 |
+
multipliers[i] = min_res(res, min_size=min_size)
|
| 14 |
+
return multipliers
|
| 15 |
+
|
| 16 |
+
def closest_bucket(m_size, size, direction, min_size):
|
| 17 |
+
lst = get_bucket_sizes(m_size, direction, min_size)
|
| 18 |
+
return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
|
| 19 |
+
|
| 20 |
+
def resolve_bucket(i,h,w): return (i / (h / w))
|
| 21 |
+
|
| 22 |
+
def sensible_buckets(m_width, m_height, w, h, min_size=192):
|
| 23 |
+
if h > w:
|
| 24 |
+
w = resolve_bucket(m_width, h, w)
|
| 25 |
+
w = closest_bucket(m_width, w, 'down', min_size=min_size)
|
| 26 |
+
return w, m_height
|
| 27 |
+
if h < w:
|
| 28 |
+
h = resolve_bucket(m_height, w, h)
|
| 29 |
+
h = closest_bucket(m_height, h, 'down', min_size=min_size)
|
| 30 |
+
return m_width, h
|
| 31 |
+
|
| 32 |
+
return m_width, m_height
|
i2vedit/utils/dataset.py
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import decord
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import json
|
| 6 |
+
import torchvision
|
| 7 |
+
import torchvision.transforms as T
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision.transforms import Resize, Pad, InterpolationMode, ToTensor
|
| 10 |
+
|
| 11 |
+
from glob import glob
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from itertools import islice
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from .bucketing import sensible_buckets
|
| 16 |
+
|
| 17 |
+
decord.bridge.set_bridge('torch')
|
| 18 |
+
|
| 19 |
+
from torch.utils.data import Dataset
|
| 20 |
+
from einops import rearrange, repeat
|
| 21 |
+
|
| 22 |
+
def pad_with_ratio(frames, res, fill=0):
|
| 23 |
+
process = False
|
| 24 |
+
if not isinstance(frames, torch.Tensor):
|
| 25 |
+
frames = ToTensor()(frames).unsqueeze(0)
|
| 26 |
+
process = True
|
| 27 |
+
_, _, ih, iw = frames.shape
|
| 28 |
+
# print("ih, iw", ih, iw)
|
| 29 |
+
i_ratio = ih / iw
|
| 30 |
+
h, w = res
|
| 31 |
+
# print("h,w", h ,w)
|
| 32 |
+
n_ratio = h / w
|
| 33 |
+
if i_ratio > n_ratio:
|
| 34 |
+
nw = int(ih / h * w)
|
| 35 |
+
# print("nw", nw)
|
| 36 |
+
frames = Pad((nw - iw)//2, fill=fill)(frames)
|
| 37 |
+
frames = frames[...,(nw - iw)//2:-(nw - iw)//2,:]
|
| 38 |
+
else:
|
| 39 |
+
nh = int(iw / w * h)
|
| 40 |
+
frames = Pad((nh - ih)//2, fill=fill)(frames)
|
| 41 |
+
frames = frames[...,:,(nh - ih)//2:-(nh - ih)//2]
|
| 42 |
+
# print("after pad", frames.shape)
|
| 43 |
+
if process:
|
| 44 |
+
frames = (frames * 255.).type(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
|
| 45 |
+
frames = Image.fromarray(frames)
|
| 46 |
+
return frames
|
| 47 |
+
|
| 48 |
+
def return_to_original_res(frames, res, pad_to_fix=False):
|
| 49 |
+
process = False
|
| 50 |
+
if not isinstance(frames, torch.Tensor):
|
| 51 |
+
frames = ToTensor()(frames).unsqueeze(0)
|
| 52 |
+
process = True
|
| 53 |
+
|
| 54 |
+
# print("original res", res)
|
| 55 |
+
_, _, h, w = frames.shape
|
| 56 |
+
# print("h w", h, w)
|
| 57 |
+
n_ratio = h / w
|
| 58 |
+
ih, iw = res
|
| 59 |
+
i_ratio = ih / iw
|
| 60 |
+
if pad_to_fix:
|
| 61 |
+
if i_ratio > n_ratio:
|
| 62 |
+
nw = int(ih / h * w)
|
| 63 |
+
frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 64 |
+
frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
|
| 65 |
+
else:
|
| 66 |
+
nh = int(iw / w * h)
|
| 67 |
+
frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 68 |
+
|
| 69 |
+
frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
|
| 70 |
+
else:
|
| 71 |
+
frames = Resize((ih, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 72 |
+
|
| 73 |
+
if process:
|
| 74 |
+
frames = (frames * 255.).type(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
|
| 75 |
+
frames = Image.fromarray(frames)
|
| 76 |
+
|
| 77 |
+
return frames
|
| 78 |
+
|
| 79 |
+
def get_prompt_ids(prompt, tokenizer):
|
| 80 |
+
prompt_ids = tokenizer(
|
| 81 |
+
prompt,
|
| 82 |
+
truncation=True,
|
| 83 |
+
padding="max_length",
|
| 84 |
+
max_length=tokenizer.model_max_length,
|
| 85 |
+
return_tensors="pt",
|
| 86 |
+
).input_ids
|
| 87 |
+
|
| 88 |
+
return prompt_ids
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def read_caption_file(caption_file):
|
| 92 |
+
with open(caption_file, 'r', encoding="utf8") as t:
|
| 93 |
+
return t.read()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_text_prompt(
|
| 97 |
+
text_prompt: str = '',
|
| 98 |
+
fallback_prompt: str= '',
|
| 99 |
+
file_path:str = '',
|
| 100 |
+
ext_types=['.mp4'],
|
| 101 |
+
use_caption=False
|
| 102 |
+
):
|
| 103 |
+
try:
|
| 104 |
+
if use_caption:
|
| 105 |
+
if len(text_prompt) > 1: return text_prompt
|
| 106 |
+
caption_file = ''
|
| 107 |
+
# Use caption on per-video basis (One caption PER video)
|
| 108 |
+
for ext in ext_types:
|
| 109 |
+
maybe_file = file_path.replace(ext, '.txt')
|
| 110 |
+
if maybe_file.endswith(ext_types): continue
|
| 111 |
+
if os.path.exists(maybe_file):
|
| 112 |
+
caption_file = maybe_file
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if os.path.exists(caption_file):
|
| 116 |
+
return read_caption_file(caption_file)
|
| 117 |
+
|
| 118 |
+
# Return fallback prompt if no conditions are met.
|
| 119 |
+
return fallback_prompt
|
| 120 |
+
|
| 121 |
+
return text_prompt
|
| 122 |
+
except:
|
| 123 |
+
print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
|
| 124 |
+
return fallback_prompt
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
|
| 128 |
+
max_range = len(vr)
|
| 129 |
+
frame_number = sorted((0, start_idx, max_range))[1]
|
| 130 |
+
|
| 131 |
+
frame_range = range(frame_number, max_range, sample_rate)
|
| 132 |
+
frame_range_indices = list(frame_range)[:max_frames]
|
| 133 |
+
|
| 134 |
+
return frame_range_indices
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch, pad_to_fix=False, use_aug=False):
|
| 138 |
+
use_aug = False
|
| 139 |
+
if use_bucketing:
|
| 140 |
+
vr = decord.VideoReader(vid_path)
|
| 141 |
+
resize = get_frame_buckets(vr)
|
| 142 |
+
video = get_frame_batch(vr, resize=resize)
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
if not pad_to_fix:
|
| 146 |
+
vr = decord.VideoReader(vid_path, width=w, height=h)
|
| 147 |
+
video = get_frame_batch(vr, use_aug=use_aug)
|
| 148 |
+
else:
|
| 149 |
+
vr = decord.VideoReader(vid_path)
|
| 150 |
+
video = get_frame_batch(vr, use_aug=use_aug)
|
| 151 |
+
video = pad_with_ratio(video, (h, w))
|
| 152 |
+
video = T.transforms.Resize((h, w), antialias=True)(video)
|
| 153 |
+
|
| 154 |
+
return video, vr
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# https://github.com/ExponentialML/Video-BLIP2-Preprocessor
|
| 158 |
+
class VideoJsonDataset(Dataset):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
tokenizer = None,
|
| 162 |
+
width: int = 256,
|
| 163 |
+
height: int = 256,
|
| 164 |
+
n_sample_frames: int = 4,
|
| 165 |
+
sample_start_idx: int = 1,
|
| 166 |
+
frame_step: int = 1,
|
| 167 |
+
json_path: str ="",
|
| 168 |
+
json_data = None,
|
| 169 |
+
vid_data_key: str = "video_path",
|
| 170 |
+
preprocessed: bool = False,
|
| 171 |
+
use_bucketing: bool = False,
|
| 172 |
+
**kwargs
|
| 173 |
+
):
|
| 174 |
+
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
|
| 175 |
+
self.use_bucketing = use_bucketing
|
| 176 |
+
self.tokenizer = tokenizer
|
| 177 |
+
self.preprocessed = preprocessed
|
| 178 |
+
|
| 179 |
+
self.vid_data_key = vid_data_key
|
| 180 |
+
self.train_data = self.load_from_json(json_path, json_data)
|
| 181 |
+
|
| 182 |
+
self.width = width
|
| 183 |
+
self.height = height
|
| 184 |
+
|
| 185 |
+
self.n_sample_frames = n_sample_frames
|
| 186 |
+
self.sample_start_idx = sample_start_idx
|
| 187 |
+
self.frame_step = frame_step
|
| 188 |
+
|
| 189 |
+
def build_json(self, json_data):
|
| 190 |
+
extended_data = []
|
| 191 |
+
for data in json_data['data']:
|
| 192 |
+
for nested_data in data['data']:
|
| 193 |
+
self.build_json_dict(
|
| 194 |
+
data,
|
| 195 |
+
nested_data,
|
| 196 |
+
extended_data
|
| 197 |
+
)
|
| 198 |
+
json_data = extended_data
|
| 199 |
+
return json_data
|
| 200 |
+
|
| 201 |
+
def build_json_dict(self, data, nested_data, extended_data):
|
| 202 |
+
clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
|
| 203 |
+
|
| 204 |
+
extended_data.append({
|
| 205 |
+
self.vid_data_key: data[self.vid_data_key],
|
| 206 |
+
'frame_index': nested_data['frame_index'],
|
| 207 |
+
'prompt': nested_data['prompt'],
|
| 208 |
+
'clip_path': clip_path
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
def load_from_json(self, path, json_data):
|
| 212 |
+
try:
|
| 213 |
+
with open(path) as jpath:
|
| 214 |
+
print(f"Loading JSON from {path}")
|
| 215 |
+
json_data = json.load(jpath)
|
| 216 |
+
|
| 217 |
+
return self.build_json(json_data)
|
| 218 |
+
|
| 219 |
+
except:
|
| 220 |
+
self.train_data = []
|
| 221 |
+
print("Non-existant JSON path. Skipping.")
|
| 222 |
+
|
| 223 |
+
def validate_json(self, base_path, path):
|
| 224 |
+
return os.path.exists(f"{base_path}/{path}")
|
| 225 |
+
|
| 226 |
+
def get_frame_range(self, vr):
|
| 227 |
+
return get_video_frames(
|
| 228 |
+
vr,
|
| 229 |
+
self.sample_start_idx,
|
| 230 |
+
self.frame_step,
|
| 231 |
+
self.n_sample_frames
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def get_vid_idx(self, vr, vid_data=None):
|
| 235 |
+
frames = self.n_sample_frames
|
| 236 |
+
|
| 237 |
+
if vid_data is not None:
|
| 238 |
+
idx = vid_data['frame_index']
|
| 239 |
+
else:
|
| 240 |
+
idx = self.sample_start_idx
|
| 241 |
+
|
| 242 |
+
return idx
|
| 243 |
+
|
| 244 |
+
def get_frame_buckets(self, vr):
|
| 245 |
+
_, h, w = vr[0].shape
|
| 246 |
+
width, height = sensible_buckets(self.width, self.height, h, w)
|
| 247 |
+
# width, height = self.width, self.height
|
| 248 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
| 249 |
+
|
| 250 |
+
return resize
|
| 251 |
+
|
| 252 |
+
def get_frame_batch(self, vr, resize=None):
|
| 253 |
+
frame_range = self.get_frame_range(vr)
|
| 254 |
+
frames = vr.get_batch(frame_range)
|
| 255 |
+
video = rearrange(frames, "f h w c -> f c h w")
|
| 256 |
+
|
| 257 |
+
if resize is not None: video = resize(video)
|
| 258 |
+
return video
|
| 259 |
+
|
| 260 |
+
def process_video_wrapper(self, vid_path):
|
| 261 |
+
video, vr = process_video(
|
| 262 |
+
vid_path,
|
| 263 |
+
self.use_bucketing,
|
| 264 |
+
self.width,
|
| 265 |
+
self.height,
|
| 266 |
+
self.get_frame_buckets,
|
| 267 |
+
self.get_frame_batch
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return video, vr
|
| 271 |
+
|
| 272 |
+
def train_data_batch(self, index):
|
| 273 |
+
|
| 274 |
+
# If we are training on individual clips.
|
| 275 |
+
if 'clip_path' in self.train_data[index] and \
|
| 276 |
+
self.train_data[index]['clip_path'] is not None:
|
| 277 |
+
|
| 278 |
+
vid_data = self.train_data[index]
|
| 279 |
+
|
| 280 |
+
clip_path = vid_data['clip_path']
|
| 281 |
+
|
| 282 |
+
# Get video prompt
|
| 283 |
+
prompt = vid_data['prompt']
|
| 284 |
+
|
| 285 |
+
video, _ = self.process_video_wrapper(clip_path)
|
| 286 |
+
|
| 287 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
| 288 |
+
|
| 289 |
+
return video, prompt, prompt_ids
|
| 290 |
+
|
| 291 |
+
# Assign train data
|
| 292 |
+
train_data = self.train_data[index]
|
| 293 |
+
|
| 294 |
+
# Get the frame of the current index.
|
| 295 |
+
self.sample_start_idx = train_data['frame_index']
|
| 296 |
+
|
| 297 |
+
# Initialize resize
|
| 298 |
+
resize = None
|
| 299 |
+
|
| 300 |
+
video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
|
| 301 |
+
|
| 302 |
+
# Get video prompt
|
| 303 |
+
prompt = train_data['prompt']
|
| 304 |
+
vr.seek(0)
|
| 305 |
+
|
| 306 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
| 307 |
+
|
| 308 |
+
return video, prompt, prompt_ids
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def __getname__(): return 'json'
|
| 312 |
+
|
| 313 |
+
def __len__(self):
|
| 314 |
+
if self.train_data is not None:
|
| 315 |
+
return len(self.train_data)
|
| 316 |
+
else:
|
| 317 |
+
return 0
|
| 318 |
+
|
| 319 |
+
def __getitem__(self, index):
|
| 320 |
+
|
| 321 |
+
# Initialize variables
|
| 322 |
+
video = None
|
| 323 |
+
prompt = None
|
| 324 |
+
prompt_ids = None
|
| 325 |
+
|
| 326 |
+
# Use default JSON training
|
| 327 |
+
if self.train_data is not None:
|
| 328 |
+
video, prompt, prompt_ids = self.train_data_batch(index)
|
| 329 |
+
|
| 330 |
+
example = {
|
| 331 |
+
"pixel_values": (video / 127.5 - 1.0),
|
| 332 |
+
"prompt_ids": prompt_ids[0],
|
| 333 |
+
"text_prompt": prompt,
|
| 334 |
+
'dataset': self.__getname__()
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
return example
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class SingleVideoDataset(Dataset):
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
width: int = 256,
|
| 344 |
+
height: int = 256,
|
| 345 |
+
inversion_width: int = 256,
|
| 346 |
+
inversion_height: int = 256,
|
| 347 |
+
start_t: float=0,
|
| 348 |
+
end_t: float=-1,
|
| 349 |
+
sample_fps: int=-1,
|
| 350 |
+
single_video_path: str = "",
|
| 351 |
+
refer_image_path: str = "",
|
| 352 |
+
use_caption: bool = False,
|
| 353 |
+
use_bucketing: bool = False,
|
| 354 |
+
pad_to_fix: bool = False,
|
| 355 |
+
use_aug: bool = False,
|
| 356 |
+
**kwargs
|
| 357 |
+
):
|
| 358 |
+
self.use_bucketing = use_bucketing
|
| 359 |
+
self.frames = []
|
| 360 |
+
self.index = 1
|
| 361 |
+
|
| 362 |
+
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
|
| 363 |
+
self.start_t = start_t
|
| 364 |
+
self.end_t = end_t
|
| 365 |
+
self.output_fps = sample_fps
|
| 366 |
+
|
| 367 |
+
self.single_video_path = single_video_path
|
| 368 |
+
self.refer_image_path = refer_image_path
|
| 369 |
+
|
| 370 |
+
self.width = width
|
| 371 |
+
self.height = height
|
| 372 |
+
self.inversion_width = inversion_width
|
| 373 |
+
self.inversion_height = inversion_height
|
| 374 |
+
|
| 375 |
+
self.pad_to_fix = pad_to_fix
|
| 376 |
+
|
| 377 |
+
self.use_aug = use_aug
|
| 378 |
+
#self.data_augment = ControlNetDataAugmentation()
|
| 379 |
+
|
| 380 |
+
def create_video_chunks(self):
|
| 381 |
+
output_fps = self.output_fps
|
| 382 |
+
start_t = self.start_t
|
| 383 |
+
end_t = self.end_t
|
| 384 |
+
vr = decord.VideoReader(self.single_video_path)
|
| 385 |
+
initial_fps = vr.get_avg_fps()
|
| 386 |
+
if output_fps == -1:
|
| 387 |
+
output_fps = int(initial_fps)
|
| 388 |
+
if end_t == -1:
|
| 389 |
+
end_t = len(vr) / initial_fps
|
| 390 |
+
else:
|
| 391 |
+
end_t = min(len(vr) / initial_fps, end_t)
|
| 392 |
+
assert 0 <= start_t < end_t
|
| 393 |
+
assert output_fps > 0
|
| 394 |
+
start_f_ind = int(start_t * initial_fps)
|
| 395 |
+
end_f_ind = int(end_t * initial_fps)
|
| 396 |
+
num_f = int((end_t - start_t) * output_fps)
|
| 397 |
+
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
|
| 398 |
+
self.frames = [sample_idx]
|
| 399 |
+
return self.frames
|
| 400 |
+
|
| 401 |
+
def chunk(self, it, size):
|
| 402 |
+
it = iter(it)
|
| 403 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 404 |
+
|
| 405 |
+
def get_frame_batch(self, vr, resize=None, use_aug=False):
|
| 406 |
+
index = self.index
|
| 407 |
+
frames = vr.get_batch(self.frames[self.index])
|
| 408 |
+
|
| 409 |
+
if use_aug:
|
| 410 |
+
frames = self.data_augment.augment(frames)
|
| 411 |
+
print(frames.min(), frames.max())
|
| 412 |
+
|
| 413 |
+
video = rearrange(frames, "f h w c -> f c h w")
|
| 414 |
+
|
| 415 |
+
if resize is not None: video = resize(video)
|
| 416 |
+
return video
|
| 417 |
+
|
| 418 |
+
def get_frame_buckets(self, vr):
|
| 419 |
+
h, w, c = vr[0].shape
|
| 420 |
+
width, height = sensible_buckets(self.width, self.height, w, h)
|
| 421 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
| 422 |
+
|
| 423 |
+
return resize
|
| 424 |
+
|
| 425 |
+
def process_video_wrapper(self, vid_path):
|
| 426 |
+
video, vr = process_video(
|
| 427 |
+
vid_path,
|
| 428 |
+
self.use_bucketing,
|
| 429 |
+
self.width,
|
| 430 |
+
self.height,
|
| 431 |
+
self.get_frame_buckets,
|
| 432 |
+
self.get_frame_batch,
|
| 433 |
+
self.pad_to_fix,
|
| 434 |
+
self.use_aug
|
| 435 |
+
)
|
| 436 |
+
video_for_inversion, vr = process_video(
|
| 437 |
+
vid_path,
|
| 438 |
+
self.use_bucketing,
|
| 439 |
+
self.inversion_width,
|
| 440 |
+
self.inversion_height,
|
| 441 |
+
self.get_frame_buckets,
|
| 442 |
+
self.get_frame_batch,
|
| 443 |
+
self.pad_to_fix
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return video, video_for_inversion, vr
|
| 447 |
+
|
| 448 |
+
def image_batch(self):
|
| 449 |
+
train_data = self.refer_image_path
|
| 450 |
+
img = train_data
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
|
| 454 |
+
except:
|
| 455 |
+
img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
|
| 456 |
+
|
| 457 |
+
width = self.width
|
| 458 |
+
height = self.height
|
| 459 |
+
|
| 460 |
+
if self.use_bucketing:
|
| 461 |
+
_, h, w = img.shape
|
| 462 |
+
width, height = sensible_buckets(width, height, w, h)
|
| 463 |
+
|
| 464 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
| 465 |
+
|
| 466 |
+
img = resize(img)
|
| 467 |
+
img = repeat(img, 'c h w -> f c h w', f=1)
|
| 468 |
+
|
| 469 |
+
return img
|
| 470 |
+
|
| 471 |
+
def single_video_batch(self, index):
|
| 472 |
+
train_data = self.single_video_path
|
| 473 |
+
self.index = index
|
| 474 |
+
|
| 475 |
+
if train_data.endswith(self.vid_types):
|
| 476 |
+
video, video_for_inv, _ = self.process_video_wrapper(train_data)
|
| 477 |
+
|
| 478 |
+
return video, video_for_inv
|
| 479 |
+
else:
|
| 480 |
+
raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def __getname__(): return 'single_video'
|
| 484 |
+
|
| 485 |
+
def __len__(self):
|
| 486 |
+
|
| 487 |
+
return len(self.create_video_chunks())
|
| 488 |
+
|
| 489 |
+
def __getitem__(self, index):
|
| 490 |
+
|
| 491 |
+
video, video_for_inv = self.single_video_batch(index)
|
| 492 |
+
image = self.image_batch()
|
| 493 |
+
motion_values = torch.Tensor([127.])
|
| 494 |
+
|
| 495 |
+
example = {
|
| 496 |
+
"pixel_values": (video / 127.5 - 1.0),
|
| 497 |
+
"pixel_values_for_inv": (video_for_inv / 127.5 - 1.0),
|
| 498 |
+
"refer_pixel_values": (image / 127.5 - 1.0),
|
| 499 |
+
"motion_values": motion_values,
|
| 500 |
+
'dataset': self.__getname__()
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
return example
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class ImageDataset(Dataset):
|
| 507 |
+
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
tokenizer = None,
|
| 511 |
+
width: int = 256,
|
| 512 |
+
height: int = 256,
|
| 513 |
+
base_width: int = 256,
|
| 514 |
+
base_height: int = 256,
|
| 515 |
+
use_caption: bool = False,
|
| 516 |
+
image_dir: str = '',
|
| 517 |
+
single_img_prompt: str = '',
|
| 518 |
+
use_bucketing: bool = False,
|
| 519 |
+
fallback_prompt: str = '',
|
| 520 |
+
**kwargs
|
| 521 |
+
):
|
| 522 |
+
self.tokenizer = tokenizer
|
| 523 |
+
self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
|
| 524 |
+
self.use_bucketing = use_bucketing
|
| 525 |
+
|
| 526 |
+
self.image_dir = self.get_images_list(image_dir)
|
| 527 |
+
self.fallback_prompt = fallback_prompt
|
| 528 |
+
|
| 529 |
+
self.use_caption = use_caption
|
| 530 |
+
self.single_img_prompt = single_img_prompt
|
| 531 |
+
|
| 532 |
+
self.width = width
|
| 533 |
+
self.height = height
|
| 534 |
+
|
| 535 |
+
def get_images_list(self, image_dir):
|
| 536 |
+
if os.path.exists(image_dir):
|
| 537 |
+
imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
|
| 538 |
+
full_img_dir = []
|
| 539 |
+
|
| 540 |
+
for img in imgs:
|
| 541 |
+
full_img_dir.append(f"{image_dir}/{img}")
|
| 542 |
+
|
| 543 |
+
return sorted(full_img_dir)
|
| 544 |
+
|
| 545 |
+
return ['']
|
| 546 |
+
|
| 547 |
+
def image_batch(self, index):
|
| 548 |
+
train_data = self.image_dir[index]
|
| 549 |
+
img = train_data
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
|
| 553 |
+
except:
|
| 554 |
+
img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
|
| 555 |
+
|
| 556 |
+
width = self.width
|
| 557 |
+
height = self.height
|
| 558 |
+
|
| 559 |
+
if self.use_bucketing:
|
| 560 |
+
_, h, w = img.shape
|
| 561 |
+
width, height = sensible_buckets(width, height, w, h)
|
| 562 |
+
|
| 563 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
| 564 |
+
|
| 565 |
+
img = resize(img)
|
| 566 |
+
img = repeat(img, 'c h w -> f c h w', f=16)
|
| 567 |
+
|
| 568 |
+
prompt = get_text_prompt(
|
| 569 |
+
file_path=train_data,
|
| 570 |
+
text_prompt=self.single_img_prompt,
|
| 571 |
+
fallback_prompt=self.fallback_prompt,
|
| 572 |
+
ext_types=self.img_types,
|
| 573 |
+
use_caption=True
|
| 574 |
+
)
|
| 575 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
| 576 |
+
|
| 577 |
+
return img, prompt, prompt_ids
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def __getname__(): return 'image'
|
| 581 |
+
|
| 582 |
+
def __len__(self):
|
| 583 |
+
# Image directory
|
| 584 |
+
if os.path.exists(self.image_dir[0]):
|
| 585 |
+
return len(self.image_dir)
|
| 586 |
+
else:
|
| 587 |
+
return 0
|
| 588 |
+
|
| 589 |
+
def __getitem__(self, index):
|
| 590 |
+
img, prompt, prompt_ids = self.image_batch(index)
|
| 591 |
+
example = {
|
| 592 |
+
"pixel_values": (img / 127.5 - 1.0),
|
| 593 |
+
"prompt_ids": prompt_ids[0],
|
| 594 |
+
"text_prompt": prompt,
|
| 595 |
+
'dataset': self.__getname__()
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
return example
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class VideoFolderDataset(Dataset):
|
| 602 |
+
def __init__(
|
| 603 |
+
self,
|
| 604 |
+
tokenizer=None,
|
| 605 |
+
width: int = 256,
|
| 606 |
+
height: int = 256,
|
| 607 |
+
n_sample_frames: int = 16,
|
| 608 |
+
fps: int = 8,
|
| 609 |
+
path: str = "./data",
|
| 610 |
+
fallback_prompt: str = "",
|
| 611 |
+
use_bucketing: bool = False,
|
| 612 |
+
**kwargs
|
| 613 |
+
):
|
| 614 |
+
self.tokenizer = tokenizer
|
| 615 |
+
self.use_bucketing = use_bucketing
|
| 616 |
+
|
| 617 |
+
self.fallback_prompt = fallback_prompt
|
| 618 |
+
|
| 619 |
+
self.video_files = glob(f"{path}/*.mp4")
|
| 620 |
+
|
| 621 |
+
self.width = width
|
| 622 |
+
self.height = height
|
| 623 |
+
|
| 624 |
+
self.n_sample_frames = n_sample_frames
|
| 625 |
+
self.fps = fps
|
| 626 |
+
|
| 627 |
+
def get_frame_buckets(self, vr):
|
| 628 |
+
h, w, c = vr[0].shape
|
| 629 |
+
width, height = sensible_buckets(self.width, self.height, w, h)
|
| 630 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
| 631 |
+
|
| 632 |
+
return resize
|
| 633 |
+
|
| 634 |
+
def get_frame_batch(self, vr, resize=None):
|
| 635 |
+
n_sample_frames = self.n_sample_frames
|
| 636 |
+
native_fps = vr.get_avg_fps()
|
| 637 |
+
|
| 638 |
+
every_nth_frame = max(1, round(native_fps / self.fps))
|
| 639 |
+
every_nth_frame = min(len(vr), every_nth_frame)
|
| 640 |
+
|
| 641 |
+
effective_length = len(vr) // every_nth_frame
|
| 642 |
+
if effective_length < n_sample_frames:
|
| 643 |
+
n_sample_frames = effective_length
|
| 644 |
+
|
| 645 |
+
effective_idx = random.randint(0, (effective_length - n_sample_frames))
|
| 646 |
+
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
|
| 647 |
+
|
| 648 |
+
video = vr.get_batch(idxs)
|
| 649 |
+
video = rearrange(video, "f h w c -> f c h w")
|
| 650 |
+
|
| 651 |
+
if resize is not None: video = resize(video)
|
| 652 |
+
return video, vr
|
| 653 |
+
|
| 654 |
+
def process_video_wrapper(self, vid_path):
|
| 655 |
+
video, vr = process_video(
|
| 656 |
+
vid_path,
|
| 657 |
+
self.use_bucketing,
|
| 658 |
+
self.width,
|
| 659 |
+
self.height,
|
| 660 |
+
self.get_frame_buckets,
|
| 661 |
+
self.get_frame_batch
|
| 662 |
+
)
|
| 663 |
+
return video, vr
|
| 664 |
+
|
| 665 |
+
def get_prompt_ids(self, prompt):
|
| 666 |
+
return self.tokenizer(
|
| 667 |
+
prompt,
|
| 668 |
+
truncation=True,
|
| 669 |
+
padding="max_length",
|
| 670 |
+
max_length=self.tokenizer.model_max_length,
|
| 671 |
+
return_tensors="pt",
|
| 672 |
+
).input_ids
|
| 673 |
+
|
| 674 |
+
@staticmethod
|
| 675 |
+
def __getname__(): return 'folder'
|
| 676 |
+
|
| 677 |
+
def __len__(self):
|
| 678 |
+
return len(self.video_files)
|
| 679 |
+
|
| 680 |
+
def __getitem__(self, index):
|
| 681 |
+
|
| 682 |
+
video, _ = self.process_video_wrapper(self.video_files[index])
|
| 683 |
+
|
| 684 |
+
prompt = self.fallback_prompt
|
| 685 |
+
|
| 686 |
+
prompt_ids = self.get_prompt_ids(prompt)
|
| 687 |
+
|
| 688 |
+
return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class CachedDataset(Dataset):
|
| 692 |
+
def __init__(self,cache_dir: str = ''):
|
| 693 |
+
self.cache_dir = cache_dir
|
| 694 |
+
self.cached_data_list = self.get_files_list()
|
| 695 |
+
|
| 696 |
+
def get_files_list(self):
|
| 697 |
+
tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
|
| 698 |
+
return sorted(tensors_list)
|
| 699 |
+
|
| 700 |
+
def __len__(self):
|
| 701 |
+
return len(self.cached_data_list)
|
| 702 |
+
|
| 703 |
+
def __getitem__(self, index):
|
| 704 |
+
cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
|
| 705 |
+
return cached_latent
|
i2vedit/utils/euler_utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import Union
|
| 5 |
+
import copy
|
| 6 |
+
from scipy.stats import anderson
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 12 |
+
from i2vedit.prompt_attention import attention_util
|
| 13 |
+
|
| 14 |
+
# Euler Inversion
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def init_image(image, firstframe, pipeline):
|
| 17 |
+
if isinstance(image, torch.Tensor):
|
| 18 |
+
height, width = image.shape[-2:]
|
| 19 |
+
image = (image + 1) / 2. * 255.
|
| 20 |
+
image = image.type(torch.uint8).squeeze().permute(1,2,0).cpu().numpy()
|
| 21 |
+
image = Image.fromarray(image)
|
| 22 |
+
if isinstance(firstframe, torch.Tensor):
|
| 23 |
+
firstframe = (firstframe + 1) / 2. * 255.
|
| 24 |
+
firstframe = firstframe.type(torch.uint8).squeeze().permute(1,2,0).cpu().numpy()
|
| 25 |
+
firstframe = Image.fromarray(firstframe)
|
| 26 |
+
|
| 27 |
+
device = pipeline._execution_device
|
| 28 |
+
image_embeddings = pipeline._encode_image(firstframe, device, 1, False)
|
| 29 |
+
image = pipeline.image_processor.preprocess(image, height=height, width=width)
|
| 30 |
+
firstframe = pipeline.image_processor.preprocess(firstframe, height=height, width=width)
|
| 31 |
+
#print(image.dtype)
|
| 32 |
+
noise = torch.randn(image.shape, device=image.device, dtype=image.dtype)
|
| 33 |
+
image = image + 0.02 * noise
|
| 34 |
+
firstframe = firstframe + 0.02 * noise
|
| 35 |
+
#print(image.dtype)
|
| 36 |
+
image_latents = pipeline._encode_vae_image(image, device, 1, False)
|
| 37 |
+
firstframe_latents = pipeline._encode_vae_image(firstframe, device, 1, False)
|
| 38 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 39 |
+
firstframe_latents = firstframe_latents.to(image_embeddings.dtype)
|
| 40 |
+
|
| 41 |
+
return image_embeddings, image_latents, firstframe_latents
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], sigma, sigma_next,
|
| 45 |
+
sample: Union[torch.FloatTensor, np.ndarray], euler_scheduler, controller=None, consistency_controller=None):
|
| 46 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 47 |
+
if controller is not None:
|
| 48 |
+
pred_original_sample = controller.step_callback(pred_original_sample)
|
| 49 |
+
if consistency_controller is not None:
|
| 50 |
+
pred_original_sample = consistency_controller.step_callback(pred_original_sample)
|
| 51 |
+
#print("sample", sample.mean(), sample.std(), "pred_original_sample", pred_original_sample.mean())
|
| 52 |
+
#pred_original_sample = sample.mean() - pred_original_sample.mean() + pred_original_sample
|
| 53 |
+
next_sample = sample + (sigma_next - sigma) * (sample - pred_original_sample) / sigma
|
| 54 |
+
#print(sigma, sigma_next)
|
| 55 |
+
#print("next sample", torch.mean(next_sample), torch.std(next_sample))
|
| 56 |
+
return next_sample
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_model_pred_single(latents, t, image_embeddings, added_time_ids, unet):
|
| 60 |
+
noise_pred = unet(
|
| 61 |
+
latents,
|
| 62 |
+
t,
|
| 63 |
+
encoder_hidden_states=image_embeddings,
|
| 64 |
+
added_time_ids=added_time_ids,
|
| 65 |
+
return_dict=False,
|
| 66 |
+
)[0]
|
| 67 |
+
return noise_pred
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def euler_loop(pipeline, euler_scheduler, latents, num_inv_steps, image, firstframe, controller=None, consistency_controller=None):
|
| 71 |
+
device = pipeline._execution_device
|
| 72 |
+
|
| 73 |
+
# prepare image conditions
|
| 74 |
+
image_embeddings, image_latents, firstframe_latents = init_image(image, firstframe, pipeline)
|
| 75 |
+
skip = 1#latents.shape[1]
|
| 76 |
+
image_latents = torch.cat(
|
| 77 |
+
[
|
| 78 |
+
image_latents.unsqueeze(1).repeat(1, skip, 1, 1, 1),
|
| 79 |
+
firstframe_latents.unsqueeze(1).repeat(1, latents.shape[1]-skip, 1, 1, 1)
|
| 80 |
+
],
|
| 81 |
+
dim=1
|
| 82 |
+
)
|
| 83 |
+
#image_latents = image_latents.unsqueeze(1).repeat(1, latents.shape[1], 1, 1, 1)
|
| 84 |
+
|
| 85 |
+
# Get Added Time IDs
|
| 86 |
+
added_time_ids = pipeline._get_add_time_ids(
|
| 87 |
+
8,
|
| 88 |
+
127,
|
| 89 |
+
0.02,
|
| 90 |
+
image_embeddings.dtype,
|
| 91 |
+
1,
|
| 92 |
+
1,
|
| 93 |
+
False
|
| 94 |
+
)
|
| 95 |
+
added_time_ids = added_time_ids.to(device)
|
| 96 |
+
|
| 97 |
+
# Prepare timesteps
|
| 98 |
+
euler_scheduler.set_timesteps(num_inv_steps, device=device)
|
| 99 |
+
sigmas_0 = euler_scheduler.sigmas[-2] * euler_scheduler.sigmas[-2] / euler_scheduler.sigmas[-3]
|
| 100 |
+
timesteps = torch.cat([euler_scheduler.timesteps[1:],torch.Tensor([0.25 * sigmas_0.log()]).to(device)])
|
| 101 |
+
sigmas = copy.deepcopy(euler_scheduler.sigmas)
|
| 102 |
+
sigmas[-1] = sigmas_0
|
| 103 |
+
#print(sigmas)
|
| 104 |
+
|
| 105 |
+
# prepare latents
|
| 106 |
+
all_latent = [latents]
|
| 107 |
+
latents = latents.clone().detach()
|
| 108 |
+
|
| 109 |
+
for i in tqdm(range(num_inv_steps)):
|
| 110 |
+
t = timesteps[len(timesteps) -i -1]
|
| 111 |
+
sigma = sigmas[len(sigmas) -i -1]
|
| 112 |
+
sigma_next = sigmas[len(sigmas) -i -2]
|
| 113 |
+
latent_model_input = latents
|
| 114 |
+
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
| 115 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 116 |
+
model_pred = get_model_pred_single(latent_model_input, t, image_embeddings, added_time_ids, pipeline.unet)
|
| 117 |
+
latents = next_step(model_pred, sigma, sigma_next, latents, euler_scheduler, controller=controller, consistency_controller=consistency_controller)
|
| 118 |
+
all_latent.append(latents)
|
| 119 |
+
all_latent[-1] = all_latent[-1] / ((sigmas[0]**2 + 1)**0.5)
|
| 120 |
+
return all_latent
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def euler_inversion(pipeline, euler_scheduler, video_latent, num_inv_steps, image, firstframe, controller=None, consistency_controller=None):
|
| 125 |
+
euler_latents = euler_loop(pipeline, euler_scheduler, video_latent, num_inv_steps, image, firstframe, controller=controller, consistency_controller=consistency_controller)
|
| 126 |
+
return euler_latents
|
| 127 |
+
|
| 128 |
+
from diffusers import EulerDiscreteScheduler
|
| 129 |
+
from .model_utils import tensor_to_vae_latent, load_primary_models, handle_memory_attention
|
| 130 |
+
|
| 131 |
+
def inverse_video(
|
| 132 |
+
pretrained_model_path,
|
| 133 |
+
video,
|
| 134 |
+
keyframe,
|
| 135 |
+
firstframe,
|
| 136 |
+
num_steps,
|
| 137 |
+
resctrl=None,
|
| 138 |
+
sard=None,
|
| 139 |
+
enable_xformers_memory_efficient_attention=True,
|
| 140 |
+
enable_torch_2_attn=False,
|
| 141 |
+
store_controller = None,
|
| 142 |
+
consistency_store_controller = None,
|
| 143 |
+
find_modules={},
|
| 144 |
+
consistency_find_modules={},
|
| 145 |
+
sarp_noise_scale=0.002,
|
| 146 |
+
):
|
| 147 |
+
dtype = torch.float32
|
| 148 |
+
|
| 149 |
+
# check if inverted latents exists
|
| 150 |
+
for _controller in [store_controller, consistency_store_controller]:
|
| 151 |
+
if _controller is not None:
|
| 152 |
+
if os.path.exists(os.path.join(_controller.store_dir, "inverted_latents.pt")):
|
| 153 |
+
euler_inv_latent = torch.load(os.path.join(_controller.store_dir, "inverted_latents.pt")).to("cuda", dtype)
|
| 154 |
+
print(f"Successfully load inverted latents from {os.path.join(_controller.store_dir, 'inverted_latents.pt')}")
|
| 155 |
+
return euler_inv_latent
|
| 156 |
+
|
| 157 |
+
# prepare model, Load scheduler, tokenizer and models.
|
| 158 |
+
noise_scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(pretrained_model_path)
|
| 159 |
+
|
| 160 |
+
# Enable xformers if available
|
| 161 |
+
handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
|
| 162 |
+
|
| 163 |
+
vae.to('cuda', dtype=dtype)
|
| 164 |
+
unet.to('cuda')
|
| 165 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
| 166 |
+
pretrained_model_path,
|
| 167 |
+
feature_extractor=feature_extractor,
|
| 168 |
+
image_encoder=image_encoder,
|
| 169 |
+
vae=vae,
|
| 170 |
+
unet=unet
|
| 171 |
+
)
|
| 172 |
+
pipe.image_encoder.to('cuda')
|
| 173 |
+
|
| 174 |
+
attention_util.register_attention_control(
|
| 175 |
+
pipe.unet,
|
| 176 |
+
store_controller,
|
| 177 |
+
consistency_store_controller,
|
| 178 |
+
find_modules=find_modules,
|
| 179 |
+
consistency_find_modules=consistency_find_modules
|
| 180 |
+
)
|
| 181 |
+
if store_controller is not None:
|
| 182 |
+
store_controller.LOW_RESOURCE = True
|
| 183 |
+
|
| 184 |
+
video_for_inv = torch.cat([firstframe,keyframe,video],dim=1).to(dtype)
|
| 185 |
+
#print(video_for_inv.shape)
|
| 186 |
+
if resctrl is not None:
|
| 187 |
+
video_for_inv = resctrl(video_for_inv)
|
| 188 |
+
if sard is not None:
|
| 189 |
+
indx = sard.detection(video_for_inv, 0.001)
|
| 190 |
+
#import cv2
|
| 191 |
+
#cv2.imwrite("indx.png", indx[0,0,:,:,:].permute(1,2,0).type(torch.uint8).cpu().numpy()*255)
|
| 192 |
+
noise = torch.randn(video_for_inv.shape, device=video.device, dtype=video.dtype)
|
| 193 |
+
video_for_inv[indx] = video_for_inv[indx] + noise[indx] * sarp_noise_scale
|
| 194 |
+
video_for_inv = video_for_inv.clamp(-1,1)
|
| 195 |
+
|
| 196 |
+
firstframe, keyframe, video_for_inv = video_for_inv.tensor_split([1,2],dim=1)
|
| 197 |
+
|
| 198 |
+
#print("video for inv", video_for_inv.mean(), video_for_inv.std())
|
| 199 |
+
latents_for_inv = tensor_to_vae_latent(video_for_inv, vae)
|
| 200 |
+
#noise = torch.randn(latents_for_inv.shape, device=video.device, dtype=video.dtype)
|
| 201 |
+
#latents_for_inv = latents_for_inv + noise * sarp_noise_scale
|
| 202 |
+
#print("video latent for inv", latents_for_inv.mean(), latents_for_inv.std(), latents_for_inv.shape)
|
| 203 |
+
|
| 204 |
+
euler_inv_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 205 |
+
euler_inv_scheduler.set_timesteps(num_steps)
|
| 206 |
+
|
| 207 |
+
euler_inv_latent = euler_inversion(
|
| 208 |
+
pipe, euler_inv_scheduler, video_latent=latents_for_inv.to(pipe.device),
|
| 209 |
+
num_inv_steps=num_steps, image=keyframe[:,0,:,:,:], firstframe=firstframe[:,0,:,:,:], controller=store_controller, consistency_controller=consistency_store_controller)[-1]
|
| 210 |
+
|
| 211 |
+
torch.cuda.empty_cache()
|
| 212 |
+
del pipe
|
| 213 |
+
|
| 214 |
+
#res = anderson(euler_inv_latent.cpu().view(-1).numpy())
|
| 215 |
+
#print(euler_inv_latent.mean(), euler_inv_latent.std())
|
| 216 |
+
#print(res.statistic)
|
| 217 |
+
#print(res.critical_values)
|
| 218 |
+
#print(res.significance_level)
|
| 219 |
+
|
| 220 |
+
# save inverted latents
|
| 221 |
+
for _controller in [store_controller, consistency_store_controller]:
|
| 222 |
+
if _controller is not None:
|
| 223 |
+
torch.save(euler_inv_latent, os.path.join(_controller.store_dir, "inverted_latents.pt"))
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
return euler_inv_latent
|
i2vedit/utils/lora.py
ADDED
|
@@ -0,0 +1,1493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
from itertools import groupby
|
| 4 |
+
import os
|
| 5 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import PIL
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from safetensors.torch import safe_open
|
| 15 |
+
from safetensors.torch import save_file as safe_save
|
| 16 |
+
|
| 17 |
+
safetensors_available = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
from .safe_open import safe_open
|
| 20 |
+
|
| 21 |
+
def safe_save(
|
| 22 |
+
tensors: Dict[str, torch.Tensor],
|
| 23 |
+
filename: str,
|
| 24 |
+
metadata: Optional[Dict[str, str]] = None,
|
| 25 |
+
) -> None:
|
| 26 |
+
raise EnvironmentError(
|
| 27 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
safetensors_available = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LoraInjectedLinear(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
if r > min(in_features, out_features):
|
| 40 |
+
#raise ValueError(
|
| 41 |
+
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
| 42 |
+
#)
|
| 43 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
|
| 44 |
+
r = min(in_features, out_features)
|
| 45 |
+
|
| 46 |
+
self.r = r
|
| 47 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
| 48 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
| 49 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 50 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
| 51 |
+
self.scale = scale
|
| 52 |
+
self.selector = nn.Identity()
|
| 53 |
+
|
| 54 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 55 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 56 |
+
|
| 57 |
+
def forward(self, input):
|
| 58 |
+
return (
|
| 59 |
+
self.linear(input)
|
| 60 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 61 |
+
* self.scale
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def realize_as_lora(self):
|
| 65 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 66 |
+
|
| 67 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 68 |
+
# diag is a 1D tensor of size (r,)
|
| 69 |
+
assert diag.shape == (self.r,)
|
| 70 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
| 71 |
+
self.selector.weight.data = torch.diag(diag)
|
| 72 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 73 |
+
self.lora_up.weight.device
|
| 74 |
+
).to(self.lora_up.weight.dtype)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MultiLoraInjectedLinear(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0]
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
if r > min(in_features, out_features):
|
| 84 |
+
#raise ValueError(
|
| 85 |
+
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
| 86 |
+
#)
|
| 87 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
|
| 88 |
+
r = min(in_features, out_features)
|
| 89 |
+
|
| 90 |
+
self.r = r
|
| 91 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
| 92 |
+
|
| 93 |
+
for i in range(lora_num):
|
| 94 |
+
if i==0:
|
| 95 |
+
self.lora_down =[nn.Linear(in_features, r, bias=False)]
|
| 96 |
+
self.dropout = [nn.Dropout(dropout_p)]
|
| 97 |
+
self.lora_up = [nn.Linear(r, out_features, bias=False)]
|
| 98 |
+
self.scale = scales[i]
|
| 99 |
+
self.selector = [nn.Identity()]
|
| 100 |
+
else:
|
| 101 |
+
self.lora_down.append(nn.Linear(in_features, r, bias=False))
|
| 102 |
+
self.dropout.append( nn.Dropout(dropout_p))
|
| 103 |
+
self.lora_up.append( nn.Linear(r, out_features, bias=False))
|
| 104 |
+
self.scale.append(scales[i])
|
| 105 |
+
|
| 106 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 107 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 108 |
+
|
| 109 |
+
def forward(self, input):
|
| 110 |
+
return (
|
| 111 |
+
self.linear(input)
|
| 112 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 113 |
+
* self.scale
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def realize_as_lora(self):
|
| 117 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 118 |
+
|
| 119 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 120 |
+
# diag is a 1D tensor of size (r,)
|
| 121 |
+
assert diag.shape == (self.r,)
|
| 122 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
| 123 |
+
self.selector.weight.data = torch.diag(diag)
|
| 124 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 125 |
+
self.lora_up.weight.device
|
| 126 |
+
).to(self.lora_up.weight.dtype)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class LoraInjectedConv2d(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
in_channels: int,
|
| 133 |
+
out_channels: int,
|
| 134 |
+
kernel_size,
|
| 135 |
+
stride=1,
|
| 136 |
+
padding=0,
|
| 137 |
+
dilation=1,
|
| 138 |
+
groups: int = 1,
|
| 139 |
+
bias: bool = True,
|
| 140 |
+
r: int = 4,
|
| 141 |
+
dropout_p: float = 0.1,
|
| 142 |
+
scale: float = 1.0,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
if r > min(in_channels, out_channels):
|
| 146 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
| 147 |
+
r = min(in_channels, out_channels)
|
| 148 |
+
|
| 149 |
+
self.r = r
|
| 150 |
+
self.conv = nn.Conv2d(
|
| 151 |
+
in_channels=in_channels,
|
| 152 |
+
out_channels=out_channels,
|
| 153 |
+
kernel_size=kernel_size,
|
| 154 |
+
stride=stride,
|
| 155 |
+
padding=padding,
|
| 156 |
+
dilation=dilation,
|
| 157 |
+
groups=groups,
|
| 158 |
+
bias=bias,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.lora_down = nn.Conv2d(
|
| 162 |
+
in_channels=in_channels,
|
| 163 |
+
out_channels=r,
|
| 164 |
+
kernel_size=kernel_size,
|
| 165 |
+
stride=stride,
|
| 166 |
+
padding=padding,
|
| 167 |
+
dilation=dilation,
|
| 168 |
+
groups=groups,
|
| 169 |
+
bias=False,
|
| 170 |
+
)
|
| 171 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 172 |
+
self.lora_up = nn.Conv2d(
|
| 173 |
+
in_channels=r,
|
| 174 |
+
out_channels=out_channels,
|
| 175 |
+
kernel_size=1,
|
| 176 |
+
stride=1,
|
| 177 |
+
padding=0,
|
| 178 |
+
bias=False,
|
| 179 |
+
)
|
| 180 |
+
self.selector = nn.Identity()
|
| 181 |
+
self.scale = scale
|
| 182 |
+
|
| 183 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 184 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 185 |
+
|
| 186 |
+
def forward(self, input):
|
| 187 |
+
return (
|
| 188 |
+
self.conv(input)
|
| 189 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 190 |
+
* self.scale
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def realize_as_lora(self):
|
| 194 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 195 |
+
|
| 196 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 197 |
+
# diag is a 1D tensor of size (r,)
|
| 198 |
+
assert diag.shape == (self.r,)
|
| 199 |
+
self.selector = nn.Conv2d(
|
| 200 |
+
in_channels=self.r,
|
| 201 |
+
out_channels=self.r,
|
| 202 |
+
kernel_size=1,
|
| 203 |
+
stride=1,
|
| 204 |
+
padding=0,
|
| 205 |
+
bias=False,
|
| 206 |
+
)
|
| 207 |
+
self.selector.weight.data = torch.diag(diag)
|
| 208 |
+
|
| 209 |
+
# same device + dtype as lora_up
|
| 210 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 211 |
+
self.lora_up.weight.device
|
| 212 |
+
).to(self.lora_up.weight.dtype)
|
| 213 |
+
|
| 214 |
+
class LoraInjectedConv3d(nn.Module):
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
in_channels: int,
|
| 218 |
+
out_channels: int,
|
| 219 |
+
kernel_size: (3, 1, 1),
|
| 220 |
+
padding: (1, 0, 0),
|
| 221 |
+
bias: bool = False,
|
| 222 |
+
r: int = 4,
|
| 223 |
+
dropout_p: float = 0,
|
| 224 |
+
scale: float = 1.0,
|
| 225 |
+
):
|
| 226 |
+
super().__init__()
|
| 227 |
+
if r > min(in_channels, out_channels):
|
| 228 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
| 229 |
+
r = min(in_channels, out_channels)
|
| 230 |
+
|
| 231 |
+
self.r = r
|
| 232 |
+
self.kernel_size = kernel_size
|
| 233 |
+
self.padding = padding
|
| 234 |
+
self.conv = nn.Conv3d(
|
| 235 |
+
in_channels=in_channels,
|
| 236 |
+
out_channels=out_channels,
|
| 237 |
+
kernel_size=kernel_size,
|
| 238 |
+
padding=padding,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.lora_down = nn.Conv3d(
|
| 242 |
+
in_channels=in_channels,
|
| 243 |
+
out_channels=r,
|
| 244 |
+
kernel_size=kernel_size,
|
| 245 |
+
bias=False,
|
| 246 |
+
padding=padding
|
| 247 |
+
)
|
| 248 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 249 |
+
self.lora_up = nn.Conv3d(
|
| 250 |
+
in_channels=r,
|
| 251 |
+
out_channels=out_channels,
|
| 252 |
+
kernel_size=1,
|
| 253 |
+
stride=1,
|
| 254 |
+
padding=0,
|
| 255 |
+
bias=False,
|
| 256 |
+
)
|
| 257 |
+
self.selector = nn.Identity()
|
| 258 |
+
self.scale = scale
|
| 259 |
+
|
| 260 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 261 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 262 |
+
|
| 263 |
+
def forward(self, input):
|
| 264 |
+
return (
|
| 265 |
+
self.conv(input)
|
| 266 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 267 |
+
* self.scale
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def realize_as_lora(self):
|
| 271 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 272 |
+
|
| 273 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 274 |
+
# diag is a 1D tensor of size (r,)
|
| 275 |
+
assert diag.shape == (self.r,)
|
| 276 |
+
self.selector = nn.Conv3d(
|
| 277 |
+
in_channels=self.r,
|
| 278 |
+
out_channels=self.r,
|
| 279 |
+
kernel_size=1,
|
| 280 |
+
stride=1,
|
| 281 |
+
padding=0,
|
| 282 |
+
bias=False,
|
| 283 |
+
)
|
| 284 |
+
self.selector.weight.data = torch.diag(diag)
|
| 285 |
+
|
| 286 |
+
# same device + dtype as lora_up
|
| 287 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 288 |
+
self.lora_up.weight.device
|
| 289 |
+
).to(self.lora_up.weight.dtype)
|
| 290 |
+
|
| 291 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
| 292 |
+
|
| 293 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
|
| 294 |
+
|
| 295 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
| 296 |
+
|
| 297 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
| 298 |
+
|
| 299 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
| 300 |
+
|
| 301 |
+
EMBED_FLAG = "<embed>"
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _find_children(
|
| 305 |
+
model,
|
| 306 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 307 |
+
):
|
| 308 |
+
"""
|
| 309 |
+
Find all modules of a certain class (or union of classes).
|
| 310 |
+
|
| 311 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 312 |
+
names they are referenced by.
|
| 313 |
+
"""
|
| 314 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
| 315 |
+
for parent in model.modules():
|
| 316 |
+
for name, module in parent.named_children():
|
| 317 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 318 |
+
yield parent, name, module
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _find_modules_v2(
|
| 322 |
+
model,
|
| 323 |
+
ancestor_class: Optional[Set[str]] = None,
|
| 324 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 325 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = None,
|
| 326 |
+
# [
|
| 327 |
+
# LoraInjectedLinear,
|
| 328 |
+
# LoraInjectedConv2d,
|
| 329 |
+
# LoraInjectedConv3d
|
| 330 |
+
# ],
|
| 331 |
+
):
|
| 332 |
+
"""
|
| 333 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
| 334 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
| 335 |
+
|
| 336 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 337 |
+
names they are referenced by.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
# Get the targets we should replace all linears under
|
| 341 |
+
if ancestor_class is not None:
|
| 342 |
+
ancestors = (
|
| 343 |
+
module
|
| 344 |
+
for name, module in model.named_modules()
|
| 345 |
+
if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name)
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
# this, incase you want to naively iterate over all modules.
|
| 349 |
+
ancestors = [module for module in model.modules()]
|
| 350 |
+
|
| 351 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
| 352 |
+
for ancestor in ancestors:
|
| 353 |
+
for fullname, module in ancestor.named_modules():
|
| 354 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 355 |
+
continue_flag = True
|
| 356 |
+
if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
|
| 357 |
+
continue_flag = False
|
| 358 |
+
if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
|
| 359 |
+
continue_flag = False
|
| 360 |
+
if 'TemporalBasicTransformerBlock' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
|
| 361 |
+
continue_flag = False
|
| 362 |
+
#if 'TemporalBasicTransformerBlock' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
|
| 363 |
+
# continue_flag = False
|
| 364 |
+
if 'BasicTransformerBlock' in ancestor_class and ('attn2' in fullname or 'ff' in fullname):
|
| 365 |
+
continue_flag = False
|
| 366 |
+
if continue_flag:
|
| 367 |
+
continue
|
| 368 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
| 369 |
+
*path, name = fullname.split(".")
|
| 370 |
+
parent = ancestor
|
| 371 |
+
while path:
|
| 372 |
+
parent = parent.get_submodule(path.pop(0))
|
| 373 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
| 374 |
+
if exclude_children_of and any(
|
| 375 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
| 376 |
+
):
|
| 377 |
+
continue
|
| 378 |
+
if name in ['lora_up', 'dropout', 'lora_down']:
|
| 379 |
+
continue
|
| 380 |
+
# Otherwise, yield it
|
| 381 |
+
yield parent, name, module
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _find_modules_old(
|
| 385 |
+
model,
|
| 386 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 387 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 388 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
| 389 |
+
):
|
| 390 |
+
ret = []
|
| 391 |
+
for _module in model.modules():
|
| 392 |
+
if _module.__class__.__name__ in ancestor_class:
|
| 393 |
+
|
| 394 |
+
for name, _child_module in _module.named_modules():
|
| 395 |
+
if _child_module.__class__ in search_class:
|
| 396 |
+
ret.append((_module, name, _child_module))
|
| 397 |
+
print(ret)
|
| 398 |
+
return ret
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
_find_modules = _find_modules_v2
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def inject_trainable_lora(
|
| 405 |
+
model: nn.Module,
|
| 406 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 407 |
+
r: int = 4,
|
| 408 |
+
loras=None, # path to lora .pt
|
| 409 |
+
verbose: bool = False,
|
| 410 |
+
dropout_p: float = 0.0,
|
| 411 |
+
scale: float = 1.0,
|
| 412 |
+
):
|
| 413 |
+
"""
|
| 414 |
+
inject lora into model, and returns lora parameter groups.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
require_grad_params = []
|
| 418 |
+
names = []
|
| 419 |
+
|
| 420 |
+
if loras != None:
|
| 421 |
+
loras = torch.load(loras)
|
| 422 |
+
|
| 423 |
+
for _module, name, _child_module in _find_modules(
|
| 424 |
+
model, target_replace_module, search_class=[nn.Linear]
|
| 425 |
+
):
|
| 426 |
+
weight = _child_module.weight
|
| 427 |
+
bias = _child_module.bias
|
| 428 |
+
if verbose:
|
| 429 |
+
print("LoRA Injection : injecting lora into ", name)
|
| 430 |
+
print("LoRA Injection : weight shape", weight.shape)
|
| 431 |
+
_tmp = LoraInjectedLinear(
|
| 432 |
+
_child_module.in_features,
|
| 433 |
+
_child_module.out_features,
|
| 434 |
+
_child_module.bias is not None,
|
| 435 |
+
r=r,
|
| 436 |
+
dropout_p=dropout_p,
|
| 437 |
+
scale=scale,
|
| 438 |
+
)
|
| 439 |
+
_tmp.linear.weight = weight
|
| 440 |
+
if bias is not None:
|
| 441 |
+
_tmp.linear.bias = bias
|
| 442 |
+
|
| 443 |
+
# switch the module
|
| 444 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 445 |
+
_module._modules[name] = _tmp
|
| 446 |
+
|
| 447 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 448 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 449 |
+
|
| 450 |
+
if loras != None:
|
| 451 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 452 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 453 |
+
|
| 454 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 455 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 456 |
+
names.append(name)
|
| 457 |
+
|
| 458 |
+
return require_grad_params, names
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def inject_trainable_lora_extended(
|
| 462 |
+
model: nn.Module,
|
| 463 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
| 464 |
+
r: int = 4,
|
| 465 |
+
loras=None, # path to lora .pt
|
| 466 |
+
dropout_p: float = 0.0,
|
| 467 |
+
scale: float = 1.0,
|
| 468 |
+
):
|
| 469 |
+
"""
|
| 470 |
+
inject lora into model, and returns lora parameter groups.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
require_grad_params = []
|
| 474 |
+
names = []
|
| 475 |
+
|
| 476 |
+
if loras != None:
|
| 477 |
+
print(f"Load from lora: {loras} ...")
|
| 478 |
+
loras = torch.load(loras)
|
| 479 |
+
if True:
|
| 480 |
+
for target_replace_module_i in target_replace_module:
|
| 481 |
+
for _module, name, _child_module in _find_modules(
|
| 482 |
+
model, [target_replace_module_i], search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
|
| 483 |
+
):
|
| 484 |
+
# if name == 'to_q':
|
| 485 |
+
# continue
|
| 486 |
+
if _child_module.__class__ == nn.Linear:
|
| 487 |
+
weight = _child_module.weight
|
| 488 |
+
bias = _child_module.bias
|
| 489 |
+
_tmp = LoraInjectedLinear(
|
| 490 |
+
_child_module.in_features,
|
| 491 |
+
_child_module.out_features,
|
| 492 |
+
_child_module.bias is not None,
|
| 493 |
+
r=r,
|
| 494 |
+
dropout_p=dropout_p,
|
| 495 |
+
scale=scale,
|
| 496 |
+
)
|
| 497 |
+
_tmp.linear.weight = weight
|
| 498 |
+
if bias is not None:
|
| 499 |
+
_tmp.linear.bias = bias
|
| 500 |
+
elif _child_module.__class__ == nn.Conv2d:
|
| 501 |
+
weight = _child_module.weight
|
| 502 |
+
bias = _child_module.bias
|
| 503 |
+
_tmp = LoraInjectedConv2d(
|
| 504 |
+
_child_module.in_channels,
|
| 505 |
+
_child_module.out_channels,
|
| 506 |
+
_child_module.kernel_size,
|
| 507 |
+
_child_module.stride,
|
| 508 |
+
_child_module.padding,
|
| 509 |
+
_child_module.dilation,
|
| 510 |
+
_child_module.groups,
|
| 511 |
+
_child_module.bias is not None,
|
| 512 |
+
r=r,
|
| 513 |
+
dropout_p=dropout_p,
|
| 514 |
+
scale=scale,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
_tmp.conv.weight = weight
|
| 518 |
+
if bias is not None:
|
| 519 |
+
_tmp.conv.bias = bias
|
| 520 |
+
|
| 521 |
+
elif _child_module.__class__ == nn.Conv3d:
|
| 522 |
+
weight = _child_module.weight
|
| 523 |
+
bias = _child_module.bias
|
| 524 |
+
_tmp = LoraInjectedConv3d(
|
| 525 |
+
_child_module.in_channels,
|
| 526 |
+
_child_module.out_channels,
|
| 527 |
+
bias=_child_module.bias is not None,
|
| 528 |
+
kernel_size=_child_module.kernel_size,
|
| 529 |
+
padding=_child_module.padding,
|
| 530 |
+
r=r,
|
| 531 |
+
dropout_p=dropout_p,
|
| 532 |
+
scale=scale,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
_tmp.conv.weight = weight
|
| 536 |
+
if bias is not None:
|
| 537 |
+
_tmp.conv.bias = bias
|
| 538 |
+
# LoRA layer
|
| 539 |
+
else:
|
| 540 |
+
_tmp = _child_module
|
| 541 |
+
# switch the module
|
| 542 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 543 |
+
try:
|
| 544 |
+
if bias is not None:
|
| 545 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
| 546 |
+
except:
|
| 547 |
+
pass
|
| 548 |
+
|
| 549 |
+
_module._modules[name] = _tmp
|
| 550 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 551 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 552 |
+
|
| 553 |
+
if loras != None:
|
| 554 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 555 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 556 |
+
|
| 557 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 558 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 559 |
+
names.append(name)
|
| 560 |
+
else:
|
| 561 |
+
for _module, name, _child_module in _find_modules(
|
| 562 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
|
| 563 |
+
):
|
| 564 |
+
if _child_module.__class__ == nn.Linear:
|
| 565 |
+
weight = _child_module.weight
|
| 566 |
+
bias = _child_module.bias
|
| 567 |
+
_tmp = LoraInjectedLinear(
|
| 568 |
+
_child_module.in_features,
|
| 569 |
+
_child_module.out_features,
|
| 570 |
+
_child_module.bias is not None,
|
| 571 |
+
r=r,
|
| 572 |
+
dropout_p=dropout_p,
|
| 573 |
+
scale=scale,
|
| 574 |
+
)
|
| 575 |
+
_tmp.linear.weight = weight
|
| 576 |
+
if bias is not None:
|
| 577 |
+
_tmp.linear.bias = bias
|
| 578 |
+
elif _child_module.__class__ == nn.Conv2d:
|
| 579 |
+
weight = _child_module.weight
|
| 580 |
+
bias = _child_module.bias
|
| 581 |
+
_tmp = LoraInjectedConv2d(
|
| 582 |
+
_child_module.in_channels,
|
| 583 |
+
_child_module.out_channels,
|
| 584 |
+
_child_module.kernel_size,
|
| 585 |
+
_child_module.stride,
|
| 586 |
+
_child_module.padding,
|
| 587 |
+
_child_module.dilation,
|
| 588 |
+
_child_module.groups,
|
| 589 |
+
_child_module.bias is not None,
|
| 590 |
+
r=r,
|
| 591 |
+
dropout_p=dropout_p,
|
| 592 |
+
scale=scale,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
_tmp.conv.weight = weight
|
| 596 |
+
if bias is not None:
|
| 597 |
+
_tmp.conv.bias = bias
|
| 598 |
+
|
| 599 |
+
elif _child_module.__class__ == nn.Conv3d:
|
| 600 |
+
weight = _child_module.weight
|
| 601 |
+
bias = _child_module.bias
|
| 602 |
+
_tmp = LoraInjectedConv3d(
|
| 603 |
+
_child_module.in_channels,
|
| 604 |
+
_child_module.out_channels,
|
| 605 |
+
bias=_child_module.bias is not None,
|
| 606 |
+
kernel_size=_child_module.kernel_size,
|
| 607 |
+
padding=_child_module.padding,
|
| 608 |
+
r=r,
|
| 609 |
+
dropout_p=dropout_p,
|
| 610 |
+
scale=scale,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
_tmp.conv.weight = weight
|
| 614 |
+
if bias is not None:
|
| 615 |
+
_tmp.conv.bias = bias
|
| 616 |
+
# switch the module
|
| 617 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 618 |
+
if bias is not None:
|
| 619 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
| 620 |
+
|
| 621 |
+
_module._modules[name] = _tmp
|
| 622 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 623 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 624 |
+
|
| 625 |
+
if loras != None:
|
| 626 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 627 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 628 |
+
|
| 629 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 630 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 631 |
+
names.append(name)
|
| 632 |
+
|
| 633 |
+
return require_grad_params, names
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def inject_inferable_lora(
|
| 637 |
+
model,
|
| 638 |
+
lora_path='',
|
| 639 |
+
unet_replace_modules=["UNet3DConditionModel"],
|
| 640 |
+
text_encoder_replace_modules=["CLIPEncoderLayer"],
|
| 641 |
+
is_extended=False,
|
| 642 |
+
r=16
|
| 643 |
+
):
|
| 644 |
+
from transformers.models.clip import CLIPTextModel
|
| 645 |
+
from diffusers import UNet3DConditionModel
|
| 646 |
+
|
| 647 |
+
def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
|
| 648 |
+
def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
|
| 649 |
+
|
| 650 |
+
if os.path.exists(lora_path):
|
| 651 |
+
try:
|
| 652 |
+
for f in os.listdir(lora_path):
|
| 653 |
+
if f.endswith('.pt'):
|
| 654 |
+
lora_file = os.path.join(lora_path, f)
|
| 655 |
+
|
| 656 |
+
if is_text_model(f):
|
| 657 |
+
monkeypatch_or_replace_lora(
|
| 658 |
+
model.text_encoder,
|
| 659 |
+
torch.load(lora_file),
|
| 660 |
+
target_replace_module=text_encoder_replace_modules,
|
| 661 |
+
r=r
|
| 662 |
+
)
|
| 663 |
+
print("Successfully loaded Text Encoder LoRa.")
|
| 664 |
+
continue
|
| 665 |
+
|
| 666 |
+
if is_unet(f):
|
| 667 |
+
monkeypatch_or_replace_lora_extended(
|
| 668 |
+
model.unet,
|
| 669 |
+
torch.load(lora_file),
|
| 670 |
+
target_replace_module=unet_replace_modules,
|
| 671 |
+
r=r
|
| 672 |
+
)
|
| 673 |
+
print("Successfully loaded UNET LoRa.")
|
| 674 |
+
continue
|
| 675 |
+
|
| 676 |
+
print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
|
| 677 |
+
|
| 678 |
+
except Exception as e:
|
| 679 |
+
print(e)
|
| 680 |
+
print("Couldn't inject LoRA's due to an error.")
|
| 681 |
+
|
| 682 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
| 683 |
+
|
| 684 |
+
loras = []
|
| 685 |
+
|
| 686 |
+
for target_replace_module_i in target_replace_module:
|
| 687 |
+
|
| 688 |
+
for _m, _n, _child_module in _find_modules(
|
| 689 |
+
model,
|
| 690 |
+
[target_replace_module_i],
|
| 691 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
| 692 |
+
):
|
| 693 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
| 694 |
+
|
| 695 |
+
if len(loras) == 0:
|
| 696 |
+
raise ValueError("No lora injected.")
|
| 697 |
+
|
| 698 |
+
return loras
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
| 702 |
+
|
| 703 |
+
loras = []
|
| 704 |
+
|
| 705 |
+
for target_replace_module_i in target_replace_module:
|
| 706 |
+
|
| 707 |
+
for _m, _n, _child_module in _find_modules(
|
| 708 |
+
model,
|
| 709 |
+
[target_replace_module_i],
|
| 710 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
| 711 |
+
):
|
| 712 |
+
loras.append(_child_module)
|
| 713 |
+
|
| 714 |
+
return loras
|
| 715 |
+
|
| 716 |
+
def extract_lora_as_tensor(
|
| 717 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
| 718 |
+
):
|
| 719 |
+
|
| 720 |
+
loras = []
|
| 721 |
+
|
| 722 |
+
for _m, _n, _child_module in _find_modules(
|
| 723 |
+
model,
|
| 724 |
+
target_replace_module,
|
| 725 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
| 726 |
+
):
|
| 727 |
+
up, down = _child_module.realize_as_lora()
|
| 728 |
+
if as_fp16:
|
| 729 |
+
up = up.to(torch.float16)
|
| 730 |
+
down = down.to(torch.float16)
|
| 731 |
+
|
| 732 |
+
loras.append((up, down))
|
| 733 |
+
|
| 734 |
+
if len(loras) == 0:
|
| 735 |
+
raise ValueError("No lora injected.")
|
| 736 |
+
|
| 737 |
+
return loras
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def save_lora_weight(
|
| 741 |
+
model,
|
| 742 |
+
path="./lora.pt",
|
| 743 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 744 |
+
flag=None
|
| 745 |
+
):
|
| 746 |
+
weights = []
|
| 747 |
+
for _up, _down in extract_lora_ups_down(
|
| 748 |
+
model, target_replace_module=target_replace_module
|
| 749 |
+
):
|
| 750 |
+
weights.append(_up.weight.to("cpu").to(torch.float32))
|
| 751 |
+
weights.append(_down.weight.to("cpu").to(torch.float32))
|
| 752 |
+
if not flag:
|
| 753 |
+
torch.save(weights, path)
|
| 754 |
+
else:
|
| 755 |
+
weights_new=[]
|
| 756 |
+
for i in range(0, len(weights), 4):
|
| 757 |
+
subset = weights[i+(flag-1)*2:i+(flag-1)*2+2]
|
| 758 |
+
weights_new.extend(subset)
|
| 759 |
+
torch.save(weights_new, path)
|
| 760 |
+
|
| 761 |
+
def save_lora_as_json(model, path="./lora.json"):
|
| 762 |
+
weights = []
|
| 763 |
+
for _up, _down in extract_lora_ups_down(model):
|
| 764 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
| 765 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
| 766 |
+
|
| 767 |
+
import json
|
| 768 |
+
|
| 769 |
+
with open(path, "w") as f:
|
| 770 |
+
json.dump(weights, f)
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
def save_safeloras_with_embeds(
|
| 774 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 775 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 776 |
+
outpath="./lora.safetensors",
|
| 777 |
+
):
|
| 778 |
+
"""
|
| 779 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
| 780 |
+
|
| 781 |
+
modelmap is a dictionary of {
|
| 782 |
+
"module name": (module, target_replace_module)
|
| 783 |
+
}
|
| 784 |
+
"""
|
| 785 |
+
weights = {}
|
| 786 |
+
metadata = {}
|
| 787 |
+
|
| 788 |
+
for name, (model, target_replace_module) in modelmap.items():
|
| 789 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 790 |
+
|
| 791 |
+
for i, (_up, _down) in enumerate(
|
| 792 |
+
extract_lora_as_tensor(model, target_replace_module)
|
| 793 |
+
):
|
| 794 |
+
rank = _down.shape[0]
|
| 795 |
+
|
| 796 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
| 797 |
+
weights[f"{name}:{i}:up"] = _up
|
| 798 |
+
weights[f"{name}:{i}:down"] = _down
|
| 799 |
+
|
| 800 |
+
for token, tensor in embeds.items():
|
| 801 |
+
metadata[token] = EMBED_FLAG
|
| 802 |
+
weights[token] = tensor
|
| 803 |
+
|
| 804 |
+
print(f"Saving weights to {outpath}")
|
| 805 |
+
safe_save(weights, outpath, metadata)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def save_safeloras(
|
| 809 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 810 |
+
outpath="./lora.safetensors",
|
| 811 |
+
):
|
| 812 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def convert_loras_to_safeloras_with_embeds(
|
| 816 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 817 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 818 |
+
outpath="./lora.safetensors",
|
| 819 |
+
):
|
| 820 |
+
"""
|
| 821 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
| 822 |
+
|
| 823 |
+
modelmap is a dictionary of {
|
| 824 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
| 825 |
+
}
|
| 826 |
+
"""
|
| 827 |
+
|
| 828 |
+
weights = {}
|
| 829 |
+
metadata = {}
|
| 830 |
+
|
| 831 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
| 832 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 833 |
+
|
| 834 |
+
lora = torch.load(path)
|
| 835 |
+
for i, weight in enumerate(lora):
|
| 836 |
+
is_up = i % 2 == 0
|
| 837 |
+
i = i // 2
|
| 838 |
+
|
| 839 |
+
if is_up:
|
| 840 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
| 841 |
+
weights[f"{name}:{i}:up"] = weight
|
| 842 |
+
else:
|
| 843 |
+
weights[f"{name}:{i}:down"] = weight
|
| 844 |
+
|
| 845 |
+
for token, tensor in embeds.items():
|
| 846 |
+
metadata[token] = EMBED_FLAG
|
| 847 |
+
weights[token] = tensor
|
| 848 |
+
|
| 849 |
+
print(f"Saving weights to {outpath}")
|
| 850 |
+
safe_save(weights, outpath, metadata)
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def convert_loras_to_safeloras(
|
| 854 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 855 |
+
outpath="./lora.safetensors",
|
| 856 |
+
):
|
| 857 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def parse_safeloras(
|
| 861 |
+
safeloras,
|
| 862 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
| 863 |
+
"""
|
| 864 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
| 865 |
+
into Parameters and other information
|
| 866 |
+
|
| 867 |
+
Output is a dictionary of {
|
| 868 |
+
"module name": (
|
| 869 |
+
[list of weights],
|
| 870 |
+
[list of ranks],
|
| 871 |
+
target_replacement_modules
|
| 872 |
+
)
|
| 873 |
+
}
|
| 874 |
+
"""
|
| 875 |
+
loras = {}
|
| 876 |
+
metadata = safeloras.metadata()
|
| 877 |
+
|
| 878 |
+
get_name = lambda k: k.split(":")[0]
|
| 879 |
+
|
| 880 |
+
keys = list(safeloras.keys())
|
| 881 |
+
keys.sort(key=get_name)
|
| 882 |
+
|
| 883 |
+
for name, module_keys in groupby(keys, get_name):
|
| 884 |
+
info = metadata.get(name)
|
| 885 |
+
|
| 886 |
+
if not info:
|
| 887 |
+
raise ValueError(
|
| 888 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
# Skip Textual Inversion embeds
|
| 892 |
+
if info == EMBED_FLAG:
|
| 893 |
+
continue
|
| 894 |
+
|
| 895 |
+
# Handle Loras
|
| 896 |
+
# Extract the targets
|
| 897 |
+
target = json.loads(info)
|
| 898 |
+
|
| 899 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
| 900 |
+
module_keys = list(module_keys)
|
| 901 |
+
ranks = [4] * (len(module_keys) // 2)
|
| 902 |
+
weights = [None] * len(module_keys)
|
| 903 |
+
|
| 904 |
+
for key in module_keys:
|
| 905 |
+
# Split the model name and index out of the key
|
| 906 |
+
_, idx, direction = key.split(":")
|
| 907 |
+
idx = int(idx)
|
| 908 |
+
|
| 909 |
+
# Add the rank
|
| 910 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
| 911 |
+
|
| 912 |
+
# Insert the weight into the list
|
| 913 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
| 914 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
| 915 |
+
|
| 916 |
+
loras[name] = (weights, ranks, target)
|
| 917 |
+
|
| 918 |
+
return loras
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def parse_safeloras_embeds(
|
| 922 |
+
safeloras,
|
| 923 |
+
) -> Dict[str, torch.Tensor]:
|
| 924 |
+
"""
|
| 925 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
| 926 |
+
a dictionary of embed_token: Tensor
|
| 927 |
+
"""
|
| 928 |
+
embeds = {}
|
| 929 |
+
metadata = safeloras.metadata()
|
| 930 |
+
|
| 931 |
+
for key in safeloras.keys():
|
| 932 |
+
# Only handle Textual Inversion embeds
|
| 933 |
+
meta = metadata.get(key)
|
| 934 |
+
if not meta or meta != EMBED_FLAG:
|
| 935 |
+
continue
|
| 936 |
+
|
| 937 |
+
embeds[key] = safeloras.get_tensor(key)
|
| 938 |
+
|
| 939 |
+
return embeds
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def load_safeloras(path, device="cpu"):
|
| 943 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 944 |
+
return parse_safeloras(safeloras)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def load_safeloras_embeds(path, device="cpu"):
|
| 948 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 949 |
+
return parse_safeloras_embeds(safeloras)
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def load_safeloras_both(path, device="cpu"):
|
| 953 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 954 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def collapse_lora(model, alpha=1.0):
|
| 958 |
+
|
| 959 |
+
for _module, name, _child_module in _find_modules(
|
| 960 |
+
model,
|
| 961 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
| 962 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
| 963 |
+
):
|
| 964 |
+
|
| 965 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 966 |
+
print("Collapsing Lin Lora in", name)
|
| 967 |
+
|
| 968 |
+
_child_module.linear.weight = nn.Parameter(
|
| 969 |
+
_child_module.linear.weight.data
|
| 970 |
+
+ alpha
|
| 971 |
+
* (
|
| 972 |
+
_child_module.lora_up.weight.data
|
| 973 |
+
@ _child_module.lora_down.weight.data
|
| 974 |
+
)
|
| 975 |
+
.type(_child_module.linear.weight.dtype)
|
| 976 |
+
.to(_child_module.linear.weight.device)
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
else:
|
| 980 |
+
print("Collapsing Conv Lora in", name)
|
| 981 |
+
_child_module.conv.weight = nn.Parameter(
|
| 982 |
+
_child_module.conv.weight.data
|
| 983 |
+
+ alpha
|
| 984 |
+
* (
|
| 985 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
| 986 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
| 987 |
+
)
|
| 988 |
+
.reshape(_child_module.conv.weight.data.shape)
|
| 989 |
+
.type(_child_module.conv.weight.dtype)
|
| 990 |
+
.to(_child_module.conv.weight.device)
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def monkeypatch_or_replace_lora(
|
| 995 |
+
model,
|
| 996 |
+
loras,
|
| 997 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 998 |
+
r: Union[int, List[int]] = 4,
|
| 999 |
+
):
|
| 1000 |
+
for _module, name, _child_module in _find_modules(
|
| 1001 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
| 1002 |
+
):
|
| 1003 |
+
_source = (
|
| 1004 |
+
_child_module.linear
|
| 1005 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 1006 |
+
else _child_module
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
weight = _source.weight
|
| 1010 |
+
bias = _source.bias
|
| 1011 |
+
_tmp = LoraInjectedLinear(
|
| 1012 |
+
_source.in_features,
|
| 1013 |
+
_source.out_features,
|
| 1014 |
+
_source.bias is not None,
|
| 1015 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 1016 |
+
)
|
| 1017 |
+
_tmp.linear.weight = weight
|
| 1018 |
+
|
| 1019 |
+
if bias is not None:
|
| 1020 |
+
_tmp.linear.bias = bias
|
| 1021 |
+
|
| 1022 |
+
# switch the module
|
| 1023 |
+
_module._modules[name] = _tmp
|
| 1024 |
+
|
| 1025 |
+
up_weight = loras.pop(0)
|
| 1026 |
+
down_weight = loras.pop(0)
|
| 1027 |
+
|
| 1028 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 1029 |
+
up_weight.type(weight.dtype)
|
| 1030 |
+
)
|
| 1031 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 1032 |
+
down_weight.type(weight.dtype)
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
_module._modules[name].to(weight.device)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
def monkeypatch_or_replace_lora_extended(
|
| 1039 |
+
model,
|
| 1040 |
+
loras,
|
| 1041 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 1042 |
+
r: Union[int, List[int]] = 4,
|
| 1043 |
+
):
|
| 1044 |
+
for _module, name, _child_module in _find_modules(
|
| 1045 |
+
model,
|
| 1046 |
+
target_replace_module,
|
| 1047 |
+
search_class=[
|
| 1048 |
+
nn.Linear,
|
| 1049 |
+
nn.Conv2d,
|
| 1050 |
+
nn.Conv3d,
|
| 1051 |
+
LoraInjectedLinear,
|
| 1052 |
+
LoraInjectedConv2d,
|
| 1053 |
+
LoraInjectedConv3d,
|
| 1054 |
+
],
|
| 1055 |
+
):
|
| 1056 |
+
|
| 1057 |
+
if (_child_module.__class__ == nn.Linear) or (
|
| 1058 |
+
_child_module.__class__ == LoraInjectedLinear
|
| 1059 |
+
):
|
| 1060 |
+
if len(loras[0].shape) != 2:
|
| 1061 |
+
continue
|
| 1062 |
+
|
| 1063 |
+
_source = (
|
| 1064 |
+
_child_module.linear
|
| 1065 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 1066 |
+
else _child_module
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
weight = _source.weight
|
| 1070 |
+
bias = _source.bias
|
| 1071 |
+
_tmp = LoraInjectedLinear(
|
| 1072 |
+
_source.in_features,
|
| 1073 |
+
_source.out_features,
|
| 1074 |
+
_source.bias is not None,
|
| 1075 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 1076 |
+
)
|
| 1077 |
+
_tmp.linear.weight = weight
|
| 1078 |
+
|
| 1079 |
+
if bias is not None:
|
| 1080 |
+
_tmp.linear.bias = bias
|
| 1081 |
+
|
| 1082 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
| 1083 |
+
_child_module.__class__ == LoraInjectedConv2d
|
| 1084 |
+
):
|
| 1085 |
+
if len(loras[0].shape) != 4:
|
| 1086 |
+
continue
|
| 1087 |
+
_source = (
|
| 1088 |
+
_child_module.conv
|
| 1089 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
| 1090 |
+
else _child_module
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
weight = _source.weight
|
| 1094 |
+
bias = _source.bias
|
| 1095 |
+
_tmp = LoraInjectedConv2d(
|
| 1096 |
+
_source.in_channels,
|
| 1097 |
+
_source.out_channels,
|
| 1098 |
+
_source.kernel_size,
|
| 1099 |
+
_source.stride,
|
| 1100 |
+
_source.padding,
|
| 1101 |
+
_source.dilation,
|
| 1102 |
+
_source.groups,
|
| 1103 |
+
_source.bias is not None,
|
| 1104 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
_tmp.conv.weight = weight
|
| 1108 |
+
|
| 1109 |
+
if bias is not None:
|
| 1110 |
+
_tmp.conv.bias = bias
|
| 1111 |
+
|
| 1112 |
+
elif _child_module.__class__ == nn.Conv3d or(
|
| 1113 |
+
_child_module.__class__ == LoraInjectedConv3d
|
| 1114 |
+
):
|
| 1115 |
+
|
| 1116 |
+
if len(loras[0].shape) != 5:
|
| 1117 |
+
continue
|
| 1118 |
+
|
| 1119 |
+
_source = (
|
| 1120 |
+
_child_module.conv
|
| 1121 |
+
if isinstance(_child_module, LoraInjectedConv3d)
|
| 1122 |
+
else _child_module
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
weight = _source.weight
|
| 1126 |
+
bias = _source.bias
|
| 1127 |
+
_tmp = LoraInjectedConv3d(
|
| 1128 |
+
_source.in_channels,
|
| 1129 |
+
_source.out_channels,
|
| 1130 |
+
bias=_source.bias is not None,
|
| 1131 |
+
kernel_size=_source.kernel_size,
|
| 1132 |
+
padding=_source.padding,
|
| 1133 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
_tmp.conv.weight = weight
|
| 1137 |
+
|
| 1138 |
+
if bias is not None:
|
| 1139 |
+
_tmp.conv.bias = bias
|
| 1140 |
+
|
| 1141 |
+
# switch the module
|
| 1142 |
+
_module._modules[name] = _tmp
|
| 1143 |
+
|
| 1144 |
+
up_weight = loras.pop(0)
|
| 1145 |
+
down_weight = loras.pop(0)
|
| 1146 |
+
|
| 1147 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 1148 |
+
up_weight.type(weight.dtype)
|
| 1149 |
+
)
|
| 1150 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 1151 |
+
down_weight.type(weight.dtype)
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
_module._modules[name].to(weight.device)
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
| 1158 |
+
loras = parse_safeloras(safeloras)
|
| 1159 |
+
|
| 1160 |
+
for name, (lora, ranks, target) in loras.items():
|
| 1161 |
+
model = getattr(models, name, None)
|
| 1162 |
+
|
| 1163 |
+
if not model:
|
| 1164 |
+
print(f"No model provided for {name}, contained in Lora")
|
| 1165 |
+
continue
|
| 1166 |
+
|
| 1167 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
def monkeypatch_remove_lora(model):
|
| 1171 |
+
for _module, name, _child_module in _find_modules(
|
| 1172 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
|
| 1173 |
+
):
|
| 1174 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 1175 |
+
_source = _child_module.linear
|
| 1176 |
+
weight, bias = _source.weight, _source.bias
|
| 1177 |
+
|
| 1178 |
+
_tmp = nn.Linear(
|
| 1179 |
+
_source.in_features, _source.out_features, bias is not None
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
_tmp.weight = weight
|
| 1183 |
+
if bias is not None:
|
| 1184 |
+
_tmp.bias = bias
|
| 1185 |
+
|
| 1186 |
+
else:
|
| 1187 |
+
_source = _child_module.conv
|
| 1188 |
+
weight, bias = _source.weight, _source.bias
|
| 1189 |
+
|
| 1190 |
+
if isinstance(_source, nn.Conv2d):
|
| 1191 |
+
_tmp = nn.Conv2d(
|
| 1192 |
+
in_channels=_source.in_channels,
|
| 1193 |
+
out_channels=_source.out_channels,
|
| 1194 |
+
kernel_size=_source.kernel_size,
|
| 1195 |
+
stride=_source.stride,
|
| 1196 |
+
padding=_source.padding,
|
| 1197 |
+
dilation=_source.dilation,
|
| 1198 |
+
groups=_source.groups,
|
| 1199 |
+
bias=bias is not None,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
_tmp.weight = weight
|
| 1203 |
+
if bias is not None:
|
| 1204 |
+
_tmp.bias = bias
|
| 1205 |
+
|
| 1206 |
+
if isinstance(_source, nn.Conv3d):
|
| 1207 |
+
_tmp = nn.Conv3d(
|
| 1208 |
+
_source.in_channels,
|
| 1209 |
+
_source.out_channels,
|
| 1210 |
+
bias=_source.bias is not None,
|
| 1211 |
+
kernel_size=_source.kernel_size,
|
| 1212 |
+
padding=_source.padding,
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
_tmp.weight = weight
|
| 1216 |
+
if bias is not None:
|
| 1217 |
+
_tmp.bias = bias
|
| 1218 |
+
|
| 1219 |
+
_module._modules[name] = _tmp
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
def monkeypatch_add_lora(
|
| 1223 |
+
model,
|
| 1224 |
+
loras,
|
| 1225 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 1226 |
+
alpha: float = 1.0,
|
| 1227 |
+
beta: float = 1.0,
|
| 1228 |
+
):
|
| 1229 |
+
for _module, name, _child_module in _find_modules(
|
| 1230 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
| 1231 |
+
):
|
| 1232 |
+
weight = _child_module.linear.weight
|
| 1233 |
+
|
| 1234 |
+
up_weight = loras.pop(0)
|
| 1235 |
+
down_weight = loras.pop(0)
|
| 1236 |
+
|
| 1237 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 1238 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
| 1239 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
| 1240 |
+
)
|
| 1241 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 1242 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
| 1243 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
_module._modules[name].to(weight.device)
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
| 1250 |
+
for _module in model.modules():
|
| 1251 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
| 1252 |
+
_module.scale = alpha
|
| 1253 |
+
|
| 1254 |
+
|
| 1255 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
| 1256 |
+
for _module in model.modules():
|
| 1257 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
| 1258 |
+
_module.set_selector_from_diag(diag)
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
def _text_lora_path(path: str) -> str:
|
| 1262 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 1263 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
def _ti_lora_path(path: str) -> str:
|
| 1267 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 1268 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
def apply_learned_embed_in_clip(
|
| 1272 |
+
learned_embeds,
|
| 1273 |
+
text_encoder,
|
| 1274 |
+
tokenizer,
|
| 1275 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 1276 |
+
idempotent=False,
|
| 1277 |
+
):
|
| 1278 |
+
if isinstance(token, str):
|
| 1279 |
+
trained_tokens = [token]
|
| 1280 |
+
elif isinstance(token, list):
|
| 1281 |
+
assert len(learned_embeds.keys()) == len(
|
| 1282 |
+
token
|
| 1283 |
+
), "The number of tokens and the number of embeds should be the same"
|
| 1284 |
+
trained_tokens = token
|
| 1285 |
+
else:
|
| 1286 |
+
trained_tokens = list(learned_embeds.keys())
|
| 1287 |
+
|
| 1288 |
+
for token in trained_tokens:
|
| 1289 |
+
print(token)
|
| 1290 |
+
embeds = learned_embeds[token]
|
| 1291 |
+
|
| 1292 |
+
# cast to dtype of text_encoder
|
| 1293 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
| 1294 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 1295 |
+
|
| 1296 |
+
i = 1
|
| 1297 |
+
if not idempotent:
|
| 1298 |
+
while num_added_tokens == 0:
|
| 1299 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 1300 |
+
token = f"{token[:-1]}-{i}>"
|
| 1301 |
+
print(f"Attempting to add the token {token}.")
|
| 1302 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 1303 |
+
i += 1
|
| 1304 |
+
elif num_added_tokens == 0 and idempotent:
|
| 1305 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 1306 |
+
print(f"Replacing {token} embedding.")
|
| 1307 |
+
|
| 1308 |
+
# resize the token embeddings
|
| 1309 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 1310 |
+
|
| 1311 |
+
# get the id for the token and assign the embeds
|
| 1312 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 1313 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
| 1314 |
+
return token
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
def load_learned_embed_in_clip(
|
| 1318 |
+
learned_embeds_path,
|
| 1319 |
+
text_encoder,
|
| 1320 |
+
tokenizer,
|
| 1321 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 1322 |
+
idempotent=False,
|
| 1323 |
+
):
|
| 1324 |
+
learned_embeds = torch.load(learned_embeds_path)
|
| 1325 |
+
apply_learned_embed_in_clip(
|
| 1326 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
def patch_pipe(
|
| 1331 |
+
pipe,
|
| 1332 |
+
maybe_unet_path,
|
| 1333 |
+
token: Optional[str] = None,
|
| 1334 |
+
r: int = 4,
|
| 1335 |
+
patch_unet=True,
|
| 1336 |
+
patch_text=True,
|
| 1337 |
+
patch_ti=True,
|
| 1338 |
+
idempotent_token=True,
|
| 1339 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 1340 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 1341 |
+
):
|
| 1342 |
+
if maybe_unet_path.endswith(".pt"):
|
| 1343 |
+
# torch format
|
| 1344 |
+
|
| 1345 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
| 1346 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
| 1347 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
| 1348 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
| 1349 |
+
else:
|
| 1350 |
+
unet_path = maybe_unet_path
|
| 1351 |
+
|
| 1352 |
+
ti_path = _ti_lora_path(unet_path)
|
| 1353 |
+
text_path = _text_lora_path(unet_path)
|
| 1354 |
+
|
| 1355 |
+
if patch_unet:
|
| 1356 |
+
print("LoRA : Patching Unet")
|
| 1357 |
+
monkeypatch_or_replace_lora(
|
| 1358 |
+
pipe.unet,
|
| 1359 |
+
torch.load(unet_path),
|
| 1360 |
+
r=r,
|
| 1361 |
+
target_replace_module=unet_target_replace_module,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
if patch_text:
|
| 1365 |
+
print("LoRA : Patching text encoder")
|
| 1366 |
+
monkeypatch_or_replace_lora(
|
| 1367 |
+
pipe.text_encoder,
|
| 1368 |
+
torch.load(text_path),
|
| 1369 |
+
target_replace_module=text_target_replace_module,
|
| 1370 |
+
r=r,
|
| 1371 |
+
)
|
| 1372 |
+
if patch_ti:
|
| 1373 |
+
print("LoRA : Patching token input")
|
| 1374 |
+
token = load_learned_embed_in_clip(
|
| 1375 |
+
ti_path,
|
| 1376 |
+
pipe.text_encoder,
|
| 1377 |
+
pipe.tokenizer,
|
| 1378 |
+
token=token,
|
| 1379 |
+
idempotent=idempotent_token,
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
| 1383 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
| 1384 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
| 1385 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
| 1386 |
+
if patch_ti:
|
| 1387 |
+
apply_learned_embed_in_clip(
|
| 1388 |
+
tok_dict,
|
| 1389 |
+
pipe.text_encoder,
|
| 1390 |
+
pipe.tokenizer,
|
| 1391 |
+
token=token,
|
| 1392 |
+
idempotent=idempotent_token,
|
| 1393 |
+
)
|
| 1394 |
+
return tok_dict
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
def train_patch_pipe(pipe, patch_unet, patch_text):
|
| 1398 |
+
if patch_unet:
|
| 1399 |
+
print("LoRA : Patching Unet")
|
| 1400 |
+
collapse_lora(pipe.unet)
|
| 1401 |
+
monkeypatch_remove_lora(pipe.unet)
|
| 1402 |
+
|
| 1403 |
+
if patch_text:
|
| 1404 |
+
print("LoRA : Patching text encoder")
|
| 1405 |
+
|
| 1406 |
+
collapse_lora(pipe.text_encoder)
|
| 1407 |
+
monkeypatch_remove_lora(pipe.text_encoder)
|
| 1408 |
+
|
| 1409 |
+
@torch.no_grad()
|
| 1410 |
+
def inspect_lora(model):
|
| 1411 |
+
moved = {}
|
| 1412 |
+
|
| 1413 |
+
for name, _module in model.named_modules():
|
| 1414 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
| 1415 |
+
ups = _module.lora_up.weight.data.clone()
|
| 1416 |
+
downs = _module.lora_down.weight.data.clone()
|
| 1417 |
+
|
| 1418 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
| 1419 |
+
|
| 1420 |
+
dist = wght.flatten().abs().mean().item()
|
| 1421 |
+
if name in moved:
|
| 1422 |
+
moved[name].append(dist)
|
| 1423 |
+
else:
|
| 1424 |
+
moved[name] = [dist]
|
| 1425 |
+
|
| 1426 |
+
return moved
|
| 1427 |
+
|
| 1428 |
+
|
| 1429 |
+
def save_all(
|
| 1430 |
+
unet,
|
| 1431 |
+
text_encoder,
|
| 1432 |
+
save_path,
|
| 1433 |
+
placeholder_token_ids=None,
|
| 1434 |
+
placeholder_tokens=None,
|
| 1435 |
+
save_lora=True,
|
| 1436 |
+
save_ti=True,
|
| 1437 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 1438 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
| 1439 |
+
safe_form=True,
|
| 1440 |
+
):
|
| 1441 |
+
if not safe_form:
|
| 1442 |
+
# save ti
|
| 1443 |
+
if save_ti:
|
| 1444 |
+
ti_path = _ti_lora_path(save_path)
|
| 1445 |
+
learned_embeds_dict = {}
|
| 1446 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1447 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1448 |
+
print(
|
| 1449 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1450 |
+
learned_embeds[:4],
|
| 1451 |
+
)
|
| 1452 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
| 1453 |
+
|
| 1454 |
+
torch.save(learned_embeds_dict, ti_path)
|
| 1455 |
+
print("Ti saved to ", ti_path)
|
| 1456 |
+
|
| 1457 |
+
# save text encoder
|
| 1458 |
+
if save_lora:
|
| 1459 |
+
save_lora_weight(
|
| 1460 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
| 1461 |
+
)
|
| 1462 |
+
print("Unet saved to ", save_path)
|
| 1463 |
+
|
| 1464 |
+
save_lora_weight(
|
| 1465 |
+
text_encoder,
|
| 1466 |
+
_text_lora_path(save_path),
|
| 1467 |
+
target_replace_module=target_replace_module_text,
|
| 1468 |
+
)
|
| 1469 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
| 1470 |
+
|
| 1471 |
+
else:
|
| 1472 |
+
assert save_path.endswith(
|
| 1473 |
+
".safetensors"
|
| 1474 |
+
), f"Save path : {save_path} should end with .safetensors"
|
| 1475 |
+
|
| 1476 |
+
loras = {}
|
| 1477 |
+
embeds = {}
|
| 1478 |
+
|
| 1479 |
+
if save_lora:
|
| 1480 |
+
|
| 1481 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
| 1482 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
| 1483 |
+
|
| 1484 |
+
if save_ti:
|
| 1485 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1486 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1487 |
+
print(
|
| 1488 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1489 |
+
learned_embeds[:4],
|
| 1490 |
+
)
|
| 1491 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
| 1492 |
+
|
| 1493 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
i2vedit/utils/lora_handler.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from logging import warnings
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Union
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
| 7 |
+
from transformers import CLIPTextModel
|
| 8 |
+
|
| 9 |
+
from .lora import (
|
| 10 |
+
extract_lora_ups_down,
|
| 11 |
+
inject_trainable_lora_extended,
|
| 12 |
+
save_lora_weight,
|
| 13 |
+
train_patch_pipe,
|
| 14 |
+
monkeypatch_or_replace_lora,
|
| 15 |
+
monkeypatch_or_replace_lora_extended
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
FILE_BASENAMES = ['unet', 'text_encoder']
|
| 20 |
+
LORA_FILE_TYPES = ['.pt', '.safetensors']
|
| 21 |
+
CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
|
| 22 |
+
STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
|
| 23 |
+
|
| 24 |
+
lora_versions = dict(
|
| 25 |
+
stable_lora = "stable_lora",
|
| 26 |
+
cloneofsimo = "cloneofsimo"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
lora_func_types = dict(
|
| 30 |
+
loader = "loader",
|
| 31 |
+
injector = "injector"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
lora_args = dict(
|
| 35 |
+
model = None,
|
| 36 |
+
loras = None,
|
| 37 |
+
target_replace_module = [],
|
| 38 |
+
target_module = [],
|
| 39 |
+
r = 4,
|
| 40 |
+
search_class = [torch.nn.Linear],
|
| 41 |
+
dropout = 0,
|
| 42 |
+
lora_bias = 'none'
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
LoraVersions = SimpleNamespace(**lora_versions)
|
| 46 |
+
LoraFuncTypes = SimpleNamespace(**lora_func_types)
|
| 47 |
+
|
| 48 |
+
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
|
| 49 |
+
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
|
| 50 |
+
|
| 51 |
+
def filter_dict(_dict, keys=[]):
|
| 52 |
+
if len(keys) == 0:
|
| 53 |
+
assert "Keys cannot empty for filtering return dict."
|
| 54 |
+
|
| 55 |
+
for k in keys:
|
| 56 |
+
if k not in lora_args.keys():
|
| 57 |
+
assert f"{k} does not exist in available LoRA arguments"
|
| 58 |
+
|
| 59 |
+
return {k: v for k, v in _dict.items() if k in keys}
|
| 60 |
+
|
| 61 |
+
class LoraHandler(object):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
version: LORA_VERSIONS = LoraVersions.cloneofsimo,
|
| 65 |
+
use_unet_lora: bool = False,
|
| 66 |
+
use_image_lora: bool = False,
|
| 67 |
+
save_for_webui: bool = False,
|
| 68 |
+
only_for_webui: bool = False,
|
| 69 |
+
lora_bias: str = 'none',
|
| 70 |
+
unet_replace_modules: list = None,
|
| 71 |
+
image_encoder_replace_modules: list = None
|
| 72 |
+
):
|
| 73 |
+
self.version = version
|
| 74 |
+
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
|
| 75 |
+
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
|
| 76 |
+
self.lora_bias = lora_bias
|
| 77 |
+
self.use_unet_lora = use_unet_lora
|
| 78 |
+
self.use_image_lora = use_image_lora
|
| 79 |
+
self.save_for_webui = save_for_webui
|
| 80 |
+
self.only_for_webui = only_for_webui
|
| 81 |
+
self.unet_replace_modules = unet_replace_modules
|
| 82 |
+
self.image_encoder_replace_modules = image_encoder_replace_modules
|
| 83 |
+
self.use_lora = any([use_image_lora, use_unet_lora])
|
| 84 |
+
|
| 85 |
+
def is_cloneofsimo_lora(self):
|
| 86 |
+
return self.version == LoraVersions.cloneofsimo
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
|
| 90 |
+
|
| 91 |
+
if self.is_cloneofsimo_lora():
|
| 92 |
+
|
| 93 |
+
if func_type == LoraFuncTypes.loader:
|
| 94 |
+
return monkeypatch_or_replace_lora_extended
|
| 95 |
+
|
| 96 |
+
if func_type == LoraFuncTypes.injector:
|
| 97 |
+
return inject_trainable_lora_extended
|
| 98 |
+
|
| 99 |
+
assert "LoRA Version does not exist."
|
| 100 |
+
|
| 101 |
+
def check_lora_ext(self, lora_file: str):
|
| 102 |
+
return lora_file.endswith(tuple(LORA_FILE_TYPES))
|
| 103 |
+
|
| 104 |
+
def get_lora_file_path(
|
| 105 |
+
self,
|
| 106 |
+
lora_path: str,
|
| 107 |
+
model: Union[UNetSpatioTemporalConditionModel, CLIPTextModel]
|
| 108 |
+
):
|
| 109 |
+
if os.path.exists(lora_path):
|
| 110 |
+
lora_filenames = [fns for fns in os.listdir(lora_path)]
|
| 111 |
+
is_lora = self.check_lora_ext(lora_path)
|
| 112 |
+
|
| 113 |
+
is_unet = isinstance(model, UNetSpatioTemporalConditionModel)
|
| 114 |
+
#is_text = isinstance(model, CLIPTextModel)
|
| 115 |
+
idx = 0 if is_unet else 1
|
| 116 |
+
|
| 117 |
+
base_name = FILE_BASENAMES[idx]
|
| 118 |
+
|
| 119 |
+
for lora_filename in lora_filenames:
|
| 120 |
+
is_lora = self.check_lora_ext(lora_filename)
|
| 121 |
+
if not is_lora:
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
if base_name in lora_filename:
|
| 125 |
+
return os.path.join(lora_path, lora_filename)
|
| 126 |
+
else:
|
| 127 |
+
print(f"lora_path: {lora_path} does not exist. Inject without pretrained loras...")
|
| 128 |
+
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
|
| 132 |
+
self.lora_loader(**lora_loader_args)
|
| 133 |
+
print(f"Successfully loaded LoRA from: {file_name}")
|
| 134 |
+
|
| 135 |
+
def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
|
| 136 |
+
try:
|
| 137 |
+
lora_file = self.get_lora_file_path(lora_path, model)
|
| 138 |
+
|
| 139 |
+
if lora_file is not None:
|
| 140 |
+
lora_loader_args.update({"lora_path": lora_file})
|
| 141 |
+
self.handle_lora_load(lora_file, lora_loader_args)
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"An error occurred while loading a LoRA file: {e}")
|
| 148 |
+
|
| 149 |
+
def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale):
|
| 150 |
+
return_dict = lora_args.copy()
|
| 151 |
+
|
| 152 |
+
if self.is_cloneofsimo_lora():
|
| 153 |
+
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
|
| 154 |
+
return_dict.update({
|
| 155 |
+
"model": model,
|
| 156 |
+
"loras": self.get_lora_file_path(lora_path, model),
|
| 157 |
+
"target_replace_module": replace_modules,
|
| 158 |
+
"r": r,
|
| 159 |
+
"scale": scale,
|
| 160 |
+
"dropout_p": dropout,
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return return_dict
|
| 164 |
+
|
| 165 |
+
def do_lora_injection(
|
| 166 |
+
self,
|
| 167 |
+
model,
|
| 168 |
+
replace_modules,
|
| 169 |
+
bias='none',
|
| 170 |
+
dropout=0,
|
| 171 |
+
r=4,
|
| 172 |
+
lora_loader_args=None,
|
| 173 |
+
):
|
| 174 |
+
REPLACE_MODULES = replace_modules
|
| 175 |
+
|
| 176 |
+
params = None
|
| 177 |
+
negation = None
|
| 178 |
+
is_injection_hybrid = False
|
| 179 |
+
|
| 180 |
+
if self.is_cloneofsimo_lora():
|
| 181 |
+
is_injection_hybrid = True
|
| 182 |
+
injector_args = lora_loader_args
|
| 183 |
+
|
| 184 |
+
params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended
|
| 185 |
+
for _up, _down in extract_lora_ups_down(
|
| 186 |
+
model,
|
| 187 |
+
target_replace_module=REPLACE_MODULES):
|
| 188 |
+
|
| 189 |
+
if all(x is not None for x in [_up, _down]):
|
| 190 |
+
print(f"Lora successfully injected into {model.__class__.__name__}.")
|
| 191 |
+
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
return params, negation, is_injection_hybrid
|
| 195 |
+
|
| 196 |
+
return params, negation, is_injection_hybrid
|
| 197 |
+
|
| 198 |
+
def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0):
|
| 199 |
+
|
| 200 |
+
params = None
|
| 201 |
+
negation = None
|
| 202 |
+
|
| 203 |
+
lora_loader_args = self.get_lora_func_args(
|
| 204 |
+
lora_path,
|
| 205 |
+
use_lora,
|
| 206 |
+
model,
|
| 207 |
+
replace_modules,
|
| 208 |
+
r,
|
| 209 |
+
dropout,
|
| 210 |
+
self.lora_bias,
|
| 211 |
+
scale
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if use_lora:
|
| 215 |
+
params, negation, is_injection_hybrid = self.do_lora_injection(
|
| 216 |
+
model,
|
| 217 |
+
replace_modules,
|
| 218 |
+
bias=self.lora_bias,
|
| 219 |
+
lora_loader_args=lora_loader_args,
|
| 220 |
+
dropout=dropout,
|
| 221 |
+
r=r
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if not is_injection_hybrid:
|
| 225 |
+
self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
|
| 226 |
+
|
| 227 |
+
params = model if params is None else params
|
| 228 |
+
return params, negation
|
| 229 |
+
|
| 230 |
+
def save_cloneofsimo_lora(self, model, save_path, step, flag):
|
| 231 |
+
|
| 232 |
+
def save_lora(model, name, condition, replace_modules, step, save_path, flag=None):
|
| 233 |
+
if condition and replace_modules is not None:
|
| 234 |
+
save_path = f"{save_path}/{step}_{name}.pt"
|
| 235 |
+
save_lora_weight(model, save_path, replace_modules, flag)
|
| 236 |
+
|
| 237 |
+
save_lora(
|
| 238 |
+
model.unet,
|
| 239 |
+
FILE_BASENAMES[0],
|
| 240 |
+
self.use_unet_lora,
|
| 241 |
+
self.unet_replace_modules,
|
| 242 |
+
step,
|
| 243 |
+
save_path,
|
| 244 |
+
flag
|
| 245 |
+
)
|
| 246 |
+
save_lora(
|
| 247 |
+
model.image_encoder,
|
| 248 |
+
FILE_BASENAMES[1],
|
| 249 |
+
self.use_image_lora,
|
| 250 |
+
self.image_encoder_replace_modules,
|
| 251 |
+
step,
|
| 252 |
+
save_path,
|
| 253 |
+
flag
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
|
| 257 |
+
|
| 258 |
+
def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None):
|
| 259 |
+
save_path = f"{save_path}/lora"
|
| 260 |
+
os.makedirs(save_path, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
if self.is_cloneofsimo_lora():
|
| 263 |
+
if any([self.save_for_webui, self.only_for_webui]):
|
| 264 |
+
warnings.warn(
|
| 265 |
+
"""
|
| 266 |
+
You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
|
| 267 |
+
Only 'stable_lora' is supported for saving to a compatible webui file.
|
| 268 |
+
"""
|
| 269 |
+
)
|
| 270 |
+
self.save_cloneofsimo_lora(model, save_path, step, flag)
|
i2vedit/utils/model_utils.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import inspect
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import gc
|
| 9 |
+
import copy
|
| 10 |
+
import imageio
|
| 11 |
+
import numpy as np
|
| 12 |
+
import PIL
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from scipy.stats import anderson
|
| 15 |
+
from typing import Dict, Optional, Tuple, Callable, List, Union
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
from tqdm.auto import tqdm
|
| 25 |
+
|
| 26 |
+
from accelerate import Accelerator
|
| 27 |
+
from accelerate.logging import get_logger
|
| 28 |
+
from accelerate.utils import set_seed
|
| 29 |
+
|
| 30 |
+
import transformers
|
| 31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 32 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 33 |
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
| 34 |
+
|
| 35 |
+
import diffusers
|
| 36 |
+
from diffusers.models import AutoencoderKL
|
| 37 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
| 38 |
+
from diffusers.optimization import get_scheduler
|
| 39 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 40 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
|
| 41 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 42 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 43 |
+
from diffusers.models.lora import LoRALinearLayer
|
| 44 |
+
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
|
| 45 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 46 |
+
from diffusers.optimization import get_scheduler
|
| 47 |
+
from diffusers.training_utils import EMAModel
|
| 48 |
+
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image, BaseOutput
|
| 49 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 50 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 51 |
+
from diffusers.models.unet_3d_blocks import \
|
| 52 |
+
(CrossAttnDownBlockSpatioTemporal,
|
| 53 |
+
DownBlockSpatioTemporal,
|
| 54 |
+
CrossAttnUpBlockSpatioTemporal,
|
| 55 |
+
UpBlockSpatioTemporal)
|
| 56 |
+
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteSchedulerOutput, EulerDiscreteScheduler
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _append_dims(x, target_dims):
|
| 60 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 61 |
+
dims_to_append = target_dims - x.ndim
|
| 62 |
+
if dims_to_append < 0:
|
| 63 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
| 64 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 65 |
+
|
| 66 |
+
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
| 67 |
+
# Based on:
|
| 68 |
+
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
| 69 |
+
|
| 70 |
+
batch_size, channels, num_frames, height, width = video.shape
|
| 71 |
+
outputs = []
|
| 72 |
+
for batch_idx in range(batch_size):
|
| 73 |
+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
| 74 |
+
batch_output = processor.postprocess(batch_vid, output_type)
|
| 75 |
+
|
| 76 |
+
outputs.append(batch_output)
|
| 77 |
+
|
| 78 |
+
return outputs
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def tensor_to_vae_latent(t, vae):
|
| 82 |
+
video_length = t.shape[1]
|
| 83 |
+
|
| 84 |
+
t = rearrange(t, "b f c h w -> (b f) c h w")
|
| 85 |
+
latents = vae.encode(t).latent_dist.sample()
|
| 86 |
+
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
|
| 87 |
+
latents = latents * vae.config.scaling_factor
|
| 88 |
+
|
| 89 |
+
return latents
|
| 90 |
+
|
| 91 |
+
def load_primary_models(pretrained_model_path):
|
| 92 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
| 93 |
+
pretrained_model_path, subfolder="scheduler")
|
| 94 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 95 |
+
pretrained_model_path, subfolder="feature_extractor", revision=None
|
| 96 |
+
)
|
| 97 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 98 |
+
pretrained_model_path, subfolder="image_encoder", revision=None, variant="fp16"
|
| 99 |
+
)
|
| 100 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
| 101 |
+
pretrained_model_path, subfolder="vae", revision=None, variant="fp16")
|
| 102 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 103 |
+
pretrained_model_path,
|
| 104 |
+
subfolder="unet",
|
| 105 |
+
low_cpu_mem_usage=True,
|
| 106 |
+
variant="fp16",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return noise_scheduler, feature_extractor, image_encoder, vae, unet
|
| 110 |
+
|
| 111 |
+
def set_processors(attentions):
|
| 112 |
+
for attn in attentions: attn.set_processor(AttnProcessor2_0())
|
| 113 |
+
|
| 114 |
+
def is_attn(name):
|
| 115 |
+
return ('attn1' or 'attn2' == name.split('.')[-1])
|
| 116 |
+
|
| 117 |
+
def set_torch_2_attn(unet):
|
| 118 |
+
optim_count = 0
|
| 119 |
+
|
| 120 |
+
for name, module in unet.named_modules():
|
| 121 |
+
if is_attn(name):
|
| 122 |
+
if isinstance(module, torch.nn.ModuleList):
|
| 123 |
+
for m in module:
|
| 124 |
+
if isinstance(m, BasicTransformerBlock):
|
| 125 |
+
set_processors([m.attn1, m.attn2])
|
| 126 |
+
optim_count += 1
|
| 127 |
+
if optim_count > 0:
|
| 128 |
+
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
|
| 129 |
+
|
| 130 |
+
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
|
| 131 |
+
try:
|
| 132 |
+
is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
|
| 133 |
+
enable_torch_2 = is_torch_2 and enable_torch_2_attn
|
| 134 |
+
|
| 135 |
+
if enable_xformers_memory_efficient_attention and not enable_torch_2:
|
| 136 |
+
if is_xformers_available():
|
| 137 |
+
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 138 |
+
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 141 |
+
|
| 142 |
+
if enable_torch_2:
|
| 143 |
+
set_torch_2_attn(unet)
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(e)
|
| 147 |
+
print("Could not enable memory efficient attention for xformers or Torch 2.0.")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class P2PEulerDiscreteScheduler(EulerDiscreteScheduler):
|
| 151 |
+
|
| 152 |
+
def step(
|
| 153 |
+
self,
|
| 154 |
+
model_output: torch.FloatTensor,
|
| 155 |
+
timestep: Union[float, torch.FloatTensor],
|
| 156 |
+
sample: torch.FloatTensor,
|
| 157 |
+
conditional_latents: torch.FloatTensor = None,
|
| 158 |
+
s_churn: float = 0.0,
|
| 159 |
+
s_tmin: float = 0.0,
|
| 160 |
+
s_tmax: float = float("inf"),
|
| 161 |
+
s_noise: float = 1.0,
|
| 162 |
+
generator: Optional[torch.Generator] = None,
|
| 163 |
+
return_dict: bool = True,
|
| 164 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
| 165 |
+
"""
|
| 166 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 167 |
+
process from the learned model outputs (most often the predicted noise).
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
model_output (`torch.FloatTensor`):
|
| 171 |
+
The direct output from learned diffusion model.
|
| 172 |
+
timestep (`float`):
|
| 173 |
+
The current discrete timestep in the diffusion chain.
|
| 174 |
+
sample (`torch.FloatTensor`):
|
| 175 |
+
A current instance of a sample created by the diffusion process.
|
| 176 |
+
s_churn (`float`):
|
| 177 |
+
s_tmin (`float`):
|
| 178 |
+
s_tmax (`float`):
|
| 179 |
+
s_noise (`float`, defaults to 1.0):
|
| 180 |
+
Scaling factor for noise added to the sample.
|
| 181 |
+
generator (`torch.Generator`, *optional*):
|
| 182 |
+
A random number generator.
|
| 183 |
+
return_dict (`bool`):
|
| 184 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 185 |
+
tuple.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 189 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 190 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
if (
|
| 194 |
+
isinstance(timestep, int)
|
| 195 |
+
or isinstance(timestep, torch.IntTensor)
|
| 196 |
+
or isinstance(timestep, torch.LongTensor)
|
| 197 |
+
):
|
| 198 |
+
raise ValueError(
|
| 199 |
+
(
|
| 200 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 201 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 202 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if not self.is_scale_input_called:
|
| 207 |
+
logger.warning(
|
| 208 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 209 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if self.step_index is None:
|
| 213 |
+
self._init_step_index(timestep)
|
| 214 |
+
|
| 215 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 216 |
+
sample = sample.to(torch.float32)
|
| 217 |
+
|
| 218 |
+
sigma = self.sigmas[self.step_index]
|
| 219 |
+
|
| 220 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
| 221 |
+
|
| 222 |
+
noise = randn_tensor(
|
| 223 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
eps = noise * s_noise
|
| 227 |
+
sigma_hat = sigma * (gamma + 1)
|
| 228 |
+
|
| 229 |
+
if gamma > 0:
|
| 230 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
| 231 |
+
|
| 232 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 233 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
| 234 |
+
# backwards compatibility
|
| 235 |
+
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
|
| 236 |
+
pred_original_sample = model_output
|
| 237 |
+
elif self.config.prediction_type == "epsilon":
|
| 238 |
+
pred_original_sample = sample - sigma_hat * model_output
|
| 239 |
+
elif self.config.prediction_type == "v_prediction":
|
| 240 |
+
# denoised = model_output * c_out + input * c_skip
|
| 241 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
for controller in self.controller:
|
| 248 |
+
pred_original_sample = controller.step_callback(pred_original_sample)
|
| 249 |
+
# color preservation
|
| 250 |
+
#color_delta = torch.mean(conditional_latents[:,0:1,:,:,:]) - torch.mean(pred_original_sample[:,0:1,:,:,:])
|
| 251 |
+
#print("color_delta", color_delta)
|
| 252 |
+
#pred_original_sample = pred_original_sample + color_delta
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# 2. Convert to an ODE derivative
|
| 257 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
| 258 |
+
|
| 259 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
| 260 |
+
|
| 261 |
+
prev_sample = sample + derivative * dt
|
| 262 |
+
|
| 263 |
+
# Cast sample back to model compatible dtype
|
| 264 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 265 |
+
|
| 266 |
+
# upon completion increase step index by one
|
| 267 |
+
self._step_index += 1
|
| 268 |
+
|
| 269 |
+
if not return_dict:
|
| 270 |
+
return (prev_sample,)
|
| 271 |
+
|
| 272 |
+
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 273 |
+
|
| 274 |
+
@dataclass
|
| 275 |
+
class StableVideoDiffusionPipelineOutput(BaseOutput):
|
| 276 |
+
r"""
|
| 277 |
+
Output class for zero-shot text-to-video pipeline.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
|
| 281 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
| 282 |
+
num_channels)`.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
frames: Union[List[PIL.Image.Image], np.ndarray]
|
| 286 |
+
latents: torch.Tensor
|
| 287 |
+
|
| 288 |
+
class P2PStableVideoDiffusionPipeline(StableVideoDiffusionPipeline):
|
| 289 |
+
|
| 290 |
+
def _encode_vae_image(
|
| 291 |
+
self,
|
| 292 |
+
image: torch.Tensor,
|
| 293 |
+
device,
|
| 294 |
+
num_videos_per_prompt,
|
| 295 |
+
do_classifier_free_guidance,
|
| 296 |
+
image_latents: torch.Tensor = None
|
| 297 |
+
):
|
| 298 |
+
if image_latents is None:
|
| 299 |
+
image = image.to(device=device)
|
| 300 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
| 301 |
+
else:
|
| 302 |
+
image_latents = rearrange(image_latents, "b f c h w -> (b f) c h w")
|
| 303 |
+
|
| 304 |
+
if do_classifier_free_guidance:
|
| 305 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
| 306 |
+
|
| 307 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 308 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 309 |
+
# to avoid doing two forward passes
|
| 310 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
| 311 |
+
|
| 312 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
| 313 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
| 314 |
+
|
| 315 |
+
return image_latents
|
| 316 |
+
|
| 317 |
+
@torch.no_grad()
|
| 318 |
+
def __call__(
|
| 319 |
+
self,
|
| 320 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
| 321 |
+
edited_firstframe: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] = None,
|
| 322 |
+
image_latents: torch.FloatTensor = None,
|
| 323 |
+
height: int = 576,
|
| 324 |
+
width: int = 1024,
|
| 325 |
+
num_frames: Optional[int] = None,
|
| 326 |
+
num_inference_steps: int = 25,
|
| 327 |
+
min_guidance_scale: float = 1.0,
|
| 328 |
+
max_guidance_scale: float = 2.5,
|
| 329 |
+
fps: int = 7,
|
| 330 |
+
motion_bucket_id: int = 127,
|
| 331 |
+
noise_aug_strength: int = 0.02,
|
| 332 |
+
decode_chunk_size: Optional[int] = None,
|
| 333 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 334 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 335 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 336 |
+
output_type: Optional[str] = "pil",
|
| 337 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 338 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 339 |
+
return_dict: bool = True,
|
| 340 |
+
):
|
| 341 |
+
r"""
|
| 342 |
+
The call function to the pipeline for generation.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 346 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
| 347 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
| 348 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 349 |
+
The height in pixels of the generated image.
|
| 350 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 351 |
+
The width in pixels of the generated image.
|
| 352 |
+
num_frames (`int`, *optional*):
|
| 353 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
| 354 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 355 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 356 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 357 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 358 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 359 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 360 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 361 |
+
fps (`int`, *optional*, defaults to 7):
|
| 362 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
| 363 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 364 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 365 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
| 366 |
+
noise_aug_strength (`int`, *optional*, defaults to 0.02):
|
| 367 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
| 368 |
+
decode_chunk_size (`int`, *optional*):
|
| 369 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
| 370 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
| 371 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
| 372 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 373 |
+
The number of images to generate per prompt.
|
| 374 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 375 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 376 |
+
generation deterministic.
|
| 377 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 378 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 379 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 380 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 381 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 382 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 383 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 384 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 385 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 386 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 387 |
+
`callback_on_step_end_tensor_inputs`.
|
| 388 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 389 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 390 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 391 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 392 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 393 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 394 |
+
plain tuple.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 398 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
| 399 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
| 400 |
+
|
| 401 |
+
Examples:
|
| 402 |
+
|
| 403 |
+
```py
|
| 404 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 405 |
+
from diffusers.utils import load_image, export_to_video
|
| 406 |
+
|
| 407 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
| 408 |
+
pipe.to("cuda")
|
| 409 |
+
|
| 410 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
| 411 |
+
image = image.resize((1024, 576))
|
| 412 |
+
|
| 413 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
| 414 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
| 415 |
+
```
|
| 416 |
+
"""
|
| 417 |
+
# 0. Default height and width to unet
|
| 418 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 419 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 420 |
+
|
| 421 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 422 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 423 |
+
|
| 424 |
+
# 1. Check inputs. Raise error if not correct
|
| 425 |
+
self.check_inputs(image, height, width)
|
| 426 |
+
|
| 427 |
+
# 2. Define call parameters
|
| 428 |
+
if isinstance(image, PIL.Image.Image):
|
| 429 |
+
batch_size = 1
|
| 430 |
+
elif isinstance(image, list):
|
| 431 |
+
batch_size = len(image)
|
| 432 |
+
else:
|
| 433 |
+
batch_size = image.shape[0]
|
| 434 |
+
device = self._execution_device
|
| 435 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 436 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 437 |
+
# corresponds to doing no classifier free guidance.
|
| 438 |
+
self._guidance_scale = max_guidance_scale
|
| 439 |
+
|
| 440 |
+
# 3. Encode input image
|
| 441 |
+
edited_firstframe = image if edited_firstframe is None else edited_firstframe
|
| 442 |
+
image_embeddings = self._encode_image(edited_firstframe, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 443 |
+
|
| 444 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
| 445 |
+
# is why it is reduced here.
|
| 446 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 447 |
+
fps = fps - 1
|
| 448 |
+
|
| 449 |
+
# 4. Encode input image using VAE
|
| 450 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
| 451 |
+
edited_firstframe = self.image_processor.preprocess(edited_firstframe, height=height, width=width)
|
| 452 |
+
#print("before vae", image.min(), image.max())
|
| 453 |
+
#noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
| 454 |
+
#image = image + noise_aug_strength * noise
|
| 455 |
+
#edited_firstframe = edited_firstframe + noise_aug_strength * noise
|
| 456 |
+
if image_latents is not None:
|
| 457 |
+
#noise_tmp = randn_tensor(image_latents.shape, generator=generator, device=image_latents.device, dtype=image_latents.dtype)
|
| 458 |
+
image_latents = image_latents / self.vae.config.scaling_factor
|
| 459 |
+
#image_latents = image_latents + noise_aug_strength * noise_tmp
|
| 460 |
+
|
| 461 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 462 |
+
if needs_upcasting:
|
| 463 |
+
self.vae.to(dtype=torch.float32)
|
| 464 |
+
|
| 465 |
+
#print("before vae", image.min(), image.max())
|
| 466 |
+
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance, image_latents = image_latents)
|
| 467 |
+
firstframe_latents = self._encode_vae_image(edited_firstframe, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 468 |
+
noise = randn_tensor(image_latents.shape, generator=generator, device=image_latents.device, dtype=image_latents.dtype)[1:]
|
| 469 |
+
image_latents[1:] = image_latents[1:] + noise_aug_strength * noise #/ self.vae.config.scaling_factor
|
| 470 |
+
#firstframe_latents = firstframe_latents + noise_aug_strength * noise / self.vae.config.scaling_factor
|
| 471 |
+
|
| 472 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 473 |
+
firstframe_latents = firstframe_latents.to(image_embeddings.dtype)
|
| 474 |
+
|
| 475 |
+
# cast back to fp16 if needed
|
| 476 |
+
if needs_upcasting:
|
| 477 |
+
self.vae.to(dtype=torch.float16)
|
| 478 |
+
|
| 479 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 480 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
| 481 |
+
|
| 482 |
+
skip=num_frames
|
| 483 |
+
image_latents = torch.cat(
|
| 484 |
+
[
|
| 485 |
+
image_latents.unsqueeze(1).repeat(1, skip, 1, 1, 1),
|
| 486 |
+
firstframe_latents.unsqueeze(1).repeat(1, num_frames-skip, 1, 1, 1)
|
| 487 |
+
],
|
| 488 |
+
dim=1
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
#image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 492 |
+
#print("image latents", image_latents.min(), image_latents.max())
|
| 493 |
+
|
| 494 |
+
# 5. Get Added Time IDs
|
| 495 |
+
added_time_ids = self._get_add_time_ids(
|
| 496 |
+
fps,
|
| 497 |
+
motion_bucket_id,
|
| 498 |
+
noise_aug_strength,
|
| 499 |
+
image_embeddings.dtype,
|
| 500 |
+
batch_size,
|
| 501 |
+
num_videos_per_prompt,
|
| 502 |
+
self.do_classifier_free_guidance,
|
| 503 |
+
)
|
| 504 |
+
added_time_ids = added_time_ids.to(device)
|
| 505 |
+
|
| 506 |
+
# 4. Prepare timesteps
|
| 507 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 508 |
+
timesteps = self.scheduler.timesteps
|
| 509 |
+
|
| 510 |
+
# 5. Prepare latent variables
|
| 511 |
+
num_channels_latents = self.unet.config.in_channels
|
| 512 |
+
latents = self.prepare_latents(
|
| 513 |
+
batch_size * num_videos_per_prompt,
|
| 514 |
+
num_frames,
|
| 515 |
+
num_channels_latents,
|
| 516 |
+
height,
|
| 517 |
+
width,
|
| 518 |
+
image_embeddings.dtype,
|
| 519 |
+
device,
|
| 520 |
+
generator,
|
| 521 |
+
latents,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# 7. Prepare guidance scale
|
| 525 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 526 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 527 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 528 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 529 |
+
|
| 530 |
+
self._guidance_scale = guidance_scale
|
| 531 |
+
|
| 532 |
+
# 8. Denoising loop
|
| 533 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 534 |
+
self._num_timesteps = len(timesteps)
|
| 535 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 536 |
+
for i, t in enumerate(timesteps):
|
| 537 |
+
# expand the latents if we are doing classifier free guidance
|
| 538 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 539 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 540 |
+
|
| 541 |
+
# Concatenate image_latents over channels dimention
|
| 542 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 543 |
+
|
| 544 |
+
# predict the noise residual
|
| 545 |
+
noise_pred = self.unet(
|
| 546 |
+
latent_model_input,
|
| 547 |
+
t,
|
| 548 |
+
encoder_hidden_states=image_embeddings,
|
| 549 |
+
added_time_ids=added_time_ids,
|
| 550 |
+
return_dict=False,
|
| 551 |
+
)[0]
|
| 552 |
+
|
| 553 |
+
# perform guidance
|
| 554 |
+
if self.do_classifier_free_guidance:
|
| 555 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 556 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 557 |
+
|
| 558 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 559 |
+
#conditional_latents = image_latents.chunk(2)[1] if self.do_classifier_free_guidance else image_latents
|
| 560 |
+
#latents = self.scheduler.step(noise_pred, t, latents, conditional_latents=conditional_latents*self.vae.config.scaling_factor).prev_sample
|
| 561 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 562 |
+
|
| 563 |
+
if callback_on_step_end is not None:
|
| 564 |
+
callback_kwargs = {}
|
| 565 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 566 |
+
callback_kwargs[k] = locals()[k]
|
| 567 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 568 |
+
|
| 569 |
+
latents = callback_outputs.pop("latents", latents)
|
| 570 |
+
|
| 571 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 572 |
+
progress_bar.update()
|
| 573 |
+
|
| 574 |
+
if not output_type == "latent":
|
| 575 |
+
# cast back to fp16 if needed
|
| 576 |
+
if needs_upcasting:
|
| 577 |
+
self.vae.to(dtype=torch.float16)
|
| 578 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 579 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
| 580 |
+
else:
|
| 581 |
+
frames = latents
|
| 582 |
+
|
| 583 |
+
self.maybe_free_model_hooks()
|
| 584 |
+
|
| 585 |
+
if not return_dict:
|
| 586 |
+
return frames
|
| 587 |
+
|
| 588 |
+
return StableVideoDiffusionPipelineOutput(frames=frames, latents=latents)
|
i2vedit/utils/svd_util.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import importlib
|
| 3 |
+
import os
|
| 4 |
+
from functools import partial
|
| 5 |
+
from inspect import isfunction
|
| 6 |
+
|
| 7 |
+
import fsspec
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
+
from safetensors.torch import load_file as load_safetensors
|
| 13 |
+
|
| 14 |
+
import decord
|
| 15 |
+
from einops import rearrange, repeat
|
| 16 |
+
from torchvision.transforms import Resize, Pad, InterpolationMode
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def disabled_train(self, mode=True):
|
| 21 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 22 |
+
does not change anymore."""
|
| 23 |
+
return self
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_string_from_tuple(s):
|
| 27 |
+
try:
|
| 28 |
+
# Check if the string starts and ends with parentheses
|
| 29 |
+
if s[0] == "(" and s[-1] == ")":
|
| 30 |
+
# Convert the string to a tuple
|
| 31 |
+
t = eval(s)
|
| 32 |
+
# Check if the type of t is tuple
|
| 33 |
+
if type(t) == tuple:
|
| 34 |
+
return t[0]
|
| 35 |
+
else:
|
| 36 |
+
pass
|
| 37 |
+
except:
|
| 38 |
+
pass
|
| 39 |
+
return s
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def is_power_of_two(n):
|
| 43 |
+
"""
|
| 44 |
+
chat.openai.com/chat
|
| 45 |
+
Return True if n is a power of 2, otherwise return False.
|
| 46 |
+
|
| 47 |
+
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
|
| 48 |
+
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
|
| 49 |
+
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
|
| 50 |
+
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
if n <= 0:
|
| 54 |
+
return False
|
| 55 |
+
return (n & (n - 1)) == 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def autocast(f, enabled=True):
|
| 59 |
+
def do_autocast(*args, **kwargs):
|
| 60 |
+
with torch.cuda.amp.autocast(
|
| 61 |
+
enabled=enabled,
|
| 62 |
+
dtype=torch.get_autocast_gpu_dtype(),
|
| 63 |
+
cache_enabled=torch.is_autocast_cache_enabled(),
|
| 64 |
+
):
|
| 65 |
+
return f(*args, **kwargs)
|
| 66 |
+
|
| 67 |
+
return do_autocast
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_partial_from_config(config):
|
| 71 |
+
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def log_txt_as_img(wh, xc, size=10):
|
| 75 |
+
# wh a tuple of (width, height)
|
| 76 |
+
# xc a list of captions to plot
|
| 77 |
+
b = len(xc)
|
| 78 |
+
txts = list()
|
| 79 |
+
for bi in range(b):
|
| 80 |
+
txt = Image.new("RGB", wh, color="white")
|
| 81 |
+
draw = ImageDraw.Draw(txt)
|
| 82 |
+
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
|
| 83 |
+
nc = int(40 * (wh[0] / 256))
|
| 84 |
+
if isinstance(xc[bi], list):
|
| 85 |
+
text_seq = xc[bi][0]
|
| 86 |
+
else:
|
| 87 |
+
text_seq = xc[bi]
|
| 88 |
+
lines = "\n".join(
|
| 89 |
+
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
| 94 |
+
except UnicodeEncodeError:
|
| 95 |
+
print("Cant encode string for logging. Skipping.")
|
| 96 |
+
|
| 97 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
| 98 |
+
txts.append(txt)
|
| 99 |
+
txts = np.stack(txts)
|
| 100 |
+
txts = torch.tensor(txts)
|
| 101 |
+
return txts
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def partialclass(cls, *args, **kwargs):
|
| 105 |
+
class NewCls(cls):
|
| 106 |
+
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
| 107 |
+
|
| 108 |
+
return NewCls
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def make_path_absolute(path):
|
| 112 |
+
fs, p = fsspec.core.url_to_fs(path)
|
| 113 |
+
if fs.protocol == "file":
|
| 114 |
+
return os.path.abspath(p)
|
| 115 |
+
return path
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def ismap(x):
|
| 119 |
+
if not isinstance(x, torch.Tensor):
|
| 120 |
+
return False
|
| 121 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def isimage(x):
|
| 125 |
+
if not isinstance(x, torch.Tensor):
|
| 126 |
+
return False
|
| 127 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def isheatmap(x):
|
| 131 |
+
if not isinstance(x, torch.Tensor):
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
return x.ndim == 2
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def isneighbors(x):
|
| 138 |
+
if not isinstance(x, torch.Tensor):
|
| 139 |
+
return False
|
| 140 |
+
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def exists(x):
|
| 144 |
+
return x is not None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def expand_dims_like(x, y):
|
| 148 |
+
while x.dim() != y.dim():
|
| 149 |
+
x = x.unsqueeze(-1)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def default(val, d):
|
| 154 |
+
if exists(val):
|
| 155 |
+
return val
|
| 156 |
+
return d() if isfunction(d) else d
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def mean_flat(tensor):
|
| 160 |
+
"""
|
| 161 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
| 162 |
+
Take the mean over all non-batch dimensions.
|
| 163 |
+
"""
|
| 164 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def count_params(model, verbose=False):
|
| 168 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 169 |
+
if verbose:
|
| 170 |
+
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
| 171 |
+
return total_params
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def instantiate_from_config(config):
|
| 175 |
+
if not "target" in config:
|
| 176 |
+
if config == "__is_first_stage__":
|
| 177 |
+
return None
|
| 178 |
+
elif config == "__is_unconditional__":
|
| 179 |
+
return None
|
| 180 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 181 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_obj_from_str(string, reload=False, invalidate_cache=True):
|
| 185 |
+
module, cls = string.rsplit(".", 1)
|
| 186 |
+
if invalidate_cache:
|
| 187 |
+
importlib.invalidate_caches()
|
| 188 |
+
if reload:
|
| 189 |
+
module_imp = importlib.import_module(module)
|
| 190 |
+
importlib.reload(module_imp)
|
| 191 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def append_zero(x):
|
| 195 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def append_dims(x, target_dims):
|
| 199 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 200 |
+
dims_to_append = target_dims - x.ndim
|
| 201 |
+
if dims_to_append < 0:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
| 204 |
+
)
|
| 205 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
| 209 |
+
print(f"Loading model from {ckpt}")
|
| 210 |
+
if ckpt.endswith("ckpt"):
|
| 211 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 212 |
+
if "global_step" in pl_sd:
|
| 213 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
| 214 |
+
sd = pl_sd["state_dict"]
|
| 215 |
+
elif ckpt.endswith("safetensors"):
|
| 216 |
+
sd = load_safetensors(ckpt)
|
| 217 |
+
else:
|
| 218 |
+
raise NotImplementedError
|
| 219 |
+
|
| 220 |
+
model = instantiate_from_config(config.model)
|
| 221 |
+
|
| 222 |
+
m, u = model.load_state_dict(sd, strict=False)
|
| 223 |
+
|
| 224 |
+
if len(m) > 0 and verbose:
|
| 225 |
+
print("missing keys:")
|
| 226 |
+
print(m)
|
| 227 |
+
if len(u) > 0 and verbose:
|
| 228 |
+
print("unexpected keys:")
|
| 229 |
+
print(u)
|
| 230 |
+
|
| 231 |
+
if freeze:
|
| 232 |
+
for param in model.parameters():
|
| 233 |
+
param.requires_grad = False
|
| 234 |
+
|
| 235 |
+
model.eval()
|
| 236 |
+
return model
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_configs_path() -> str:
|
| 240 |
+
"""
|
| 241 |
+
Get the `configs` directory.
|
| 242 |
+
For a working copy, this is the one in the root of the repository,
|
| 243 |
+
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
|
| 244 |
+
"""
|
| 245 |
+
this_dir = os.path.dirname(__file__)
|
| 246 |
+
candidates = (
|
| 247 |
+
os.path.join(this_dir, "configs"),
|
| 248 |
+
os.path.join(this_dir, "..", "configs"),
|
| 249 |
+
)
|
| 250 |
+
for candidate in candidates:
|
| 251 |
+
candidate = os.path.abspath(candidate)
|
| 252 |
+
if os.path.isdir(candidate):
|
| 253 |
+
return candidate
|
| 254 |
+
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
|
| 258 |
+
"""
|
| 259 |
+
Will return the result of a recursive get attribute call.
|
| 260 |
+
E.g.:
|
| 261 |
+
a.b.c
|
| 262 |
+
= getattr(getattr(a, "b"), "c")
|
| 263 |
+
= get_nested_attribute(a, "b.c")
|
| 264 |
+
If any part of the attribute call is an integer x with current obj a, will
|
| 265 |
+
try to call a[x] instead of a.x first.
|
| 266 |
+
"""
|
| 267 |
+
attributes = attribute_path.split(".")
|
| 268 |
+
if depth is not None and depth > 0:
|
| 269 |
+
attributes = attributes[:depth]
|
| 270 |
+
assert len(attributes) > 0, "At least one attribute should be selected"
|
| 271 |
+
current_attribute = obj
|
| 272 |
+
current_key = None
|
| 273 |
+
for level, attribute in enumerate(attributes):
|
| 274 |
+
current_key = ".".join(attributes[: level + 1])
|
| 275 |
+
try:
|
| 276 |
+
id_ = int(attribute)
|
| 277 |
+
current_attribute = current_attribute[id_]
|
| 278 |
+
except ValueError:
|
| 279 |
+
current_attribute = getattr(current_attribute, attribute)
|
| 280 |
+
|
| 281 |
+
return (current_attribute, current_key) if return_key else current_attribute
|
| 282 |
+
|
| 283 |
+
def pad_with_ratio(frames, res):
|
| 284 |
+
_, _, ih, iw = frames.shape
|
| 285 |
+
#print("ih, iw", ih, iw)
|
| 286 |
+
i_ratio = ih / iw
|
| 287 |
+
h, w = res
|
| 288 |
+
#print("h,w", h ,w)
|
| 289 |
+
n_ratio = h / w
|
| 290 |
+
if i_ratio > n_ratio:
|
| 291 |
+
nw = int(ih / h * w)
|
| 292 |
+
#print("nw", nw)
|
| 293 |
+
frames = Pad((nw - iw)//2)(frames)
|
| 294 |
+
frames = frames[...,(nw - iw)//2:-(nw - iw)//2,:]
|
| 295 |
+
else:
|
| 296 |
+
nh = int(iw / w * h)
|
| 297 |
+
frames = Pad((nh - ih)//2)(frames)
|
| 298 |
+
frames = frames[...,:,(nh - ih)//2:-(nh - ih)//2]
|
| 299 |
+
#print("after pad", frames.shape)
|
| 300 |
+
return frames
|
| 301 |
+
|
| 302 |
+
def prepare_video(video_path:str, resolution, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1, pad_to_fix=False):
|
| 303 |
+
vr = decord.VideoReader(video_path)
|
| 304 |
+
initial_fps = vr.get_avg_fps()
|
| 305 |
+
if output_fps == -1:
|
| 306 |
+
output_fps = int(initial_fps)
|
| 307 |
+
if end_t == -1:
|
| 308 |
+
end_t = len(vr) / initial_fps
|
| 309 |
+
else:
|
| 310 |
+
end_t = min(len(vr) / initial_fps, end_t)
|
| 311 |
+
assert 0 <= start_t < end_t
|
| 312 |
+
assert output_fps > 0
|
| 313 |
+
start_f_ind = int(start_t * initial_fps)
|
| 314 |
+
end_f_ind = int(end_t * initial_fps)
|
| 315 |
+
num_f = int((end_t - start_t) * output_fps)
|
| 316 |
+
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
|
| 317 |
+
video = vr.get_batch(sample_idx)
|
| 318 |
+
if torch.is_tensor(video):
|
| 319 |
+
video = video.detach().cpu().numpy()
|
| 320 |
+
else:
|
| 321 |
+
video = video.asnumpy()
|
| 322 |
+
_, h, w, _ = video.shape
|
| 323 |
+
video = rearrange(video, "f h w c -> f c h w")
|
| 324 |
+
video = torch.Tensor(video).to(device).to(dtype)
|
| 325 |
+
|
| 326 |
+
# Use max if you want the larger side to be equal to resolution (e.g. 512)
|
| 327 |
+
# k = float(resolution) / min(h, w)
|
| 328 |
+
if pad_to_fix and resolution is not None:
|
| 329 |
+
video = pad_with_ratio(video, resolution)
|
| 330 |
+
if isinstance(resolution, tuple):
|
| 331 |
+
#video = Resize(resolution, interpolation=InterpolationMode.BICUBIC, antialias=True)(video)
|
| 332 |
+
video = nn.functional.interpolate(video, size=resolution, mode='bilinear')
|
| 333 |
+
else:
|
| 334 |
+
k = float(resolution) / max(h, w)
|
| 335 |
+
h *= k
|
| 336 |
+
w *= k
|
| 337 |
+
h = int(np.round(h / 64.0)) * 64
|
| 338 |
+
w = int(np.round(w / 64.0)) * 64
|
| 339 |
+
video = Resize((h, w), interpolation=InterpolationMode.BICUBIC, antialias=True)(video)
|
| 340 |
+
|
| 341 |
+
if normalize:
|
| 342 |
+
video = video / 127.5 - 1.0
|
| 343 |
+
|
| 344 |
+
return video, output_fps
|
| 345 |
+
|
| 346 |
+
def return_to_original_res(frames, res, pad_to_fix=False):
|
| 347 |
+
#print("original res", res)
|
| 348 |
+
_, _, h, w = frames.shape
|
| 349 |
+
#print("h w", h, w)
|
| 350 |
+
n_ratio = h / w
|
| 351 |
+
ih, iw = res
|
| 352 |
+
i_ratio = ih / iw
|
| 353 |
+
if pad_to_fix:
|
| 354 |
+
if i_ratio > n_ratio:
|
| 355 |
+
nw = int(ih / h * w)
|
| 356 |
+
frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 357 |
+
frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
|
| 358 |
+
else:
|
| 359 |
+
nh = int(iw / w * h)
|
| 360 |
+
frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 361 |
+
|
| 362 |
+
frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
|
| 363 |
+
else:
|
| 364 |
+
frames = Resize((ih, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
|
| 365 |
+
|
| 366 |
+
return frames
|
| 367 |
+
|
| 368 |
+
class SmoothAreaRandomDetection(object):
|
| 369 |
+
|
| 370 |
+
def __init__(self, device="cuda", dtype=torch.float16):
|
| 371 |
+
|
| 372 |
+
kernel_x = torch.zeros(3,3,3,3)
|
| 373 |
+
for i in range(3):
|
| 374 |
+
kernel_x[i,i,:,:] = torch.Tensor([[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]])
|
| 375 |
+
kernel_y = torch.zeros(3,3,3,3)
|
| 376 |
+
for i in range(3):
|
| 377 |
+
kernel_y[i,i,:,:] = torch.Tensor([[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]])
|
| 378 |
+
kernel_x = kernel_x.to(device, dtype)
|
| 379 |
+
kernel_y = kernel_y.to(device, dtype)
|
| 380 |
+
self.weight_x = kernel_x
|
| 381 |
+
self.weight_y = kernel_y
|
| 382 |
+
|
| 383 |
+
self.eps = 1/256.
|
| 384 |
+
|
| 385 |
+
def detection(self, x, thr=0.0):
|
| 386 |
+
original_dim = x.ndim
|
| 387 |
+
if x.ndim > 4:
|
| 388 |
+
b, f, c, h, w = x.shape
|
| 389 |
+
x = rearrange(x, "b f c h w -> (b f) c h w")
|
| 390 |
+
grad_xx = F.conv2d(x, self.weight_x, stride=1, padding=1)
|
| 391 |
+
grad_yx = F.conv2d(x, self.weight_y, stride=1, padding=1)
|
| 392 |
+
gradient_x = torch.abs(grad_xx) + torch.abs(grad_yx)
|
| 393 |
+
gradient_x = torch.mean(gradient_x, dim=1, keepdim=True)
|
| 394 |
+
gradient_x = repeat(gradient_x, "b 1 ... -> b 3 ...")
|
| 395 |
+
if original_dim > 4:
|
| 396 |
+
gradient_x = rearrange(gradient_x, "(b f) c h w -> b f c h w", b=b)
|
| 397 |
+
return gradient_x <= thr
|
i2vedit/version.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Thu May 29 22:38:49 2025
|
| 3 |
+
__version__ = '0.1.0-dev'
|
| 4 |
+
__gitsha__ = 'unknown'
|
| 5 |
+
version_info = (0, 1, "0-dev")
|
main.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import inspect
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import gc
|
| 9 |
+
import copy
|
| 10 |
+
import imageio
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from scipy.stats import anderson
|
| 14 |
+
from typing import Dict, Optional, Tuple, List
|
| 15 |
+
from omegaconf import OmegaConf
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torchvision import transforms
|
| 21 |
+
from torchvision.transforms import ToTensor
|
| 22 |
+
from tqdm.auto import tqdm
|
| 23 |
+
|
| 24 |
+
from accelerate import Accelerator
|
| 25 |
+
from accelerate.logging import get_logger
|
| 26 |
+
from accelerate.utils import set_seed
|
| 27 |
+
|
| 28 |
+
import transformers
|
| 29 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 30 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 31 |
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
| 32 |
+
|
| 33 |
+
import diffusers
|
| 34 |
+
from diffusers.models import AutoencoderKL
|
| 35 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
| 36 |
+
from diffusers.optimization import get_scheduler
|
| 37 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 38 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
|
| 39 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 40 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 41 |
+
from diffusers.models.lora import LoRALinearLayer
|
| 42 |
+
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
|
| 43 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 44 |
+
from diffusers.optimization import get_scheduler
|
| 45 |
+
from diffusers.training_utils import EMAModel
|
| 46 |
+
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
|
| 47 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 48 |
+
from diffusers.models.unet_3d_blocks import \
|
| 49 |
+
(CrossAttnDownBlockSpatioTemporal,
|
| 50 |
+
DownBlockSpatioTemporal,
|
| 51 |
+
CrossAttnUpBlockSpatioTemporal,
|
| 52 |
+
UpBlockSpatioTemporal)
|
| 53 |
+
from i2vedit.utils.dataset import VideoJsonDataset, SingleVideoDataset, \
|
| 54 |
+
ImageDataset, VideoFolderDataset, CachedDataset, \
|
| 55 |
+
pad_with_ratio, return_to_original_res
|
| 56 |
+
from einops import rearrange, repeat
|
| 57 |
+
from i2vedit.utils.lora_handler import LoraHandler
|
| 58 |
+
from i2vedit.utils.lora import extract_lora_child_module
|
| 59 |
+
from i2vedit.utils.euler_utils import euler_inversion
|
| 60 |
+
from i2vedit.utils.svd_util import SmoothAreaRandomDetection
|
| 61 |
+
|
| 62 |
+
from i2vedit.data import VideoIO, SingleClipDataset, ResolutionControl
|
| 63 |
+
#from utils.model_utils import load_primary_models
|
| 64 |
+
from i2vedit.utils.euler_utils import inverse_video
|
| 65 |
+
from i2vedit.train import train_motion_lora, load_images_from_list
|
| 66 |
+
from i2vedit.inference import initialize_pipeline
|
| 67 |
+
from i2vedit.utils.model_utils import P2PEulerDiscreteScheduler, P2PStableVideoDiffusionPipeline
|
| 68 |
+
from i2vedit.prompt_attention import attention_util
|
| 69 |
+
|
| 70 |
+
def create_output_folders(output_dir, config):
|
| 71 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 72 |
+
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
|
| 73 |
+
return output_dir
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main(
|
| 78 |
+
pretrained_model_path: str,
|
| 79 |
+
data_params: Dict,
|
| 80 |
+
train_motion_lora_params: Dict,
|
| 81 |
+
sarp_params: Dict,
|
| 82 |
+
attention_matching_params: Dict,
|
| 83 |
+
long_video_params: Dict = {"mode": "skip-interval"},
|
| 84 |
+
use_sarp: bool = True,
|
| 85 |
+
use_motion_lora: bool = True,
|
| 86 |
+
train_motion_lora_only: bool = False,
|
| 87 |
+
retrain_motion_lora: bool = True,
|
| 88 |
+
use_inversed_latents: bool = True,
|
| 89 |
+
use_attention_matching: bool = True,
|
| 90 |
+
use_consistency_attention_control: bool = False,
|
| 91 |
+
output_dir: str = "./outputs",
|
| 92 |
+
num_steps: int = 25,
|
| 93 |
+
device: str = "cuda",
|
| 94 |
+
seed: int = 23,
|
| 95 |
+
enable_xformers_memory_efficient_attention: bool = True,
|
| 96 |
+
enable_torch_2_attn: bool = False,
|
| 97 |
+
dtype: str = 'fp16',
|
| 98 |
+
load_from_last_frames_latents: List[str] = None,
|
| 99 |
+
save_last_frames: bool = True,
|
| 100 |
+
visualize_attention_store: bool = False,
|
| 101 |
+
visualize_attention_store_steps: List[int] = None,
|
| 102 |
+
use_latent_blend: bool = False,
|
| 103 |
+
use_previous_latent_for_train: bool = False,
|
| 104 |
+
use_latent_noise: bool = True,
|
| 105 |
+
load_from_previous_consistency_store_controller: str = None,
|
| 106 |
+
load_from_previous_consistency_edit_controller: List[str] = None
|
| 107 |
+
):
|
| 108 |
+
*_, config = inspect.getargvalues(inspect.currentframe())
|
| 109 |
+
|
| 110 |
+
if dtype == "fp16":
|
| 111 |
+
dtype = torch.float16
|
| 112 |
+
elif dtype == "fp32":
|
| 113 |
+
dtype = torch.float32
|
| 114 |
+
|
| 115 |
+
# create folder
|
| 116 |
+
output_dir = create_output_folders(output_dir, config)
|
| 117 |
+
|
| 118 |
+
# prepare video data
|
| 119 |
+
data_params["output_dir"] = output_dir
|
| 120 |
+
data_params["device"] = device
|
| 121 |
+
|
| 122 |
+
videoio = VideoIO(**data_params, dtype=dtype)
|
| 123 |
+
|
| 124 |
+
# smooth area random perturbation
|
| 125 |
+
if use_sarp:
|
| 126 |
+
sard = SmoothAreaRandomDetection(device, dtype=torch.float32)
|
| 127 |
+
else:
|
| 128 |
+
sard = None
|
| 129 |
+
|
| 130 |
+
keyframe = None
|
| 131 |
+
previous_last_frames = load_images_from_list(data_params.keyframe_paths)
|
| 132 |
+
consistency_train_controller = None
|
| 133 |
+
|
| 134 |
+
if load_from_last_frames_latents is not None:
|
| 135 |
+
previous_last_frames_latents = [torch.load(thpath).to(device) for thpath in load_from_last_frames_latents]
|
| 136 |
+
else:
|
| 137 |
+
previous_last_frames_latents = [None,] * len(previous_last_frames)
|
| 138 |
+
|
| 139 |
+
if use_consistency_attention_control and load_from_previous_consistency_store_controller is not None:
|
| 140 |
+
previous_consistency_store_controller = attention_util.ConsistencyAttentionControl(
|
| 141 |
+
additional_attention_store=None,
|
| 142 |
+
use_inversion_attention=False,
|
| 143 |
+
save_self_attention=True,
|
| 144 |
+
save_latents=False,
|
| 145 |
+
disk_store=True,
|
| 146 |
+
load_attention_store=os.path.join(load_from_previous_consistency_store_controller, "clip_0")
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
previous_consistency_store_controller = None
|
| 150 |
+
|
| 151 |
+
previous_consistency_edit_controller_list = [None,] * len(previous_last_frames)
|
| 152 |
+
if use_consistency_attention_control and load_from_previous_consistency_edit_controller is not None:
|
| 153 |
+
for i in range(len(load_from_previous_consistency_edit_controller)):
|
| 154 |
+
previous_consistency_edit_controller_list[i] = attention_util.ConsistencyAttentionControl(
|
| 155 |
+
additional_attention_store=None,
|
| 156 |
+
use_inversion_attention=False,
|
| 157 |
+
save_self_attention=True,
|
| 158 |
+
save_latents=False,
|
| 159 |
+
disk_store=True,
|
| 160 |
+
load_attention_store=os.path.join(load_from_previous_consistency_edit_controller[i], "clip_0")
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# read data and process
|
| 165 |
+
for clip_id, video in enumerate(videoio.read_video_iter()):
|
| 166 |
+
if clip_id >= data_params.get("end_clip_id", 9):
|
| 167 |
+
break
|
| 168 |
+
if clip_id < data_params.get("begin_clip_id", 0):
|
| 169 |
+
continue
|
| 170 |
+
video = video.unsqueeze(0)
|
| 171 |
+
|
| 172 |
+
resctrl = ResolutionControl(video.shape[-2:], data_params.output_res, data_params.pad_to_fit, fill=-1)
|
| 173 |
+
|
| 174 |
+
# update keyframe and edited keyframe
|
| 175 |
+
if long_video_params.mode == "skip-interval":
|
| 176 |
+
assert data_params.overlay_size > 0
|
| 177 |
+
# save the first frame as the keyframe for cross-attention
|
| 178 |
+
#if clip_id == 0:
|
| 179 |
+
firstframe = video[:,0:1,:,:,:]
|
| 180 |
+
keyframe = video[:,0:1,:,:,:]
|
| 181 |
+
edited_keyframes = copy.deepcopy(previous_last_frames)
|
| 182 |
+
edited_firstframes = edited_keyframes
|
| 183 |
+
#edited_firstframes = load_images_from_list(data_params.keyframe_paths)
|
| 184 |
+
|
| 185 |
+
elif long_video_params.mode == "auto-regressive":
|
| 186 |
+
assert data_params.overlay_size == 1
|
| 187 |
+
firstframe = video[:,0:1,:,:,:]
|
| 188 |
+
keyframe = video[:,0:1,:,:,:]
|
| 189 |
+
edited_keyframes = copy.deepcopy(previous_last_frames)
|
| 190 |
+
edited_firstframes = edited_keyframes
|
| 191 |
+
|
| 192 |
+
# register for unet, perform inversion
|
| 193 |
+
load_attention_store = None
|
| 194 |
+
if use_attention_matching:
|
| 195 |
+
assert use_inversed_latents, "inversion is disabled."
|
| 196 |
+
if attention_matching_params.get("load_attention_store") is not None:
|
| 197 |
+
load_attention_store = os.path.join(attention_matching_params.get("load_attention_store"), f"clip_{clip_id}")
|
| 198 |
+
if not os.path.exists(load_attention_store):
|
| 199 |
+
print(f"Load {load_attention_store} failed, folder doesn't exists.")
|
| 200 |
+
load_attention_store = None
|
| 201 |
+
|
| 202 |
+
store_controller = attention_util.AttentionStore(
|
| 203 |
+
disk_store=attention_matching_params.disk_store,
|
| 204 |
+
save_latents = use_latent_blend,
|
| 205 |
+
save_self_attention=True,
|
| 206 |
+
load_attention_store=load_attention_store,
|
| 207 |
+
store_path=os.path.join(output_dir, "attention_store", f"clip_{clip_id}")
|
| 208 |
+
)
|
| 209 |
+
print("store_controller.store_dir:", store_controller.store_dir)
|
| 210 |
+
else:
|
| 211 |
+
store_controller = None
|
| 212 |
+
|
| 213 |
+
load_consistency_attention_store = None
|
| 214 |
+
if use_consistency_attention_control:
|
| 215 |
+
if clip_id==0 and attention_matching_params.get("load_consistency_attention_store") is not None:
|
| 216 |
+
load_consistency_attention_store = os.path.join(attention_matching_params.get("load_consistency_attention_store"), f"clip_{clip_id}")
|
| 217 |
+
if not os.path.exists(load_consistency_attention_store):
|
| 218 |
+
print(f"Load {load_consistency_attention_store} failed, folder doesn't exists.")
|
| 219 |
+
load_consistency_attention_store = None
|
| 220 |
+
|
| 221 |
+
consistency_store_controller = attention_util.ConsistencyAttentionControl(
|
| 222 |
+
additional_attention_store=previous_consistency_store_controller,
|
| 223 |
+
use_inversion_attention=False,
|
| 224 |
+
save_self_attention=(clip_id==0),
|
| 225 |
+
load_attention_store=load_consistency_attention_store,
|
| 226 |
+
save_latents=False,
|
| 227 |
+
disk_store=True,
|
| 228 |
+
store_path=os.path.join(output_dir, "consistency_attention_store", f"clip_{clip_id}")
|
| 229 |
+
)
|
| 230 |
+
print("consistency_store_controller.store_dir:", consistency_store_controller.store_dir)
|
| 231 |
+
else:
|
| 232 |
+
consistency_store_controller = None
|
| 233 |
+
|
| 234 |
+
if train_motion_lora_only:
|
| 235 |
+
assert use_motion_lora and retrain_motion_lora, "use_motion_lora/retrain_motion_lora should be enbled to train motion lora only."
|
| 236 |
+
|
| 237 |
+
# perform smooth area random perturbation
|
| 238 |
+
if use_inversed_latents:
|
| 239 |
+
print("begin inversion sampling for inference...")
|
| 240 |
+
inversion_noise = inverse_video(
|
| 241 |
+
pretrained_model_path,
|
| 242 |
+
video,
|
| 243 |
+
keyframe,
|
| 244 |
+
firstframe,
|
| 245 |
+
num_steps,
|
| 246 |
+
resctrl,
|
| 247 |
+
sard,
|
| 248 |
+
enable_xformers_memory_efficient_attention,
|
| 249 |
+
enable_torch_2_attn,
|
| 250 |
+
store_controller = store_controller,
|
| 251 |
+
consistency_store_controller = consistency_store_controller,
|
| 252 |
+
find_modules=attention_matching_params.registered_modules if load_attention_store is None else {},
|
| 253 |
+
consistency_find_modules=long_video_params.registered_modules if load_consistency_attention_store is None else {},
|
| 254 |
+
# dtype=dtype,
|
| 255 |
+
**sarp_params,
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
if use_motion_lora and retrain_motion_lora:
|
| 259 |
+
assert not any([np > 0 for np in train_motion_lora_params.validation_data.noise_prior]), "inversion noise is not calculated but validation during motion lora training aims to use inversion noise as input latents."
|
| 260 |
+
inversion_noise = None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if use_motion_lora:
|
| 264 |
+
if retrain_motion_lora:
|
| 265 |
+
if use_consistency_attention_control:
|
| 266 |
+
if data_params.output_res[0] != train_motion_lora_params.train_data.height or \
|
| 267 |
+
data_params.output_res[1] != train_motion_lora_params.train_data.width:
|
| 268 |
+
if consistency_train_controller is None:
|
| 269 |
+
load_consistency_train_attention_store = None
|
| 270 |
+
if attention_matching_params.get("load_consistency_train_attention_store") is not None:
|
| 271 |
+
load_consistency_train_attention_store = os.path.join(attention_matching_params.get("load_consistency_train_attention_store"), f"clip_0")
|
| 272 |
+
if not os.path.exists(load_consistency_train_attention_store):
|
| 273 |
+
print(f"Load {load_consistency_train_attention_store} failed, folder doesn't exists.")
|
| 274 |
+
load_consistency_train_attention_store = None
|
| 275 |
+
if load_consistency_train_attention_store is None and clip_id > 0:
|
| 276 |
+
raise IOError(f"load_consistency_train_attention_store can't be None for clip {clip_id}.")
|
| 277 |
+
consistency_train_controller = attention_util.ConsistencyAttentionControl(
|
| 278 |
+
additional_attention_store=None,
|
| 279 |
+
use_inversion_attention=False,
|
| 280 |
+
save_self_attention=True,
|
| 281 |
+
load_attention_store=load_consistency_train_attention_store,
|
| 282 |
+
save_latents=False,
|
| 283 |
+
disk_store=True,
|
| 284 |
+
store_path=os.path.join(output_dir, "consistency_train_attention_store", "clip_0")
|
| 285 |
+
)
|
| 286 |
+
print("consistency_train_controller.store_dir:", consistency_train_controller.store_dir)
|
| 287 |
+
resctrl_train = ResolutionControl(
|
| 288 |
+
video.shape[-2:],
|
| 289 |
+
(train_motion_lora_params.train_data.height,train_motion_lora_params.train_data.width),
|
| 290 |
+
data_params.pad_to_fit, fill=-1
|
| 291 |
+
)
|
| 292 |
+
print("begin inversion sampling for training...")
|
| 293 |
+
inversion_noise_train = inverse_video(
|
| 294 |
+
pretrained_model_path,
|
| 295 |
+
video,
|
| 296 |
+
keyframe,
|
| 297 |
+
firstframe,
|
| 298 |
+
num_steps,
|
| 299 |
+
resctrl_train,
|
| 300 |
+
sard,
|
| 301 |
+
enable_xformers_memory_efficient_attention,
|
| 302 |
+
enable_torch_2_attn,
|
| 303 |
+
store_controller = None,
|
| 304 |
+
consistency_store_controller = consistency_train_controller,
|
| 305 |
+
find_modules={},
|
| 306 |
+
consistency_find_modules=long_video_params.registered_modules if long_video_params.get("load_attention_store") is None else {},
|
| 307 |
+
# dtype=dtype,
|
| 308 |
+
**sarp_params,
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
if consistency_train_controller is None:
|
| 312 |
+
consistency_train_controller = consistency_store_controller
|
| 313 |
+
else:
|
| 314 |
+
consistency_train_controller = None
|
| 315 |
+
|
| 316 |
+
if retrain_motion_lora:
|
| 317 |
+
train_dataset = SingleClipDataset(
|
| 318 |
+
inversion_noise=inversion_noise,
|
| 319 |
+
video_clip=video,
|
| 320 |
+
keyframe=((ToTensor()(previous_last_frames[0])-0.5)/0.5).unsqueeze(0).unsqueeze(0) if use_previous_latent_for_train else keyframe,
|
| 321 |
+
keyframe_latent=previous_last_frames_latents[0] if use_previous_latent_for_train else None,
|
| 322 |
+
firstframe=firstframe,
|
| 323 |
+
height=train_motion_lora_params.train_data.height,
|
| 324 |
+
width=train_motion_lora_params.train_data.width,
|
| 325 |
+
use_data_aug=train_motion_lora_params.train_data.get("use_data_aug"),
|
| 326 |
+
pad_to_fit=train_motion_lora_params.train_data.get("pad_to_fit", False)
|
| 327 |
+
)
|
| 328 |
+
train_motion_lora_params.validation_data.num_inference_steps = num_steps
|
| 329 |
+
train_motion_lora(
|
| 330 |
+
pretrained_model_path,
|
| 331 |
+
output_dir,
|
| 332 |
+
train_dataset,
|
| 333 |
+
edited_firstframes=edited_firstframes,
|
| 334 |
+
validation_images=edited_keyframes,
|
| 335 |
+
validation_images_latents=previous_last_frames_latents,
|
| 336 |
+
seed=seed,
|
| 337 |
+
clip_id=clip_id,
|
| 338 |
+
consistency_edit_controller_list=previous_consistency_edit_controller_list,
|
| 339 |
+
consistency_controller=consistency_train_controller if clip_id!=0 else None,
|
| 340 |
+
consistency_find_modules=long_video_params.registered_modules,
|
| 341 |
+
enable_xformers_memory_efficient_attention=enable_xformers_memory_efficient_attention,
|
| 342 |
+
enable_torch_2_attn=enable_torch_2_attn,
|
| 343 |
+
**train_motion_lora_params
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
if train_motion_lora_only:
|
| 347 |
+
if not use_consistency_attention_control:
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
# choose and load motion lora
|
| 351 |
+
best_checkpoint_index = attention_matching_params.get("best_checkpoint_index", 250)
|
| 352 |
+
if retrain_motion_lora:
|
| 353 |
+
lora_dir = f"{os.path.join(output_dir,'train_motion_lora')}/clip_{clip_id}"
|
| 354 |
+
lora_path = f"{lora_dir}/checkpoint-{best_checkpoint_index}/temporal/lora"
|
| 355 |
+
else:
|
| 356 |
+
lora_path = f"/homw/user/app/upload/lora"
|
| 357 |
+
assert os.path.exists(lora_path), f"lora path: {lora_path} doesn't exist!"
|
| 358 |
+
|
| 359 |
+
lora_rank = train_motion_lora_params.lora_rank
|
| 360 |
+
lora_scale = attention_matching_params.get("lora_scale", 1.0)
|
| 361 |
+
|
| 362 |
+
# prepare models
|
| 363 |
+
pipe = initialize_pipeline(
|
| 364 |
+
pretrained_model_path,
|
| 365 |
+
device,
|
| 366 |
+
enable_xformers_memory_efficient_attention,
|
| 367 |
+
enable_torch_2_attn,
|
| 368 |
+
lora_path,
|
| 369 |
+
lora_rank,
|
| 370 |
+
lora_scale,
|
| 371 |
+
load_spatial_lora = False #(clip_id != 0)
|
| 372 |
+
).to(device, dtype=dtype)
|
| 373 |
+
else:
|
| 374 |
+
pipe = P2PStableVideoDiffusionPipeline.from_pretrained(
|
| 375 |
+
pretrained_model_path
|
| 376 |
+
).to(device, dtype=dtype)
|
| 377 |
+
|
| 378 |
+
if use_attention_matching or use_consistency_attention_control:
|
| 379 |
+
pipe.scheduler = P2PEulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 380 |
+
|
| 381 |
+
generator = torch.Generator(device="cpu")
|
| 382 |
+
generator.manual_seed(seed)
|
| 383 |
+
|
| 384 |
+
previous_last_frames = []
|
| 385 |
+
|
| 386 |
+
editing_params = [item for name, item in attention_matching_params.params.items()]
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
with torch.autocast(device, dtype=dtype):
|
| 389 |
+
for kf_id, (edited_keyframe, editing_param) in enumerate(zip(edited_keyframes, editing_params)):
|
| 390 |
+
print(kf_id, editing_param)
|
| 391 |
+
|
| 392 |
+
# control resolution
|
| 393 |
+
iw, ih = edited_keyframe.size
|
| 394 |
+
resctrl = ResolutionControl(
|
| 395 |
+
(ih, iw),
|
| 396 |
+
data_params.output_res,
|
| 397 |
+
data_params.pad_to_fit,
|
| 398 |
+
fill=0
|
| 399 |
+
)
|
| 400 |
+
edited_keyframe = resctrl(edited_keyframe)
|
| 401 |
+
edited_firstframe = resctrl(edited_firstframes[kf_id])
|
| 402 |
+
|
| 403 |
+
# control attention
|
| 404 |
+
pipe.scheduler.controller = []
|
| 405 |
+
if use_attention_matching:
|
| 406 |
+
edit_controller = attention_util.AttentionControlEdit(
|
| 407 |
+
num_steps = num_steps,
|
| 408 |
+
cross_replace_steps = attention_matching_params.cross_replace_steps,
|
| 409 |
+
temporal_self_replace_steps = attention_matching_params.temporal_self_replace_steps,
|
| 410 |
+
spatial_self_replace_steps = attention_matching_params.spatial_self_replace_steps,
|
| 411 |
+
mask_thr = editing_param.get("mask_thr", 0.35),
|
| 412 |
+
temporal_step_thr = editing_param.get("temporal_step_thr", [0.5,0.8]),
|
| 413 |
+
control_mode = attention_matching_params.control_mode,
|
| 414 |
+
spatial_attention_chunk_size = attention_matching_params.get("spatial_attention_chunk_size", 1),
|
| 415 |
+
additional_attention_store = store_controller,
|
| 416 |
+
use_inversion_attention = True,
|
| 417 |
+
save_self_attention = False,
|
| 418 |
+
save_latents = False,
|
| 419 |
+
latent_blend = use_latent_blend,
|
| 420 |
+
disk_store = attention_matching_params.disk_store
|
| 421 |
+
)
|
| 422 |
+
pipe.scheduler.controller.append(edit_controller)
|
| 423 |
+
else:
|
| 424 |
+
edit_controller = None
|
| 425 |
+
|
| 426 |
+
if use_consistency_attention_control:
|
| 427 |
+
consistency_edit_controller = attention_util.ConsistencyAttentionControl(
|
| 428 |
+
additional_attention_store=previous_consistency_edit_controller_list[kf_id],
|
| 429 |
+
use_inversion_attention=False,
|
| 430 |
+
save_self_attention=(clip_id==0),
|
| 431 |
+
save_latents=False,
|
| 432 |
+
disk_store=True,
|
| 433 |
+
store_path=os.path.join(output_dir, f"consistency_edit{kf_id}_attention_store", f"clip_{clip_id}")
|
| 434 |
+
)
|
| 435 |
+
pipe.scheduler.controller.append(consistency_edit_controller)
|
| 436 |
+
else:
|
| 437 |
+
consistency_edit_controller = None
|
| 438 |
+
|
| 439 |
+
if use_attention_matching or use_consistency_attention_control:
|
| 440 |
+
attention_util.register_attention_control(
|
| 441 |
+
pipe.unet,
|
| 442 |
+
edit_controller,
|
| 443 |
+
consistency_edit_controller,
|
| 444 |
+
find_modules=attention_matching_params.registered_modules,
|
| 445 |
+
consistency_find_modules=long_video_params.registered_modules
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# should be reorganized to perform attention control
|
| 449 |
+
edited_output = pipe(
|
| 450 |
+
edited_keyframe,
|
| 451 |
+
edited_firstframe=edited_firstframe,
|
| 452 |
+
image_latents=previous_last_frames_latents[kf_id],
|
| 453 |
+
width=data_params.output_res[1],
|
| 454 |
+
height=data_params.output_res[0],
|
| 455 |
+
num_frames=video.shape[1],
|
| 456 |
+
num_inference_steps=num_steps,
|
| 457 |
+
decode_chunk_size=8,
|
| 458 |
+
motion_bucket_id=127,
|
| 459 |
+
fps=data_params.output_fps,
|
| 460 |
+
noise_aug_strength=0.02,
|
| 461 |
+
max_guidance_scale=attention_matching_params.get("max_guidance_scale", 2.5),
|
| 462 |
+
generator=generator,
|
| 463 |
+
latents=inversion_noise
|
| 464 |
+
)
|
| 465 |
+
edited_video = [img for sublist in edited_output.frames for img in sublist]
|
| 466 |
+
edited_video_latents = edited_output.latents
|
| 467 |
+
|
| 468 |
+
# callback to replace frames
|
| 469 |
+
videoio.write_video(edited_video, kf_id, resctrl)
|
| 470 |
+
|
| 471 |
+
# save previous frames
|
| 472 |
+
if long_video_params.mode == "skip-interval":
|
| 473 |
+
#previous_latents[kf_id] = edit_controller.get_all_last_latents(data_params.overlay_size)
|
| 474 |
+
previous_last_frames.append( resctrl.callback(edited_video[-1]) )
|
| 475 |
+
if use_latent_noise:
|
| 476 |
+
previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:]
|
| 477 |
+
else:
|
| 478 |
+
previous_last_frames_latents[kf_id] = None
|
| 479 |
+
elif long_video_params.mode == "auto-regressive":
|
| 480 |
+
previous_last_frames.append( resctrl.callback(edited_video[-1]) )
|
| 481 |
+
if use_latent_noise:
|
| 482 |
+
previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:]
|
| 483 |
+
else:
|
| 484 |
+
previous_last_frames_latents[kf_id] = None
|
| 485 |
+
|
| 486 |
+
# save last frames for convenient
|
| 487 |
+
if save_last_frames:
|
| 488 |
+
try:
|
| 489 |
+
fname = os.path.join(output_dir, f"clip_{clip_id}_lastframe_{kf_id}")
|
| 490 |
+
previous_last_frames[kf_id].save(fname+".png")
|
| 491 |
+
if use_latent_noise:
|
| 492 |
+
torch.save(previous_last_frames_latents[kf_id], fname+".pt")
|
| 493 |
+
except:
|
| 494 |
+
print("save fail")
|
| 495 |
+
|
| 496 |
+
if use_attention_matching or use_consistency_attention_control:
|
| 497 |
+
attention_util.register_attention_control(
|
| 498 |
+
pipe.unet,
|
| 499 |
+
edit_controller,
|
| 500 |
+
consistency_edit_controller,
|
| 501 |
+
find_modules=attention_matching_params.registered_modules,
|
| 502 |
+
consistency_find_modules=long_video_params.registered_modules,
|
| 503 |
+
undo=True
|
| 504 |
+
)
|
| 505 |
+
if edit_controller is not None:
|
| 506 |
+
if visualize_attention_store:
|
| 507 |
+
vis_save_path = os.path.join(output_dir, "visualization", f"{kf_id}", f"clip_{clip_id}")
|
| 508 |
+
os.makedirs(vis_save_path, exist_ok=True)
|
| 509 |
+
attention_util.show_avg_difference_maps(
|
| 510 |
+
edit_controller,
|
| 511 |
+
save_path = vis_save_path
|
| 512 |
+
)
|
| 513 |
+
assert visualize_attention_store_steps is not None
|
| 514 |
+
attention_util.show_self_attention(
|
| 515 |
+
edit_controller,
|
| 516 |
+
steps = visualize_attention_store_steps,
|
| 517 |
+
save_path = vis_save_path,
|
| 518 |
+
inversed = False
|
| 519 |
+
)
|
| 520 |
+
edit_controller.delete()
|
| 521 |
+
del edit_controller
|
| 522 |
+
if use_consistency_attention_control:
|
| 523 |
+
if clip_id == 0:
|
| 524 |
+
previous_consistency_edit_controller_list[kf_id] = consistency_edit_controller
|
| 525 |
+
else:
|
| 526 |
+
consistency_edit_controller.delete()
|
| 527 |
+
del consistency_edit_controller
|
| 528 |
+
print(f"previous_consistency_edit_controller_list[{kf_id}]", previous_consistency_edit_controller_list[kf_id].store_dir)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if use_attention_matching:
|
| 532 |
+
del store_controller
|
| 533 |
+
|
| 534 |
+
if use_consistency_attention_control and clip_id == 0:
|
| 535 |
+
previous_consistency_store_controller = consistency_store_controller
|
| 536 |
+
|
| 537 |
+
videoio.close()
|
| 538 |
+
|
| 539 |
+
if use_consistency_attention_control:
|
| 540 |
+
print("consistency_store_controller for clip 0:", previous_consistency_store_controller.store_dir)
|
| 541 |
+
if retrain_motion_lora:
|
| 542 |
+
print("consistency_train_controller for clip 0:", consistency_train_controller.store_dir)
|
| 543 |
+
for kf_id in range(len(previous_consistency_edit_controller_list)):
|
| 544 |
+
print(f"previous_consistency_edit_controller_list[{kf_id}]:", previous_consistency_edit_controller_list[kf_id].store_dir)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
if __name__ == "__main__":
|
| 548 |
+
parser = argparse.ArgumentParser()
|
| 549 |
+
parser.add_argument("--config", type=str, default='./configs/svdedit/item2_2.yaml')
|
| 550 |
+
args = parser.parse_args()
|
| 551 |
+
main(**OmegaConf.load(args.config))
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
|
mydata/source_and_edits/source.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5e206a8e21c8e3b51779ad09352fa96c93e1bd2ec8ad05e2042bd554b733d15
|
| 3 |
+
size 1127699
|
mydata/source_and_edits/white.jpg
ADDED
|
Git LFS Details
|
req.txt
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
accelerate==0.30.1
|
| 3 |
+
addict==2.4.0
|
| 4 |
+
aiofiles==23.2.1
|
| 5 |
+
aiohttp==3.9.5
|
| 6 |
+
aiosignal==1.3.1
|
| 7 |
+
albucore==0.0.13
|
| 8 |
+
albumentations==1.4.13
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
antlr4-python3-runtime==4.9.3
|
| 11 |
+
anyio==4.4.0
|
| 12 |
+
appdirs==1.4.4
|
| 13 |
+
async-timeout==4.0.3
|
| 14 |
+
attrs==23.2.0
|
| 15 |
+
av==12.1.0
|
| 16 |
+
black==24.8.0
|
| 17 |
+
brotli==1.1.0
|
| 18 |
+
certifi==2024.6.2
|
| 19 |
+
charset-normalizer==3.3.2
|
| 20 |
+
click==8.1.7
|
| 21 |
+
cloudpickle==3.0.0
|
| 22 |
+
coloredlogs==15.0.1
|
| 23 |
+
contourpy==1.2.1
|
| 24 |
+
cycler==0.12.1
|
| 25 |
+
cython==3.0.11
|
| 26 |
+
decorator==4.4.2
|
| 27 |
+
decord==0.6.0
|
| 28 |
+
defusedxml==0.7.1
|
| 29 |
+
diffusers==0.25.1
|
| 30 |
+
easydict==1.13
|
| 31 |
+
einops==0.8.0
|
| 32 |
+
et-xmlfile==1.1.0
|
| 33 |
+
eval-type-backport==0.2.0
|
| 34 |
+
exceptiongroup==1.2.2
|
| 35 |
+
fairscale==0.4.13
|
| 36 |
+
fastapi==0.112.0
|
| 37 |
+
ffmpy==0.3.2
|
| 38 |
+
filelock==3.14.0
|
| 39 |
+
fire==0.6.0
|
| 40 |
+
flatbuffers==24.3.25
|
| 41 |
+
fonttools==4.53.1
|
| 42 |
+
frozenlist==1.4.1
|
| 43 |
+
fsspec==2024.6.0
|
| 44 |
+
ftfy==6.2.0
|
| 45 |
+
fvcore==0.1.5.post20221221
|
| 46 |
+
gradio==4.41.0
|
| 47 |
+
gradio-client==1.3.0
|
| 48 |
+
grpcio==1.64.1
|
| 49 |
+
h11==0.14.0
|
| 50 |
+
httpcore==1.0.5
|
| 51 |
+
httpx==0.27.0
|
| 52 |
+
huggingface-hub==0.23.3
|
| 53 |
+
humanfriendly==10.0
|
| 54 |
+
hydra-core==1.3.2
|
| 55 |
+
idna==3.7
|
| 56 |
+
imageio==2.34.1
|
| 57 |
+
imageio-ffmpeg==0.5.1
|
| 58 |
+
importlib-metadata==7.1.0
|
| 59 |
+
importlib-resources==6.4.2
|
| 60 |
+
insightface==0.7.3
|
| 61 |
+
iopath==0.1.9
|
| 62 |
+
jinja2==3.1.4
|
| 63 |
+
joblib==1.4.2
|
| 64 |
+
kiwisolver==1.4.5
|
| 65 |
+
kornia==0.7.2
|
| 66 |
+
kornia-rs==0.1.3
|
| 67 |
+
lazy-loader==0.4
|
| 68 |
+
lightning-utilities==0.3.0
|
| 69 |
+
markdown==3.6
|
| 70 |
+
markdown-it-py==3.0.0
|
| 71 |
+
markupsafe==2.1.5
|
| 72 |
+
matplotlib==3.9.2
|
| 73 |
+
mdurl==0.1.2
|
| 74 |
+
moviepy==1.0.3
|
| 75 |
+
mpmath==1.3.0
|
| 76 |
+
multidict==6.0.5
|
| 77 |
+
mutagen==1.47.0
|
| 78 |
+
mypy-extensions==1.0.0
|
| 79 |
+
networkx==3.2.1
|
| 80 |
+
ninja==1.11.1.1
|
| 81 |
+
numpy==1.26.4
|
| 82 |
+
omegaconf==2.3.0
|
| 83 |
+
onnx==1.16.2
|
| 84 |
+
onnxruntime==1.18.1
|
| 85 |
+
open-clip-torch==2.24.0
|
| 86 |
+
opencv-python==4.10.0.84
|
| 87 |
+
opencv-python-headless==4.10.0.84
|
| 88 |
+
openpyxl==3.1.5
|
| 89 |
+
orjson==3.10.7
|
| 90 |
+
packaging==24.0
|
| 91 |
+
pandas==2.2.2
|
| 92 |
+
pathspec==0.12.1
|
| 93 |
+
peft==0.11.1
|
| 94 |
+
pillow==10.3.0
|
| 95 |
+
platformdirs==4.2.2
|
| 96 |
+
portalocker==2.10.1
|
| 97 |
+
prettytable==3.11.0
|
| 98 |
+
proglog==0.1.10
|
| 99 |
+
protobuf==4.25.3
|
| 100 |
+
psutil==5.9.8
|
| 101 |
+
pyasn1==0.6.0
|
| 102 |
+
av==12.1.0
|
| 103 |
+
pycocotools==2.0.8
|
| 104 |
+
pycryptodomex==3.20.0
|
| 105 |
+
pydantic==2.8.2
|
| 106 |
+
pydantic-core==2.20.1
|
| 107 |
+
pydub==0.25.1
|
| 108 |
+
pygments==2.18.0
|
| 109 |
+
pyparsing==3.1.2
|
| 110 |
+
pysocks==1.7.1
|
| 111 |
+
python-dateutil==2.9.0.post0
|
| 112 |
+
python-multipart==0.0.9
|
| 113 |
+
pytz==2024.1
|
| 114 |
+
pyyaml==6.0.1
|
| 115 |
+
regex==2023.12.25
|
| 116 |
+
requests==2.32.3
|
| 117 |
+
rich==13.7.1
|
| 118 |
+
rpds-py==0.18.1
|
| 119 |
+
ruff==0.6.0
|
| 120 |
+
safetensors==0.4.3
|
| 121 |
+
scenedetect==0.6.4
|
| 122 |
+
scikit-image==0.24.0
|
| 123 |
+
scikit-learn==1.5.1
|
| 124 |
+
scipy==1.13.1
|
| 125 |
+
segment-anything==1.0
|
| 126 |
+
semantic-version==2.10.0
|
| 127 |
+
sentencepiece==0.1.99
|
| 128 |
+
sentry-sdk==2.5.1
|
| 129 |
+
shellingham==1.5.4
|
| 130 |
+
six==1.16.0
|
| 131 |
+
smmap==5.0.1
|
| 132 |
+
sniffio==1.3.1
|
| 133 |
+
soupsieve==2.5
|
| 134 |
+
starlette==0.37.2
|
| 135 |
+
supervision==0.22.0
|
| 136 |
+
sympy==1.12.1
|
| 137 |
+
tabulate==0.9.0
|
| 138 |
+
tensorboard==2.17.0
|
| 139 |
+
tensorboard-data-server==0.7.2
|
| 140 |
+
tensorboardx==2.6.2.2
|
| 141 |
+
termcolor==2.4.0
|
| 142 |
+
threadpoolctl==3.5.0
|
| 143 |
+
tifffile==2024.8.10
|
| 144 |
+
timm==1.0.8
|
| 145 |
+
tokenizers==0.15.2
|
| 146 |
+
tomli==2.0.1
|
| 147 |
+
tomlkit==0.12.0
|
| 148 |
+
toolz==0.12.1
|
| 149 |
+
torchmetrics==0.11.4
|
| 150 |
+
tqdm==4.66.4
|
| 151 |
+
transformers==4.36.2
|
| 152 |
+
typer==0.12.3
|
| 153 |
+
typing-extensions==4.12.2
|
| 154 |
+
tzdata==2024.1
|
| 155 |
+
urllib3==2.2.1
|
| 156 |
+
uvicorn==0.30.6
|
| 157 |
+
wcwidth==0.2.13
|
| 158 |
+
websockets==12.0
|
| 159 |
+
werkzeug==3.0.3
|
| 160 |
+
xformers==0.0.26.post1
|
| 161 |
+
yacs==0.1.8
|
| 162 |
+
yapf==0.40.2
|
| 163 |
+
yarl==1.9.4
|
| 164 |
+
yt-dlp==2024.8.6
|
| 165 |
+
zipp==3.19.2
|
requirements.txt
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
accelerate==0.30.1
|
| 5 |
+
addict==2.4.0
|
| 6 |
+
aiofiles==23.2.1
|
| 7 |
+
aiohttp==3.9.5
|
| 8 |
+
aiosignal==1.3.1
|
| 9 |
+
albucore==0.0.13
|
| 10 |
+
albumentations==1.4.13
|
| 11 |
+
annotated-types==0.7.0
|
| 12 |
+
antlr4-python3-runtime==4.9.3
|
| 13 |
+
anyio==4.4.0
|
| 14 |
+
appdirs==1.4.4
|
| 15 |
+
async-timeout==4.0.3
|
| 16 |
+
attrs==23.2.0
|
| 17 |
+
av==12.1.0
|
| 18 |
+
black==24.8.0
|
| 19 |
+
Brotli==1.1.0
|
| 20 |
+
certifi==2024.6.2
|
| 21 |
+
charset-normalizer==3.3.2
|
| 22 |
+
click==8.1.7
|
| 23 |
+
cloudpickle==3.0.0
|
| 24 |
+
colorama==0.4.6
|
| 25 |
+
coloredlogs==15.0.1
|
| 26 |
+
contourpy==1.2.1
|
| 27 |
+
cycler==0.12.1
|
| 28 |
+
Cython==3.0.11
|
| 29 |
+
decorator==4.4.2
|
| 30 |
+
decord==0.6.0
|
| 31 |
+
defusedxml==0.7.1
|
| 32 |
+
diffusers>=0.27.0
|
| 33 |
+
easydict==1.13
|
| 34 |
+
einops==0.8.0
|
| 35 |
+
et-xmlfile==1.1.0
|
| 36 |
+
eval_type_backport==0.2.0
|
| 37 |
+
exceptiongroup==1.2.2
|
| 38 |
+
fairscale==0.4.13
|
| 39 |
+
fastapi==0.112.0
|
| 40 |
+
ffmpy==0.3.2
|
| 41 |
+
filelock==3.14.0
|
| 42 |
+
fire==0.6.0
|
| 43 |
+
flatbuffers==24.3.25
|
| 44 |
+
fonttools==4.53.1
|
| 45 |
+
frozenlist==1.4.1
|
| 46 |
+
fsspec==2024.6.0
|
| 47 |
+
ftfy==6.2.0
|
| 48 |
+
fvcore==0.1.5.post20221221
|
| 49 |
+
gradio==4.41.0
|
| 50 |
+
gradio_client==1.3.0
|
| 51 |
+
grpcio==1.64.1
|
| 52 |
+
h11==0.14.0
|
| 53 |
+
httpcore==1.0.5
|
| 54 |
+
httpx==0.27.0
|
| 55 |
+
huggingface-hub==0.23.3
|
| 56 |
+
humanfriendly==10.0
|
| 57 |
+
hydra-core==1.3.2
|
| 58 |
+
idna==3.7
|
| 59 |
+
imageio==2.34.1
|
| 60 |
+
imageio-ffmpeg==0.5.1
|
| 61 |
+
importlib_metadata==7.1.0
|
| 62 |
+
importlib_resources==6.4.2
|
| 63 |
+
insightface==0.7.3
|
| 64 |
+
intel-openmp==2021.4.0
|
| 65 |
+
iopath==0.1.9
|
| 66 |
+
Jinja2==3.1.4
|
| 67 |
+
joblib==1.4.2
|
| 68 |
+
kiwisolver==1.4.5
|
| 69 |
+
kornia==0.7.2
|
| 70 |
+
kornia_rs==0.1.3
|
| 71 |
+
lazy_loader==0.4
|
| 72 |
+
lightning-utilities==0.3.0
|
| 73 |
+
Markdown==3.6
|
| 74 |
+
markdown-it-py==3.0.0
|
| 75 |
+
MarkupSafe==2.1.5
|
| 76 |
+
matplotlib==3.9.2
|
| 77 |
+
mdurl==0.1.2
|
| 78 |
+
mkl==2021.4.0
|
| 79 |
+
moviepy==1.0.3
|
| 80 |
+
mpmath==1.3.0
|
| 81 |
+
multidict==6.0.5
|
| 82 |
+
mutagen==1.47.0
|
| 83 |
+
mypy-extensions==1.0.0
|
| 84 |
+
networkx==3.2.1
|
| 85 |
+
ninja==1.11.1.1
|
| 86 |
+
numpy==1.26.4
|
| 87 |
+
omegaconf==2.3.0
|
| 88 |
+
onnx==1.16.2
|
| 89 |
+
onnxruntime==1.18.1
|
| 90 |
+
open-clip-torch==2.24.0
|
| 91 |
+
opencv-python==4.10.0.84
|
| 92 |
+
opencv-python-headless==4.10.0.84
|
| 93 |
+
openpyxl==3.1.5
|
| 94 |
+
orjson==3.10.7
|
| 95 |
+
packaging==24.0
|
| 96 |
+
pandas==2.2.2
|
| 97 |
+
pathspec==0.12.1
|
| 98 |
+
peft==0.11.1
|
| 99 |
+
pillow==10.3.0
|
| 100 |
+
platformdirs==4.2.2
|
| 101 |
+
portalocker==2.10.1
|
| 102 |
+
prettytable==3.11.0
|
| 103 |
+
proglog==0.1.10
|
| 104 |
+
protobuf==4.25.3
|
| 105 |
+
psutil==5.9.8
|
| 106 |
+
pyasn1==0.6.0
|
| 107 |
+
pycocotools==2.0.8
|
| 108 |
+
pycryptodomex==3.20.0
|
| 109 |
+
pydantic==2.8.2
|
| 110 |
+
pydantic_core==2.20.1
|
| 111 |
+
pydub==0.25.1
|
| 112 |
+
Pygments==2.18.0
|
| 113 |
+
pyparsing==3.1.2
|
| 114 |
+
pyreadline3==3.5.4
|
| 115 |
+
PySocks==1.7.1
|
| 116 |
+
python-dateutil==2.9.0.post0
|
| 117 |
+
python-multipart==0.0.9
|
| 118 |
+
pytz==2024.1
|
| 119 |
+
PyYAML==6.0.1
|
| 120 |
+
regex==2023.12.25
|
| 121 |
+
requests==2.32.3
|
| 122 |
+
rich==13.7.1
|
| 123 |
+
rpds-py==0.18.1
|
| 124 |
+
ruff==0.6.0
|
| 125 |
+
safetensors==0.4.3
|
| 126 |
+
scenedetect==0.6.4
|
| 127 |
+
scikit-image==0.24.0
|
| 128 |
+
scikit-learn==1.5.1
|
| 129 |
+
scipy==1.13.1
|
| 130 |
+
segment-anything==1.0
|
| 131 |
+
semantic-version==2.10.0
|
| 132 |
+
sentencepiece==0.1.99
|
| 133 |
+
sentry-sdk==2.5.1
|
| 134 |
+
setuptools==69.5.1
|
| 135 |
+
shellingham==1.5.4
|
| 136 |
+
six==1.16.0
|
| 137 |
+
smmap==5.0.1
|
| 138 |
+
sniffio==1.3.1
|
| 139 |
+
soupsieve==2.5
|
| 140 |
+
starlette==0.37.2
|
| 141 |
+
supervision==0.22.0
|
| 142 |
+
sympy==1.12.1
|
| 143 |
+
tabulate==0.9.0
|
| 144 |
+
tbb==2021.13.1
|
| 145 |
+
tensorboard==2.17.0
|
| 146 |
+
tensorboard-data-server==0.7.2
|
| 147 |
+
tensorboardX==2.6.2.2
|
| 148 |
+
termcolor==2.4.0
|
| 149 |
+
threadpoolctl==3.5.0
|
| 150 |
+
tifffile==2024.8.10
|
| 151 |
+
timm==1.0.8
|
| 152 |
+
tokenizers==0.15.2
|
| 153 |
+
tomli==2.0.1
|
| 154 |
+
tomlkit==0.12.0
|
| 155 |
+
toolz==0.12.1
|
| 156 |
+
torch==2.3.0+cu121
|
| 157 |
+
torchaudio==2.3.0+cu121
|
| 158 |
+
torchmetrics==0.11.4
|
| 159 |
+
torchvision==0.18.0+cu121
|
| 160 |
+
tqdm==4.66.4
|
| 161 |
+
transformers==4.36.2
|
| 162 |
+
typer==0.12.3
|
| 163 |
+
typing_extensions==4.12.2
|
| 164 |
+
tzdata==2024.1
|
| 165 |
+
urllib3==2.2.1
|
| 166 |
+
uvicorn==0.30.6
|
| 167 |
+
wcwidth==0.2.13
|
| 168 |
+
websockets==12.0
|
| 169 |
+
Werkzeug==3.0.3
|
| 170 |
+
wheel==0.43.0
|
| 171 |
+
xformers==0.0.26.post1
|
| 172 |
+
yacs==0.1.8
|
| 173 |
+
yapf==0.40.2
|
| 174 |
+
yarl==1.9.4
|
| 175 |
+
yt-dlp==2024.8.6
|
| 176 |
+
zipp==3.19.2
|
setup.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from setuptools import setup, find_packages
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
|
| 7 |
+
WORK_DIR = "i2vedit"
|
| 8 |
+
NAME = "i2vedit"
|
| 9 |
+
author = "wenqi.oywq"
|
| 10 |
+
author_email = '[email protected]'
|
| 11 |
+
|
| 12 |
+
version_file = 'i2vedit/version.py'
|
| 13 |
+
|
| 14 |
+
def get_hash():
|
| 15 |
+
if False:#os.path.exists('.git'):
|
| 16 |
+
sha = get_git_hash()[:7]
|
| 17 |
+
# currently ignore this
|
| 18 |
+
# elif os.path.exists(version_file):
|
| 19 |
+
# try:
|
| 20 |
+
# from basicsr.version import __version__
|
| 21 |
+
# sha = __version__.split('+')[-1]
|
| 22 |
+
# except ImportError:
|
| 23 |
+
# raise ImportError('Unable to get git version')
|
| 24 |
+
else:
|
| 25 |
+
sha = 'unknown'
|
| 26 |
+
|
| 27 |
+
return sha
|
| 28 |
+
|
| 29 |
+
def get_version():
|
| 30 |
+
with open(version_file, 'r') as f:
|
| 31 |
+
exec(compile(f.read(), version_file, 'exec'))
|
| 32 |
+
return locals()['__version__']
|
| 33 |
+
|
| 34 |
+
def write_version_py():
|
| 35 |
+
content = """# GENERATED VERSION FILE
|
| 36 |
+
# TIME: {}
|
| 37 |
+
__version__ = '{}'
|
| 38 |
+
__gitsha__ = '{}'
|
| 39 |
+
version_info = ({})
|
| 40 |
+
"""
|
| 41 |
+
sha = get_hash()
|
| 42 |
+
with open('VERSION', 'r') as f:
|
| 43 |
+
SHORT_VERSION = f.read().strip()
|
| 44 |
+
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
| 45 |
+
|
| 46 |
+
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
| 47 |
+
with open(version_file, 'w') as f:
|
| 48 |
+
f.write(version_file_str)
|
| 49 |
+
|
| 50 |
+
REQUIRE = [
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
def install_requires(REQUIRE):
|
| 54 |
+
for item in REQUIRE:
|
| 55 |
+
os.system(f'pip install {item}')
|
| 56 |
+
|
| 57 |
+
write_version_py()
|
| 58 |
+
install_requires(REQUIRE)
|
| 59 |
+
setup(
|
| 60 |
+
name=NAME,
|
| 61 |
+
packages=find_packages(),
|
| 62 |
+
version=get_version(),
|
| 63 |
+
description="image-to-video editing",
|
| 64 |
+
author=author,
|
| 65 |
+
author_email=author_email,
|
| 66 |
+
keywords=["image-to-video editing"],
|
| 67 |
+
install_requires=[],
|
| 68 |
+
include_package_data=False,
|
| 69 |
+
exclude_package_data={'':['.gitignore','README.md','./configs','./outputs'],
|
| 70 |
+
},
|
| 71 |
+
entry_points={'console_scripts': ['pyinstrument = pyinstrument.__main__:main']},
|
| 72 |
+
zip_safe=False,
|
| 73 |
+
)
|