weiyuyeh commited on
Commit
a45ed83
·
1 Parent(s): 0daf590
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mydata/source_and_edits/source.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ mydata/source_and_edits/white.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cache
2
+ ckpts
3
+ trash
4
+ outputs
5
+ run.sh
6
+ __pycache__
7
+ *_tmp
8
+ *.mp4
9
+ *.png
10
+ i2vedit.egg-info/
11
+ !mydata/**/*.mp4
12
+ customize_train_local.yaml
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: I2vedit
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: I2vedit
3
+ emoji: 📈
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.32.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import shutil
5
+ import sys
6
+
7
+ target_paths = {
8
+ "video": "/home/user/app/upload/source_and_edits/source.mp4",
9
+ "image": "/home/user/app/upload/source_and_edits/ref.jpg",
10
+ "config": "/home/user/app/upload/config/customize_train.yaml",
11
+ "lora": "/homw/user/app/upload/lora/lora.pt",
12
+ "output_l": "/home/user/app/outputs/train_motion_lora",
13
+ "output_r": "/home/user/app/outputs/ref.mp4",
14
+ "zip": "/home/user/app/outputs/train_motion_lora.zip",
15
+ }
16
+
17
+
18
+ def zip_outputs():
19
+ if os.path.exists(target_paths["zip"]):
20
+ os.remove(target_paths["zip"])
21
+ shutil.make_archive(target_paths["zip"].replace(".zip", ""), 'zip', root_dir=target_paths["output_l"])
22
+ return target_paths["zip"]
23
+
24
+ def output_video():
25
+ if os.path.exists(target_paths["output_r"]):
26
+ return target_paths["output_r"]
27
+ return None
28
+
29
+
30
+ def start_training_stream():
31
+ process = subprocess.Popen(
32
+ ["python", "main.py", "--config=" + target_paths["config"]],
33
+ stdout=subprocess.PIPE,
34
+ stderr=subprocess.STDOUT,
35
+ text=True,
36
+ bufsize=1,
37
+ universal_newlines=True
38
+ )
39
+
40
+ output = []
41
+ for line in process.stdout:
42
+ output.append(line)
43
+ yield "".join(output)
44
+
45
+ def install_i2vedit():
46
+ try:
47
+ import i2vedit
48
+ print("i2vedit already installed")
49
+ except ImportError:
50
+ print("Installing i2vedit...")
51
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./i2vedit"])
52
+ print("i2vedit installed")
53
+
54
+
55
+ def install_package(package_name):
56
+ try:
57
+ result = subprocess.run(
58
+ [sys.executable, "-m", "pip", "install", package_name],
59
+ stdout=subprocess.PIPE,
60
+ stderr=subprocess.PIPE,
61
+ text=True,
62
+ )
63
+ output = result.stdout + "\n" + result.stderr
64
+ return output
65
+ except Exception as e:
66
+ return f"Error: {str(e)}"
67
+
68
+
69
+ def show_package(pkg_name):
70
+ try:
71
+ result = subprocess.run(
72
+ [sys.executable, "-m", "pip", "show", pkg_name],
73
+ stdout=subprocess.PIPE,
74
+ stderr=subprocess.PIPE,
75
+ text=True,
76
+ )
77
+ return result.stdout if result.stdout else result.stderr
78
+ except Exception as e:
79
+ return str(e)
80
+
81
+
82
+ def uninstall_package(package_name):
83
+ try:
84
+ result = subprocess.run(
85
+ [sys.executable, "-m", "pip", "uninstall", package_name, "-y"],
86
+ stdout=subprocess.PIPE,
87
+ stderr=subprocess.PIPE,
88
+ text=True,
89
+ )
90
+ output = result.stdout + "\n" + result.stderr
91
+ return output
92
+ except Exception as e:
93
+ return f"Error: {str(e)}"
94
+
95
+
96
+
97
+ def save_files(video_file, image_file, config_file, lora_file=None):
98
+ os.makedirs(os.path.dirname(target_paths["video"]), exist_ok=True)
99
+ os.makedirs(os.path.dirname(target_paths["config"]), exist_ok=True)
100
+
101
+ shutil.copy(video_file.name, target_paths["video"])
102
+ shutil.copy(image_file.name, target_paths["image"])
103
+ shutil.copy(config_file.name, target_paths["config"])
104
+ if lora_file:
105
+ os.makedirs(os.path.dirname(target_paths["lora"]), exist_ok=True)
106
+ shutil.copy(lora_file.name, target_paths["lora"])
107
+ return "檔案已成功上傳並儲存!"
108
+
109
+
110
+ install_i2vedit()
111
+ install_package("huggingface_hub==0.25.1")
112
+ install_package("diffusers==0.25.1")
113
+ install_package("gradio==5.0.0")
114
+ uninstall_package("datasets")
115
+ print("package version set complete")
116
+
117
+
118
+ with gr.Blocks(theme=gr.themes.Origin()) as demo:
119
+ gr.Markdown("## 請先上傳檔案")
120
+ with gr.Row():
121
+ video_input = gr.File(label="原始影片", file_types=[".mp4"])
122
+ image_input = gr.File(label="編輯圖像", file_types=[".jpg", ".jpeg", ".png"])
123
+ config_input = gr.File(label="Config 檔", file_types=[".yaml", ".yml"])
124
+ lora_input = gr.File(label="LoRA 檔案", file_types=[".pt"])
125
+
126
+ upload_button = gr.Button("上傳並儲存")
127
+ output = gr.Textbox(label="狀態")
128
+
129
+
130
+ gr.Markdown("## Training")
131
+ with gr.Column():
132
+ log_output = gr.Textbox(label="Training Log", lines=20)
133
+ train_btn = gr.Button("Start Training")
134
+
135
+ gr.Markdown("## Pip Installer")
136
+ with gr.Column():
137
+ with gr.Row():
138
+ pkg_input = gr.Textbox(lines=1, placeholder="輸入想安裝的套件名稱,例如 diffusers 或 numpy==1.2.0")
139
+ install_output = gr.Textbox(label="Install Output", lines=10)
140
+ install_btn = gr.Button("Install Package")
141
+
142
+ gr.Markdown("## Pip Uninstaller")
143
+ with gr.Column():
144
+ with gr.Row():
145
+ pkg_input2 = gr.Textbox(lines=1, placeholder="輸入想解除安裝的套件名稱,例如 diffusers 或 numpy")
146
+ uninstall_output = gr.Textbox(label="Uninstall Output", lines=10)
147
+ uninstall_btn = gr.Button("Uninstall Package")
148
+
149
+ gr.Markdown("## Pip show")
150
+ with gr.Column():
151
+ with gr.Row():
152
+ show_input = gr.Textbox(label="輸入套件名稱(如 diffusers)")
153
+ show_output = gr.Textbox(label="套件資訊", lines=10)
154
+ show_btn = gr.Button("pip show")
155
+
156
+ gr.Markdown("## Download lora")
157
+ with gr.Column():
158
+ file_output = gr.File(label="點擊下載", interactive=True)
159
+ download_btn = gr.Button("下載lora")
160
+
161
+ gr.Markdown("## Download results")
162
+ with gr.Column():
163
+ file_output2 = gr.File(label="點擊下載", interactive=True)
164
+ download_btn2 = gr.Button("下載結果")
165
+
166
+ show_btn.click(fn=show_package, inputs=show_input, outputs=show_output)
167
+ download_btn.click(fn=zip_outputs, outputs=file_output)
168
+ download_btn2.click(fn=output_video, outputs=file_output2)
169
+ install_btn.click(fn=install_package, inputs=pkg_input, outputs=install_output)
170
+ train_btn.click(fn=start_training_stream, outputs=log_output)
171
+ uninstall_btn.click(fn=uninstall_package, inputs=pkg_input2, outputs=uninstall_output)
172
+ upload_button.click(fn=save_files,inputs=[video_input, image_input, config_input, lora_input],outputs=output)
173
+ demo.launch()
config/customize_subsequent_edit.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretrained diffusers model path.
2
+ pretrained_model_path: "ckpts/stable-video-diffusion-img2vid"
3
+ # The folder where your training outputs will be placed.
4
+ output_dir: "./acc"
5
+ seed: 23
6
+ num_steps: 25
7
+ # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
8
+ enable_xformers_memory_efficient_attention: True
9
+ # Use scaled dot product attention (Only available with >= Torch 2.0)
10
+ enable_torch_2_attn: True
11
+
12
+ use_sarp: true
13
+
14
+ use_motion_lora: true
15
+ train_motion_lora_only: false
16
+ retrain_motion_lora: true
17
+
18
+ use_inversed_latents: true
19
+ use_attention_matching: true
20
+ use_consistency_attention_control: true
21
+ dtype: fp16
22
+
23
+ visualize_attention_store: false
24
+ visualize_attention_store_steps: [0, 5, 10, 15, 20, 24]
25
+
26
+ save_last_frames: True
27
+ load_from_last_frames_latents:
28
+ - "./cache/item1/i2vedit_2024-05-11T15-53-54/clip_0_lastframe_0.pt"
29
+ - "./cache/item1/i2vedit_2024-05-11T15-53-54/clip_0_lastframe_1.pt"
30
+ load_from_previous_consistency_edit_controller:
31
+ - "./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_edit0_attention_store"
32
+ - "./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_edit1_attention_store"
33
+ load_from_previous_consistency_store_controller:
34
+ "./cache/item1/i2vedit_2024-05-11T15-53-54/consistency_attention_store"
35
+
36
+ # data_params
37
+ data_params:
38
+ video_path: "../datasets/svdedit/item1/source.mp4"
39
+ keyframe_paths:
40
+ - "../datasets/svdedit/tmp/edit0.png"
41
+ - "../datasets/svdedit/tmp/edit1.png"
42
+ start_t: 0
43
+ end_t: -1
44
+ sample_fps: 7
45
+ chunk_size: 16
46
+ overlay_size: 1
47
+ normalize: true
48
+ output_fps: 7
49
+ save_sampled_frame: true
50
+ output_res: [576, 1024]
51
+ pad_to_fit: false
52
+ begin_clip_id: 1
53
+ end_clip_id: 2
54
+
55
+ train_motion_lora_params:
56
+ cache_latents: true
57
+ cached_latent_dir: null #/path/to/cached_latents
58
+ lora_rank: 32
59
+ # Use LoRA for the UNET model.
60
+ use_unet_lora: True
61
+ # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
62
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
63
+ lora_unet_dropout: 0.1
64
+ # The only time you want this off is if you're doing full LoRA training.
65
+ save_pretrained_model: False
66
+ # Learning rate for AdamW
67
+ learning_rate: 5e-4
68
+ # Weight decay. Higher = more regularization. Lower = closer to dataset.
69
+ adam_weight_decay: 1e-2
70
+ # Maximum number of train steps. Model is saved after training.
71
+ max_train_steps: 250
72
+ # Saves a model every nth step.
73
+ checkpointing_steps: 250
74
+ # How many steps to do for validation if sample_preview is enabled.
75
+ validation_steps: 300
76
+ # Whether or not we want to use mixed precision with accelerate
77
+ mixed_precision: "fp16"
78
+ # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
79
+ # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
80
+ gradient_checkpointing: True
81
+ image_encoder_gradient_checkpointing: True
82
+
83
+ train_data:
84
+ # The width and height in which you want your training data to be resized to.
85
+ width: 896
86
+ height: 512
87
+ # This will find the closest aspect ratio to your input width and height.
88
+ # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
89
+ use_data_aug: ~ #"controlnet"
90
+ pad_to_fit: false
91
+
92
+ validation_data:
93
+ # Whether or not to sample preview during training (Requires more VRAM).
94
+ sample_preview: True
95
+ # The number of frames to sample during validation.
96
+ num_frames: 14
97
+ # Height and width of validation sample.
98
+ width: 1024
99
+ height: 576
100
+ pad_to_fit: false
101
+ # scale of spatial LoRAs, default is 0
102
+ spatial_scale: 0
103
+ # scale of noise prior, i.e. the scale of inversion noises
104
+ noise_prior:
105
+ - 0.0
106
+ #- 1.0
107
+
108
+ sarp_params:
109
+ sarp_noise_scale: 0.005
110
+
111
+ attention_matching_params:
112
+ best_checkpoint_index: 250
113
+ lora_scale: 1.0
114
+ # lora path
115
+ lora_dir: ~
116
+ max_guidance_scale: 2.0
117
+ disk_store: True
118
+ load_attention_store: "./cache/item1/attention_store"
119
+ load_consistency_attention_store: ~
120
+ registered_modules:
121
+ BasicTransformerBlock:
122
+ - "attn1"
123
+ #- "attn2"
124
+ TemporalBasicTransformerBlock:
125
+ - "attn1"
126
+ #- "attn2"
127
+ control_mode:
128
+ spatial_self: "masked_copy"
129
+ temporal_self: "copy_v2"
130
+ cross_replace_steps: 0.0
131
+ temporal_self_replace_steps: 1.0
132
+ spatial_self_replace_steps: 1.0
133
+ spatial_attention_chunk_size: 1
134
+
135
+ params:
136
+ edit0:
137
+ temporal_step_thr: [0.5, 0.8]
138
+ mask_thr: [0.35, 0.35]
139
+ edit1:
140
+ temporal_step_thr: [0.5, 0.8]
141
+ mask_thr: [0.35, 0.35]
142
+
143
+ long_video_params:
144
+ mode: "skip-interval"
145
+ registered_modules:
146
+ BasicTransformerBlock:
147
+ #- "attn1"
148
+ #- "attn2"
149
+ TemporalBasicTransformerBlock:
150
+ - "attn1"
151
+ #- "attn2"
152
+
config/customize_train.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretrained diffusers model path.
2
+ # Don't change
3
+ pretrained_model_path: "stabilityai/stable-video-diffusion-img2vid"
4
+ # The folder where your training outputs will be placed.
5
+ # Don't change
6
+ output_dir: "/home/user/app/outputs"
7
+ seed: 23
8
+ num_steps: 25
9
+ # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
10
+ enable_xformers_memory_efficient_attention: True
11
+ # Use scaled dot product attention (Only available with >= Torch 2.0)
12
+ enable_torch_2_attn: True
13
+
14
+ use_sarp: true
15
+
16
+ use_motion_lora: true
17
+ train_motion_lora_only: false
18
+ retrain_motion_lora: true
19
+
20
+ use_inversed_latents: true
21
+ use_attention_matching: true
22
+ use_consistency_attention_control: false
23
+ dtype: fp16
24
+
25
+ visualize_attention_store: false
26
+ visualize_attention_store_steps: #[0, 5, 10, 15, 20, 24]
27
+
28
+ save_last_frames: True
29
+ load_from_last_frames_latents:
30
+
31
+ # data_params
32
+ data_params:
33
+ # Don't change
34
+ video_path: "/home/user/app/upload/source_and_edits/source.mp4"
35
+ # Don't change
36
+ keyframe_paths:
37
+ - "/home/user/app/upload/source_and_edits/ref.jpg"
38
+ start_t: 0
39
+ end_t: 1.6
40
+ sample_fps: 10
41
+ chunk_size: 16
42
+ overlay_size: 1
43
+ normalize: true
44
+ output_fps: 3
45
+ save_sampled_frame: true
46
+ output_res: [576, 576]
47
+ pad_to_fit: true
48
+ begin_clip_id: 0
49
+ end_clip_id: 1
50
+
51
+ train_motion_lora_params:
52
+ cache_latents: true
53
+ cached_latent_dir: null #/path/to/cached_latents
54
+ lora_rank: 32
55
+ # Use LoRA for the UNET model.
56
+ use_unet_lora: True
57
+ # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
58
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
59
+ lora_unet_dropout: 0.1
60
+ # The only time you want this off is if you're doing full LoRA training.
61
+ save_pretrained_model: False
62
+ # Learning rate for AdamW
63
+ learning_rate: 5e-4
64
+ # Weight decay. Higher = more regularization. Lower = closer to dataset.
65
+ adam_weight_decay: 1e-2
66
+ # Maximum number of train steps. Model is saved after training.
67
+ max_train_steps: 600
68
+ # Saves a model every nth step.
69
+ checkpointing_steps: 100
70
+ # How many steps to do for validation if sample_preview is enabled.
71
+ validation_steps: 100
72
+ # Whether or not we want to use mixed precision with accelerate
73
+ mixed_precision: "fp16"
74
+ # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
75
+ # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
76
+ gradient_checkpointing: True
77
+ image_encoder_gradient_checkpointing: True
78
+
79
+ train_data:
80
+ # The width and height in which you want your training data to be resized to.
81
+ width: 576
82
+ height: 576
83
+ # This will find the closest aspect ratio to your input width and height.
84
+ # For example, 576x576 width and height with a video of resolution 1280x720 will be resized to 576x256
85
+ use_data_aug: ~ #"controlnet"
86
+ pad_to_fit: true
87
+
88
+ validation_data:
89
+ # Whether or not to sample preview during training (Requires more VRAM).
90
+ sample_preview: True
91
+ # The number of frames to sample during validation.
92
+ num_frames: 16
93
+ # Height and width of validation sample.
94
+ width: 576
95
+ height: 576
96
+ pad_to_fit: true
97
+ # scale of spatial LoRAs, default is 0
98
+ spatial_scale: 0
99
+ # scale of noise prior, i.e. the scale of inversion noises
100
+ noise_prior:
101
+ #- 0.0
102
+ - 1.0
103
+
104
+ sarp_params:
105
+ sarp_noise_scale: 0.005
106
+
107
+ attention_matching_params:
108
+ best_checkpoint_index: 500
109
+ lora_scale: 1.0
110
+ # lora path
111
+ lora_dir: ~
112
+ max_guidance_scale: 2.0
113
+
114
+ disk_store: True
115
+ load_attention_store: ~
116
+ load_consistency_attention_store: ~
117
+ load_consistency_train_attention_store: ~
118
+ registered_modules:
119
+ BasicTransformerBlock:
120
+ - "attn1"
121
+ #- "attn2"
122
+ TemporalBasicTransformerBlock:
123
+ - "attn1"
124
+ #- "attn2"
125
+ control_mode:
126
+ spatial_self: "masked_copy"
127
+ temporal_self: "copy_v2"
128
+ cross_replace_steps: 0.0
129
+ temporal_self_replace_steps: 1.0
130
+ spatial_self_replace_steps: 1.0
131
+ spatial_attention_chunk_size: 1
132
+
133
+ params:
134
+ edit0:
135
+ temporal_step_thr: [0.5, 0.8]
136
+ mask_thr: [0.35, 0.35]
137
+ edit1:
138
+ temporal_step_thr: [0.5, 0.8]
139
+ mask_thr: [0.35, 0.35]
140
+
141
+ long_video_params:
142
+ mode: "skip-interval"
143
+ registered_modules:
144
+ BasicTransformerBlock:
145
+ #- "attn1"
146
+ #- "attn2"
147
+ TemporalBasicTransformerBlock:
148
+ - "attn1"
149
+ #- "attn2"
config/customize_train_multi.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretrained diffusers model path.
2
+ # Don't change
3
+ pretrained_model_path: "stabilityai/stable-video-diffusion-img2vid"
4
+ # The folder where your training outputs will be placed.
5
+ # Don't change
6
+ output_dir: "/home/user/app/outputs"
7
+ seed: 23
8
+ num_steps: 25
9
+ # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
10
+ enable_xformers_memory_efficient_attention: True
11
+ # Use scaled dot product attention (Only available with >= Torch 2.0)
12
+ enable_torch_2_attn: True
13
+
14
+ use_sarp: true
15
+
16
+ use_motion_lora: true
17
+ train_motion_lora_only: false
18
+ retrain_motion_lora: true
19
+
20
+ use_inversed_latents: true
21
+ use_attention_matching: true
22
+ use_consistency_attention_control: true
23
+ dtype: fp16
24
+
25
+ visualize_attention_store: false
26
+ visualize_attention_store_steps: #[0, 5, 10, 15, 20, 24]
27
+
28
+ save_last_frames: True
29
+ load_from_last_frames_latents:
30
+
31
+ # data_params
32
+ data_params:
33
+ # Don't change
34
+ video_path: "/home/user/app/upload/source_and_edits/source.mp4"
35
+ # Don't change
36
+ keyframe_paths:
37
+ - "/home/user/app/upload/source_and_edits/ref.jpg"
38
+ start_t: 0
39
+ end_t: 4.0
40
+ sample_fps: 10
41
+ chunk_size: 12
42
+ overlay_size: 3
43
+ normalize: true
44
+ output_fps: 10
45
+ save_sampled_frame: true
46
+ output_res: [768, 768]
47
+ pad_to_fit: true
48
+ begin_clip_id: 0
49
+ end_clip_id: 4
50
+
51
+ train_motion_lora_params:
52
+ cache_latents: true
53
+ cached_latent_dir: null #/path/to/cached_latents
54
+ lora_rank: 32
55
+ # Use LoRA for the UNET model.
56
+ use_unet_lora: True
57
+ # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
58
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
59
+ lora_unet_dropout: 0.1
60
+ # The only time you want this off is if you're doing full LoRA training.
61
+ save_pretrained_model: False
62
+ # Learning rate for AdamW
63
+ learning_rate: 5e-4
64
+ # Weight decay. Higher = more regularization. Lower = closer to dataset.
65
+ adam_weight_decay: 1e-2
66
+ # Maximum number of train steps. Model is saved after training.
67
+ max_train_steps: 600
68
+ # Saves a model every nth step.
69
+ checkpointing_steps: 200
70
+ # How many steps to do for validation if sample_preview is enabled.
71
+ validation_steps: 200
72
+ # Whether or not we want to use mixed precision with accelerate
73
+ mixed_precision: "fp16"
74
+ # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
75
+ # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
76
+ gradient_checkpointing: True
77
+ image_encoder_gradient_checkpointing: True
78
+
79
+ train_data:
80
+ # The width and height in which you want your training data to be resized to.
81
+ width: 768
82
+ height: 768
83
+ # This will find the closest aspect ratio to your input width and height.
84
+ # For example, 768x768 width and height with a video of resolution 1280x720 will be resized to 768x256
85
+ use_data_aug: ~ #"controlnet"
86
+ pad_to_fit: true
87
+
88
+ validation_data:
89
+ # Whether or not to sample preview during training (Requires more VRAM).
90
+ sample_preview: True
91
+ # The number of frames to sample during validation.
92
+ num_frames: 8
93
+ # Height and width of validation sample.
94
+ width: 768
95
+ height: 768
96
+ pad_to_fit: true
97
+ # scale of spatial LoRAs, default is 0
98
+ spatial_scale: 0
99
+ # scale of noise prior, i.e. the scale of inversion noises
100
+ noise_prior:
101
+ #- 0.0
102
+ - 1.0
103
+
104
+ sarp_params:
105
+ sarp_noise_scale: 0.005
106
+
107
+ attention_matching_params:
108
+ best_checkpoint_index: 600
109
+ lora_scale: 1.0
110
+ # lora path
111
+ lora_dir: ~
112
+ max_guidance_scale: 2.0
113
+
114
+ disk_store: True
115
+ load_attention_store: ~
116
+ load_consistency_attention_store: ~
117
+ load_consistency_train_attention_store: ~
118
+ registered_modules:
119
+ BasicTransformerBlock:
120
+ - "attn1"
121
+ #- "attn2"
122
+ TemporalBasicTransformerBlock:
123
+ - "attn1"
124
+ #- "attn2"
125
+ control_mode:
126
+ spatial_self: "masked_copy"
127
+ temporal_self: "copy_v2"
128
+ cross_replace_steps: 0.0
129
+ temporal_self_replace_steps: 1.0
130
+ spatial_self_replace_steps: 1.0
131
+ spatial_attention_chunk_size: 1
132
+
133
+ params:
134
+ edit0:
135
+ temporal_step_thr: [0.5, 0.8]
136
+ mask_thr: [0.35, 0.35]
137
+ edit1:
138
+ temporal_step_thr: [0.5, 0.8]
139
+ mask_thr: [0.35, 0.35]
140
+
141
+ long_video_params:
142
+ mode: "skip-interval"
143
+ registered_modules:
144
+ BasicTransformerBlock:
145
+ #- "attn1"
146
+ #- "attn2"
147
+ TemporalBasicTransformerBlock:
148
+ - "attn1"
149
+ #- "attn2"
i2vedit/__init__.py ADDED
File without changes
i2vedit/data.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import decord
3
+ import imageio
4
+ import numpy as np
5
+ import PIL
6
+ from PIL import Image
7
+ from einops import rearrange, repeat
8
+
9
+ from torchvision.transforms import Resize, Pad, InterpolationMode, ToTensor, InterpolationMode
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset
14
+
15
+ #from i2vedit.utils.augment import ControlNetDataAugmentation, ColorDataAugmentation
16
+ # from utils.euler_utils import tensor_to_vae_latent
17
+
18
+ class ResolutionControl(object):
19
+
20
+ def __init__(self, input_res, output_res, pad_to_fit=False, fill=0, **kwargs):
21
+
22
+ self.ih, self.iw = input_res
23
+ self.output_res = output_res
24
+ self.pad_to_fit = pad_to_fit
25
+ self.fill=fill
26
+
27
+ def pad_with_ratio(self, frames, res, fill=0):
28
+ if isinstance(frames, torch.Tensor):
29
+ original_dim = frames.ndim
30
+ if frames.ndim > 4:
31
+ batch_size = frames.shape[0]
32
+ frames = rearrange(frames, "b f c h w -> (b f) c h w")
33
+ _, _, ih, iw = frames.shape
34
+ elif isinstance(frames, PIL.Image.Image):
35
+ iw, ih = frames.size
36
+ assert ih == self.ih and iw == self.iw, "resolution doesn't match."
37
+ #print("ih, iw", ih, iw)
38
+ i_ratio = ih / iw
39
+ h, w = res
40
+ #print("h,w", h ,w)
41
+ n_ratio = h / w
42
+ if i_ratio > n_ratio:
43
+ nw = int(ih / h * w)
44
+ #print("nw", nw)
45
+ frames = Pad(((nw - iw)//2,0), fill=fill)(frames)
46
+ else:
47
+ nh = int(iw / w * h)
48
+ frames = Pad((0,(nh - ih)//2), fill=fill)(frames)
49
+ #print("after pad", frames.shape)
50
+ if isinstance(frames, torch.Tensor):
51
+ if original_dim > 4:
52
+ frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
53
+
54
+ return frames
55
+
56
+ def return_to_original_res(self, frames):
57
+ if isinstance(frames, torch.Tensor):
58
+ original_dim = frames.ndim
59
+ if frames.ndim > 4:
60
+ batch_size = frames.shape[0]
61
+ frames = rearrange(frames, "b f c h w -> (b f) c h w")
62
+ _, _, h, w = frames.shape
63
+ elif isinstance(frames, PIL.Image.Image):
64
+ w, h = frames.size
65
+ #print("original res", (self.ih, self.iw))
66
+ #print("current res", (h, w))
67
+ assert h == self.output_res[0] and w == self.output_res[1], "resolution doesn't match."
68
+ n_ratio = h / w
69
+ ih, iw = self.ih, self.iw
70
+ i_ratio = ih / iw
71
+ if self.pad_to_fit:
72
+ if i_ratio > n_ratio:
73
+ nw = int(ih / h * w)
74
+ frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
75
+ if isinstance(frames, torch.Tensor):
76
+ frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
77
+ elif isinstance(frames, PIL.Image.Image):
78
+ frames = frames.crop(((nw - iw)//2,0,iw+(nw - iw)//2,ih))
79
+ else:
80
+ nh = int(iw / w * h)
81
+ frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
82
+ if isinstance(frames, torch.Tensor):
83
+ frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
84
+ elif isinstance(frames, PIL.Image.Image):
85
+ frames = frames.crop((0,(nh - ih)//2,iw,ih+(nh - ih)//2))
86
+ else:
87
+ frames = Resize((ih, iw), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
88
+
89
+ if isinstance(frames, torch.Tensor):
90
+ if original_dim > 4:
91
+ frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
92
+
93
+ return frames
94
+
95
+ def __call__(self, frames):
96
+ if self.pad_to_fit:
97
+ frames = self.pad_with_ratio(frames, self.output_res, fill=self.fill)
98
+
99
+ if isinstance(frames, torch.Tensor):
100
+ original_dim = frames.ndim
101
+ if frames.ndim > 4:
102
+ batch_size = frames.shape[0]
103
+ frames = rearrange(frames, "b f c h w -> (b f) c h w")
104
+ frames = (frames + 1) / 2.
105
+
106
+ frames = Resize(tuple(self.output_res), interpolation=InterpolationMode.BILINEAR, antialias=True)(frames)
107
+ if isinstance(frames, torch.Tensor):
108
+ if original_dim > 4:
109
+ frames = rearrange(frames, "(b f) c h w -> b f c h w", b=batch_size)
110
+ frames = frames * 2 - 1
111
+
112
+ return frames
113
+
114
+ def callback(self, frames):
115
+ return self.return_to_original_res(frames)
116
+
117
+ class VideoIO(object):
118
+
119
+ def __init__(
120
+ self,
121
+ video_path,
122
+ keyframe_paths,
123
+ output_dir,
124
+ device,
125
+ dtype,
126
+ start_t:int=0,
127
+ end_t:int=-1,
128
+ sample_fps:int=-1,
129
+ chunk_size: int=14,
130
+ overlay_size: int=-1,
131
+ normalize: bool=True,
132
+ output_fps: int=-1,
133
+ save_sampled_video: bool=True,
134
+ **kwargs
135
+ ):
136
+ self.video_path = video_path
137
+ self.keyframe_paths = keyframe_paths
138
+ self.device = device
139
+ self.dtype = dtype
140
+ self.start_t = start_t
141
+ self.end_t = end_t
142
+ self.sample_fps = sample_fps
143
+ self.chunk_size = chunk_size
144
+ self.overlay_size = overlay_size
145
+ self.normalize = normalize
146
+ self.save_sampled_video = save_sampled_video
147
+
148
+
149
+
150
+ vr = decord.VideoReader(video_path)
151
+ initial_fps = vr.get_avg_fps()
152
+ self.initial_fps = initial_fps
153
+
154
+ if output_fps == -1: output_fps = initial_fps
155
+
156
+ self.video_writer_list = []
157
+ for keyframe_path in keyframe_paths:
158
+ fname, ext = os.path.splitext(os.path.basename(keyframe_path))
159
+ output_video_path = os.path.join(output_dir, fname+".mp4")
160
+ self.video_writer_list.append( imageio.get_writer(output_video_path, fps=output_fps) )
161
+
162
+ if save_sampled_video:
163
+ fname, ext = os.path.splitext(os.path.basename(video_path))
164
+ output_sampled_video_path = os.path.join(output_dir, fname+f"_from{start_t}s_to{end_t}s{ext}")
165
+ self.sampled_video_writer = imageio.get_writer(output_sampled_video_path, fps=output_fps)
166
+
167
+ def read_keyframe_iter(self):
168
+ for keyframe_path in self.keyframe_paths:
169
+ image = Image.open(keyframe_path).convert("RGB")
170
+ yield image
171
+
172
+ def read_video_iter(self):
173
+ vr = decord.VideoReader(self.video_path)
174
+ if self.sample_fps == -1: self.sample_fps = self.initial_fps
175
+ if self.end_t == -1:
176
+ self.end_t = len(vr) / self.initial_fps
177
+ else:
178
+ self.end_t = min(len(vr) / self.initial_fps, self.end_t)
179
+ if self.overlay_size == -1: self.overlay_size = 0
180
+ assert 0 <= self.start_t < self.end_t
181
+ assert self.sample_fps > 0
182
+
183
+ start_f_ind = int(self.start_t * self.initial_fps)
184
+ end_f_ind = int(self.end_t * self.initial_fps)
185
+ num_f = int((self.end_t - self.start_t) * self.sample_fps)
186
+ sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
187
+ print("sample_idx", sample_idx)
188
+
189
+ assert len(sample_idx) > 0, f"sample_idx is empty!"
190
+
191
+ begin_frame_idx = 0
192
+ while begin_frame_idx < len(sample_idx):
193
+ self.begin_frame_idx = begin_frame_idx
194
+ begin_frame_idx = max(begin_frame_idx - self.overlay_size, 0)
195
+ next_frame_idx = min(begin_frame_idx + self.chunk_size, len(sample_idx))
196
+
197
+ video = vr.get_batch(sample_idx[begin_frame_idx:next_frame_idx])
198
+ begin_frame_idx = next_frame_idx
199
+
200
+ if self.save_sampled_video:
201
+ overlay_size = 0 if self.begin_frame_idx == 0 else self.overlay_size
202
+ print(type(video))
203
+ for frame in video[overlay_size:]:
204
+ self.sampled_video_writer.append_data(frame.detach().cpu().numpy())
205
+
206
+ video = torch.Tensor(video).to(self.device).to(self.dtype)
207
+ video = rearrange(video, "f h w c -> f c h w")
208
+
209
+ if self.normalize:
210
+ video = video / 127.5 - 1.0
211
+
212
+ yield video
213
+
214
+ def write_video(self, video, video_id, resctrl: ResolutionControl = None):
215
+ '''
216
+ video:
217
+ '''
218
+ overlay_size = 0 if self.begin_frame_idx == 0 else self.overlay_size
219
+ for img in video[overlay_size:]:
220
+ if resctrl is not None:
221
+ img = resctrl.callback(img)
222
+ self.video_writer_list[video_id].append_data(np.array(img))
223
+
224
+ def close(self):
225
+ for video_writer in self.video_writer_list:
226
+ video_writer.close()
227
+ if self.save_sampled_video:
228
+ self.sampled_video_writer.close()
229
+ self.begin_frame_idx = 0
230
+
231
+
232
+ class SingleClipDataset(Dataset):
233
+
234
+ # data_aug_class = {
235
+ # "rsfnet": ColorDataAugmentation,
236
+ # "controlnet": ControlNetDataAugmentation
237
+ # }
238
+
239
+ def __init__(
240
+ self,
241
+ inversion_noise,
242
+ video_clip,
243
+ keyframe,
244
+ firstframe,
245
+ height,
246
+ width,
247
+ use_data_aug=None,
248
+ pad_to_fit=False,
249
+ keyframe_latent=None
250
+ ):
251
+
252
+ self.resctrl = ResolutionControl(video_clip.shape[-2:],(height,width),pad_to_fit,fill=-1)
253
+
254
+ video_clip = rearrange(video_clip, "1 f c h w -> f c h w")
255
+ keyframe = rearrange(keyframe, "1 f c h w -> f c h w")
256
+ firstframe = rearrange(firstframe, "1 f c h w -> f c h w")
257
+
258
+ if inversion_noise is not None:
259
+ inversion_noise = rearrange(inversion_noise, "1 f c h w -> f c h w")
260
+
261
+ if use_data_aug is not None:
262
+ if use_data_aug in self.data_aug_class:
263
+ self.data_augment = self.data_aug_class[use_data_aug]()
264
+ use_data_aug = True
265
+ print(f"Augmentation mode: {use_data_aug} is implemented.")
266
+ else:
267
+ raise NotImplementedError(f"Augmentation mode: {use_data_aug} is not implemented!")
268
+ else:
269
+ use_data_aug = False
270
+
271
+ self.video_clip = video_clip
272
+ self.keyframe = keyframe
273
+ self.firstframe = firstframe
274
+ self.inversion_noise = inversion_noise
275
+ self.use_data_aug = use_data_aug
276
+ self.keyframe_latent = keyframe_latent
277
+
278
+ @staticmethod
279
+ def __getname__(): return 'single_clip'
280
+
281
+ def __len__(self):
282
+ return 1
283
+
284
+ def __getitem__(self, index):
285
+
286
+ motion_values = torch.Tensor([127.])
287
+
288
+ pixel_values = self.resctrl(self.video_clip)
289
+ refer_pixel_values = self.resctrl(self.keyframe)
290
+ cross_pixel_values = self.resctrl(self.firstframe)
291
+
292
+ if self.use_data_aug:
293
+ print("pixel_values before augment", refer_pixel_values.min(), refer_pixel_values.max())
294
+ #pixel_values, refer_pixel_values, cross_pixel_values = \
295
+ #self.data_augment.augment(
296
+ # torch.cat([pixel_values, refer_pixel_values, cross_pixel_values], dim=0)
297
+ #).tensor_split([pixel_values.shape[0],pixel_values.shape[0]+refer_pixel_values.shape[0]],dim=0)
298
+ refer_pixel_values = self.data_augment.augment(refer_pixel_values)
299
+ print("pixel_values after augment", refer_pixel_values.min(), refer_pixel_values.max())
300
+
301
+ outputs = {
302
+ "pixel_values": pixel_values,
303
+ "refer_pixel_values": refer_pixel_values,
304
+ "cross_pixel_values": cross_pixel_values,
305
+ "motion_values": motion_values,
306
+ 'dataset': self.__getname__(),
307
+ }
308
+
309
+ if self.inversion_noise is not None:
310
+ outputs.update({
311
+ "inversion_noise": self.inversion_noise
312
+ })
313
+ if self.keyframe_latent is not None:
314
+ outputs.update({
315
+ "refer_latents": self.keyframe_latent
316
+ })
317
+ return outputs
i2vedit/inference.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import re
5
+ import warnings
6
+ import imageio
7
+ import random
8
+ from typing import Optional
9
+ from tqdm import trange
10
+ from einops import rearrange
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ from torch.nn.functional import interpolate
15
+ from diffusers import StableVideoDiffusionPipeline, EulerDiscreteScheduler
16
+ from diffusers import TextToVideoSDPipeline
17
+
18
+ from i2vedit.train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
19
+ from i2vedit.utils.lora_handler import LoraHandler
20
+ from i2vedit.utils.model_utils import P2PStableVideoDiffusionPipeline
21
+
22
+
23
+ def initialize_pipeline(
24
+ model: str,
25
+ device: str = "cuda",
26
+ xformers: bool = False,
27
+ sdp: bool = False,
28
+ lora_path: str = "",
29
+ lora_rank: int = 64,
30
+ lora_scale: float = 1.0,
31
+ load_spatial_lora: bool = False,
32
+ dtype = torch.float16
33
+ ):
34
+ with warnings.catch_warnings():
35
+ warnings.simplefilter("ignore")
36
+
37
+ scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(model)
38
+
39
+ # Freeze any necessary models
40
+ freeze_models([vae, image_encoder, unet])
41
+
42
+ # Enable xformers if available
43
+ handle_memory_attention(xformers, sdp, unet)
44
+
45
+ lora_manager_temporal = LoraHandler(
46
+ version="cloneofsimo",
47
+ use_unet_lora=True,
48
+ use_image_lora=False,
49
+ save_for_webui=False,
50
+ only_for_webui=False,
51
+ unet_replace_modules=["TemporalBasicTransformerBlock"],
52
+ image_encoder_replace_modules=None,
53
+ lora_bias=None
54
+ )
55
+
56
+ unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
57
+ True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
58
+
59
+ if load_spatial_lora:
60
+ lora_manager_spatial = LoraHandler(
61
+ version="cloneofsimo",
62
+ use_unet_lora=True,
63
+ use_image_lora=False,
64
+ save_for_webui=False,
65
+ only_for_webui=False,
66
+ unet_replace_modules=["BasicTransformerBlock"],
67
+ image_encoder_replace_modules=None,
68
+ lora_bias=None
69
+ )
70
+
71
+ spatial_lora_path = lora_path.replace("temporal", "spatial")
72
+ unet_lora_params, unet_negation = lora_manager_spatial.add_lora_to_model(
73
+ True, unet, lora_manager_spatial.unet_replace_modules, 0, spatial_lora_path, r=lora_rank, scale=lora_scale)
74
+
75
+ unet.eval()
76
+ image_encoder.eval()
77
+ unet_and_text_g_c(unet, image_encoder, False, False)
78
+
79
+ pipe = P2PStableVideoDiffusionPipeline.from_pretrained(
80
+ model,
81
+ scheduler=scheduler,
82
+ feature_extractor=feature_extractor,
83
+ image_encoder=image_encoder.to(device=device, dtype=dtype),
84
+ vae=vae.to(device=device, dtype=dtype),
85
+ unet=unet.to(device=device, dtype=dtype)
86
+ )
87
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
88
+
89
+ return pipe
i2vedit/prompt_attention/__init__.py ADDED
File without changes
i2vedit/prompt_attention/attention_register.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ register the attention controller into the UNet of stable diffusion
3
+ Build a customized attention function `_attention'
4
+ Replace the original attention function with `forward' and `spatial_temporal_forward' in attention_controlled_forward function
5
+ Most of spatial_temporal_forward is directly copy from `video_diffusion/models/attention.py'
6
+ TODO FIXME: merge redundant code with attention.py
7
+ """
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ import logging
15
+ from einops import rearrange, repeat
16
+ import math
17
+ from inspect import isfunction
18
+ from typing import Any, Optional
19
+ from packaging import version
20
+
21
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
22
+ from diffusers.utils import USE_PEFT_BACKEND
23
+
24
+ class AttnControllerProcessor:
25
+
26
+ def __init__(self, consistency_controller, controller, place_in_unet, attention_type):
27
+
28
+ self.consistency_controller = consistency_controller
29
+ self.controller = controller
30
+ self.place_in_unet = place_in_unet
31
+ self.attention_type = attention_type
32
+
33
+ def __call__(
34
+ self,
35
+ attn: Attention,
36
+ hidden_states: torch.FloatTensor,
37
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
38
+ attention_mask: Optional[torch.FloatTensor] = None,
39
+ temb: Optional[torch.FloatTensor] = None,
40
+ scale: float = 1.0,
41
+ ) -> torch.FloatTensor:
42
+ residual = hidden_states
43
+ if attn.spatial_norm is not None:
44
+ hidden_states = attn.spatial_norm(hidden_states, temb)
45
+
46
+ input_ndim = hidden_states.ndim
47
+
48
+ if input_ndim == 4:
49
+ batch_size, channel, height, width = hidden_states.shape
50
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
51
+
52
+ batch_size, sequence_length, _ = (
53
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
54
+ )
55
+
56
+ if attention_mask is not None:
57
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
58
+ # scaled_dot_product_attention expects attention_mask shape to be
59
+ # (batch, heads, source_length, target_length)
60
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
61
+
62
+ if attn.group_norm is not None:
63
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
64
+
65
+ args = () if USE_PEFT_BACKEND else (scale,)
66
+ query = attn.to_q(hidden_states, *args)
67
+
68
+ is_cross = True
69
+ if encoder_hidden_states is None:
70
+ encoder_hidden_states = hidden_states
71
+ is_cross = False
72
+ elif attn.norm_cross:
73
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
74
+
75
+ key = attn.to_k(encoder_hidden_states, *args)
76
+ value = attn.to_v(encoder_hidden_states, *args)
77
+
78
+ inner_dim = key.shape[-1]
79
+ head_dim = inner_dim // attn.heads
80
+
81
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82
+
83
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
84
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
85
+
86
+ if self.consistency_controller is not None:
87
+ key = self.consistency_controller(
88
+ key, is_cross, f"{self.place_in_unet}_{self.attention_type}_k"
89
+ )
90
+ value = self.consistency_controller(
91
+ value, is_cross, f"{self.place_in_unet}_{self.attention_type}_v"
92
+ )
93
+
94
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
95
+ # TODO: add support for attn.scale when we move to Torch 2.1
96
+ if self.controller is not None:
97
+ hidden_states = self.controller.attention_control(
98
+ self.place_in_unet, self.attention_type, is_cross,
99
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
100
+ )
101
+ else:
102
+ hidden_states = F.scaled_dot_product_attention(
103
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
104
+ )
105
+
106
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
107
+ hidden_states = hidden_states.to(query.dtype)
108
+
109
+ # linear proj
110
+ hidden_states = attn.to_out[0](hidden_states, *args)
111
+ # dropout
112
+ hidden_states = attn.to_out[1](hidden_states)
113
+
114
+ if input_ndim == 4:
115
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
116
+
117
+ if attn.residual_connection:
118
+ hidden_states = hidden_states + residual
119
+
120
+ hidden_states = hidden_states / attn.rescale_output_factor
121
+
122
+ return hidden_states
123
+
124
+ def register_attention_control(
125
+ model,
126
+ controller=None,
127
+ consistency_controller=None,
128
+ find_modules = {},
129
+ consistency_find_modules = {},
130
+ undo=False
131
+ ):
132
+ "Connect a model with a controller"
133
+ class DummyController:
134
+
135
+ def __call__(self, *args):
136
+ return args[0]
137
+
138
+ def __init__(self):
139
+ self.num_att_layers = 0
140
+
141
+ #if controller is None:
142
+ # controller = DummyController()
143
+
144
+ f_keys = list(set(find_modules.keys()).difference(set(consistency_find_modules.keys())))
145
+ c_keys = list(set(consistency_find_modules.keys()).difference(set(find_modules.keys())))
146
+ common_keys = list(set(find_modules.keys()).intersection(set(consistency_find_modules.keys())))
147
+ new_find_modules = {}
148
+ for f_key in f_keys:
149
+ new_find_modules.update({
150
+ f_key: find_modules[f_key]
151
+ })
152
+ new_consistency_find_modules = {}
153
+ for c_key in c_keys:
154
+ new_consistency_find_modules.update({
155
+ c_key: consistency_find_modules[c_key]
156
+ })
157
+ common_modules = {}
158
+ for key in common_keys:
159
+ find_modules[key] = [] if find_modules[key] is None else find_modules[key]
160
+ consistency_find_modules[key] = [] if consistency_find_modules[key] is None else consistency_find_modules[key]
161
+ f_list = list(set(find_modules[key]).difference(set(consistency_find_modules[key])))
162
+ c_list = list(set(consistency_find_modules[key]).difference(set(find_modules[key])))
163
+ common_list = list(set(find_modules[key]).intersection(set(consistency_find_modules[key])))
164
+ if len(f_list) > 0:
165
+ new_find_modules.update({key: f_list})
166
+ if len(c_list) > 0:
167
+ new_consistency_find_modules.update({key: c_list})
168
+ if len(common_list) > 0:
169
+ common_modules.update({key: common_list})
170
+
171
+ find_modules = new_find_modules
172
+ consistency_find_modules = new_consistency_find_modules
173
+
174
+ print("common_modules", common_modules)
175
+ print("find_modules", find_modules)
176
+ print("consistency_find_modules", consistency_find_modules)
177
+ print("controller", controller, "consistency_controller", consistency_controller)
178
+
179
+ def register_recr(net_, count1, count2, place_in_unet):
180
+
181
+ if net_[1].__class__.__name__ == 'BasicTransformerBlock':
182
+ attention_type = 'spatial'
183
+ elif net_[1].__class__.__name__ == 'TemporalBasicTransformerBlock':
184
+ attention_type = 'temporal'
185
+
186
+ control1, control2 = None, None
187
+ if net_[1].__class__.__name__ in common_modules.keys():
188
+ control1, control2 = consistency_controller, controller
189
+ module_list = common_modules[net_[1].__class__.__name__]
190
+ elif net_[1].__class__.__name__ in find_modules.keys():
191
+ control1, control2 = None, controller
192
+ module_list = find_modules[net_[1].__class__.__name__]
193
+ elif net_[1].__class__.__name__ in consistency_find_modules.keys():
194
+ control1, control2 = consistency_controller, None
195
+ module_list = consistency_find_modules[net_[1].__class__.__name__]
196
+
197
+ if any([control is not None for control in [control1, control2]]):
198
+
199
+ if module_list is not None and 'attn1' in module_list:
200
+ if undo:
201
+ net_[1].attn1.set_processor(AttnProcessor2_0())
202
+ else:
203
+ net_[1].attn1.set_processor(AttnControllerProcessor(control1, control2, place_in_unet, attention_type = attention_type))
204
+ if control1 is not None: count1 += 1
205
+ if control2 is not None: count2 += 1
206
+
207
+ if module_list is not None and 'attn2' in module_list:
208
+ if undo:
209
+ net_[1].attn2.set_processor(AttnProcessor2_0())
210
+ else:
211
+ net_[1].attn2.set_processor(AttnControllerProcessor(control1, control2, place_in_unet, attention_type = attention_type))
212
+ if control1 is not None: count1 += 1
213
+ if control2 is not None: count2 += 1
214
+
215
+ return count1, count2
216
+
217
+ elif hasattr(net_[1], 'children'):
218
+ for net in net_[1].named_children():
219
+ count1, count2 = register_recr(net, count1, count2, place_in_unet)
220
+
221
+ return count1, count2
222
+
223
+ cross_att_count1 = 0
224
+ cross_att_count2 = 0
225
+ sub_nets = model.named_children()
226
+ for net in sub_nets:
227
+ if "down" in net[0]:
228
+ c1, c2 = register_recr(net, 0, 0, "down")
229
+ cross_att_count1 += c1
230
+ cross_att_count2 += c2
231
+ elif "up" in net[0]:
232
+ c1, c2 = register_recr(net, 0, 0, "up")
233
+ cross_att_count1 += c1
234
+ cross_att_count2 += c2
235
+ elif "mid" in net[0]:
236
+ c1, c2 = register_recr(net, 0, 0, "mid")
237
+ cross_att_count1 += c1
238
+ cross_att_count2 += c2
239
+ if undo:
240
+ print(f"Number of attention layer unregistered for controller: {cross_att_count2}")
241
+ print(f"Number of attention layer unregistered for consistency_controller: {cross_att_count1}")
242
+ else:
243
+ print(f"Number of attention layer registered for controller: {cross_att_count2}")
244
+ if controller is not None:
245
+ controller.num_att_layers = cross_att_count2
246
+ print(f"Number of attention layer registered for consistency_controller: {cross_att_count1}")
247
+ if consistency_controller is not None:
248
+ consistency_controller.num_att_layers = cross_att_count1
249
+
250
+
i2vedit/prompt_attention/attention_store.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py
3
+
4
+ """
5
+
6
+ import abc
7
+ import os
8
+ import copy
9
+ import shutil
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from packaging import version
13
+ from einops import rearrange
14
+ import math
15
+
16
+ from i2vedit.prompt_attention.common.util import get_time_string
17
+
18
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
19
+ SDP_IS_AVAILABLE = True
20
+ from torch.backends.cuda import SDPBackend, sdp_kernel
21
+
22
+ BACKEND_MAP = {
23
+ SDPBackend.MATH: {
24
+ "enable_math": True,
25
+ "enable_flash": False,
26
+ "enable_mem_efficient": False,
27
+ },
28
+ SDPBackend.FLASH_ATTENTION: {
29
+ "enable_math": False,
30
+ "enable_flash": True,
31
+ "enable_mem_efficient": False,
32
+ },
33
+ SDPBackend.EFFICIENT_ATTENTION: {
34
+ "enable_math": False,
35
+ "enable_flash": False,
36
+ "enable_mem_efficient": True,
37
+ },
38
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
39
+ }
40
+ else:
41
+ from contextlib import nullcontext
42
+
43
+ SDP_IS_AVAILABLE = False
44
+ sdp_kernel = nullcontext
45
+ BACKEND_MAP = {}
46
+ logpy.warn(
47
+ f"No SDP backend available, likely because you are running in pytorch "
48
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
49
+ f"You might want to consider upgrading."
50
+ )
51
+
52
+ class AttentionControl(abc.ABC):
53
+
54
+ def step_callback(self, x_t):
55
+ self.cur_att_layer = 0
56
+ self.cur_step += 1
57
+ self.between_steps()
58
+ return x_t
59
+
60
+ def between_steps(self):
61
+ return
62
+
63
+ @property
64
+ def num_uncond_att_layers(self):
65
+ """I guess the diffusion of google has some unconditional attention layer
66
+ No unconditional attention layer in Stable diffusion
67
+
68
+ Returns:
69
+ _type_: _description_
70
+ """
71
+ # return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0
72
+ return 0
73
+
74
+ @abc.abstractmethod
75
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
76
+ raise NotImplementedError
77
+
78
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
79
+ if self.cur_att_layer >= self.num_uncond_att_layers:
80
+ if self.LOW_RESOURCE or 'mask' in place_in_unet:
81
+ # For inversion without null text file
82
+ attn = self.forward(attn, is_cross, place_in_unet)
83
+ else:
84
+ # For classifier-free guidance scale!=1
85
+ h = attn.shape[0]
86
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
87
+ self.cur_att_layer += 1
88
+
89
+ return attn
90
+
91
+ def reset(self):
92
+ self.cur_step = 0
93
+ self.cur_att_layer = 0
94
+
95
+ def __init__(self,
96
+ ):
97
+ self.LOW_RESOURCE = False # assume the edit have cfg
98
+ self.cur_step = 0
99
+ self.num_att_layers = -1
100
+ self.cur_att_layer = 0
101
+
102
+
103
+ class AttentionStore(AttentionControl):
104
+ def step_callback(self, x_t):
105
+
106
+ x_t = super().step_callback(x_t)
107
+ if self.save_latents:
108
+ self.latents_store.append(x_t.cpu().detach())
109
+ return x_t
110
+
111
+ @staticmethod
112
+ def get_empty_store():
113
+ return {"down_spatial_q_cross": [], "mid_spatial_q_cross": [], "up_spatial_q_cross": [],
114
+ "down_spatial_k_cross": [], "mid_spatial_k_cross": [], "up_spatial_k_cross": [],
115
+ "down_spatial_mask_cross": [], "mid_spatial_mask_cross": [], "up_spatial_mask_cross": [],
116
+ "down_temporal_cross": [], "mid_temporal_cross": [], "up_temporal_cross": [],
117
+ "down_spatial_q_self": [], "mid_spatial_q_self": [], "up_spatial_q_self": [],
118
+ "down_spatial_k_self": [], "mid_spatial_k_self": [], "up_spatial_k_self": [],
119
+ "down_spatial_mask_self": [], "mid_spatial_mask_self": [], "up_spatial_mask_self": [],
120
+ "down_spatial_self": [], "mid_spatial_self": [], "up_spatial_self": [],
121
+ "down_temporal_self": [], "mid_temporal_self": [], "up_temporal_self": []}
122
+
123
+ @staticmethod
124
+ def get_empty_cross_store():
125
+ return {"down_spatial_q_cross": [], "mid_spatial_q_cross": [], "up_spatial_q_cross": [],
126
+ "down_spatial_k_cross": [], "mid_spatial_k_cross": [], "up_spatial_k_cross": [],
127
+ "down_spatial_mask_cross": [], "mid_spatial_mask_cross": [], "up_spatial_mask_cross": [],
128
+ "down_temporal_cross": [], "mid_temporal_cross": [], "up_temporal_cross": [],
129
+ }
130
+
131
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
132
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
133
+ if attn.shape[-2] <= 8*9*8*16: # avoid memory overload
134
+ # print(f"Store attention map {key} of shape {attn.shape}")
135
+ if (is_cross or self.save_self_attention or 'mask' in key):
136
+ if False:#attn.shape[-2] >= 4*9*4*16:
137
+ append_tensor = attn.cpu().detach()
138
+ else:
139
+ append_tensor = attn
140
+ self.step_store[key].append(copy.deepcopy(append_tensor))
141
+ # FIXME: Are these deepcopy all necessary?
142
+ # self.step_store[key].append(append_tensor)
143
+ return attn
144
+
145
+ def between_steps(self):
146
+ if len(self.attention_store) == 0:
147
+ self.attention_store = {key: self.step_store[key] for key in self.step_store if 'mask' in key}
148
+ else:
149
+ for key in self.attention_store:
150
+ if 'mask' in key:
151
+ for i in range(len(self.attention_store[key])):
152
+ self.attention_store[key][i] += self.step_store[key][i]
153
+
154
+ if self.disk_store:
155
+ path = self.store_dir + f'/{self.cur_step:03d}.pt'
156
+ if self.load_attention_store is None:
157
+ torch.save(copy.deepcopy(self.step_store), path)
158
+ self.attention_store_all_step.append(path)
159
+ else:
160
+ self.attention_store_all_step.append(copy.deepcopy(self.step_store))
161
+ self.step_store = self.get_empty_store()
162
+
163
+ def get_average_attention(self):
164
+ "divide the attention map value in attention store by denoising steps"
165
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store if 'mask' in key}
166
+ return average_attention
167
+
168
+
169
+ def reset(self):
170
+ super(AttentionStore, self).reset()
171
+ self.step_store = self.get_empty_store()
172
+ self.attention_store_all_step = []
173
+ if self.disk_store:
174
+ if self.load_attention_store is not None:
175
+ flist = sorted(os.listdir(self.load_attention_store), key=lambda x: int(x[:-3]))
176
+ self.attention_store_all_step = [
177
+ os.path.join(self.load_attention_store, fn) for fn in flist
178
+ ]
179
+ self.attention_store = {}
180
+
181
+ def __init__(self,
182
+ save_self_attention:bool=True,
183
+ save_latents:bool=True,
184
+ disk_store=False,
185
+ load_attention_store:str=None,
186
+ store_path:str=None
187
+ ):
188
+ super(AttentionStore, self).__init__()
189
+ self.disk_store = disk_store
190
+ if load_attention_store is not None:
191
+ if not os.path.exists(load_attention_store):
192
+ print(f"can not load attentions from {load_attention_store}: file doesn't exist.")
193
+ load_attention_store = None
194
+ else:
195
+ assert self.disk_store, f"can not load attentions from {load_attention_store} because disk_store is disabled."
196
+ self.attention_store_all_step = []
197
+ if self.disk_store:
198
+ if load_attention_store is not None:
199
+ self.store_dir = load_attention_store
200
+ flist = sorted([fpath for fpath in os.listdir(load_attention_store) if "inverted" not in fpath], key=lambda x: int(x[:-3]))
201
+ self.attention_store_all_step = [
202
+ os.path.join(load_attention_store, fn) for fn in flist
203
+ ]
204
+ else:
205
+ if store_path is None:
206
+ time_string = get_time_string()
207
+ path = f'./trash/{self.__class__.__name__}_attention_cache_{time_string}'
208
+ else:
209
+ path = store_path
210
+ os.makedirs(path, exist_ok=True)
211
+ self.store_dir = path
212
+ else:
213
+ self.store_dir =None
214
+ self.step_store = self.get_empty_store()
215
+ self.attention_store = {}
216
+ self.save_self_attention = save_self_attention
217
+ self.latents_store = []
218
+
219
+ self.save_latents = save_latents
220
+ self.load_attention_store = load_attention_store
221
+
222
+ def delete(self):
223
+ if self.disk_store:
224
+ try:
225
+ shutil.rmtree(self.store_dir)
226
+ print(f"Successfully remove {self.store_dir}")
227
+ except:
228
+ print(f"Fail to remove {self.store_dir}")
229
+
230
+ def attention_control(
231
+ self, place_in_unet, attention_type, is_cross,
232
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
233
+ ):
234
+ if attention_type == "temporal":
235
+
236
+ return self.temporal_attention_control(
237
+ place_in_unet, attention_type, is_cross,
238
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
239
+ )
240
+
241
+ elif attention_type == "spatial":
242
+
243
+ return self.spatial_attention_control(
244
+ place_in_unet, attention_type, is_cross,
245
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
246
+ )
247
+
248
+ def temporal_attention_control(
249
+ self, place_in_unet, attention_type, is_cross,
250
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
251
+ ):
252
+
253
+ h = q.shape[1]
254
+ q, k, v = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k, v))
255
+ attention_scores = torch.baddbmm(
256
+ torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
257
+ q,
258
+ k.transpose(-1, -2),
259
+ beta=0,
260
+ alpha=1 / math.sqrt(q.size(-1)),
261
+ )
262
+
263
+ if attn_mask is not None:
264
+ if attn_mask.dtype == torch.bool:
265
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
266
+ attention_scores = attention_scores + attn_mask
267
+
268
+ attention_probs = attention_scores.softmax(dim=-1)
269
+
270
+ # cast back to the original dtype
271
+ attention_probs = attention_probs.to(v.dtype)
272
+
273
+ # START OF CORE FUNCTION
274
+ # Record during inversion and edit the attention probs during editing
275
+ attention_probs = rearrange(
276
+ self.__call__(
277
+ rearrange(attention_probs, "(b h) n d -> b h n d", h=h),
278
+ is_cross,
279
+ f'{place_in_unet}_{attention_type}'
280
+ ),
281
+ "b h n d -> (b h) n d"
282
+ )
283
+ # END OF CORE FUNCTION
284
+
285
+ # compute attention output
286
+ hidden_states = torch.bmm(attention_probs, v)
287
+
288
+ # reshape hidden_states
289
+ hidden_states = rearrange(hidden_states, "(b h) n d -> b h n d", h=h)
290
+
291
+ return hidden_states
292
+
293
+ def spatial_attention_control(
294
+ self, place_in_unet, attention_type, is_cross,
295
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
296
+ ):
297
+
298
+ q = self.__call__(q, is_cross, f"{place_in_unet}_{attention_type}_q")
299
+ k = self.__call__(k, is_cross, f"{place_in_unet}_{attention_type}_k")
300
+
301
+ hidden_states = F.scaled_dot_product_attention(
302
+ q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
303
+ )
304
+
305
+ return hidden_states
i2vedit/prompt_attention/attention_util.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collect all function in prompt_attention folder.
3
+ Provide a API `make_controller' to return an initialized AttentionControlEdit class object in the main validation loop.
4
+ """
5
+
6
+ from typing import Optional, Union, Tuple, List, Dict
7
+ import abc
8
+ import numpy as np
9
+ import copy
10
+ import math
11
+ from einops import rearrange
12
+ import os
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from packaging import version
18
+
19
+ from i2vedit.prompt_attention.visualization import (
20
+ show_cross_attention,
21
+ show_self_attention_comp,
22
+ show_self_attention,
23
+ show_self_attention_distance,
24
+ calculate_attention_mask,
25
+ show_avg_difference_maps
26
+ )
27
+ from i2vedit.prompt_attention.attention_store import AttentionStore, AttentionControl
28
+ from i2vedit.prompt_attention.attention_register import register_attention_control
29
+
30
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
31
+
32
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
33
+ SDP_IS_AVAILABLE = True
34
+ from torch.backends.cuda import SDPBackend, sdp_kernel
35
+
36
+ BACKEND_MAP = {
37
+ SDPBackend.MATH: {
38
+ "enable_math": True,
39
+ "enable_flash": False,
40
+ "enable_mem_efficient": False,
41
+ },
42
+ SDPBackend.FLASH_ATTENTION: {
43
+ "enable_math": False,
44
+ "enable_flash": True,
45
+ "enable_mem_efficient": False,
46
+ },
47
+ SDPBackend.EFFICIENT_ATTENTION: {
48
+ "enable_math": False,
49
+ "enable_flash": False,
50
+ "enable_mem_efficient": True,
51
+ },
52
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
53
+ }
54
+ else:
55
+ from contextlib import nullcontext
56
+
57
+ SDP_IS_AVAILABLE = False
58
+ sdp_kernel = nullcontext
59
+ BACKEND_MAP = {}
60
+ logpy.warn(
61
+ f"No SDP backend available, likely because you are running in pytorch "
62
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
63
+ f"You might want to consider upgrading."
64
+ )
65
+
66
+
67
+
68
+ class EmptyControl:
69
+
70
+
71
+ def step_callback(self, x_t):
72
+ return x_t
73
+
74
+ def between_steps(self):
75
+ return
76
+
77
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
78
+ return attn
79
+
80
+
81
+ class AttentionControlEdit(AttentionStore, abc.ABC):
82
+ """Decide self or cross-attention. Call the reweighting cross attention module
83
+
84
+ Args:
85
+ AttentionStore (_type_): ([1, 4, 8, 64, 64])
86
+ abc (_type_): [8, 8, 1024, 77]
87
+ """
88
+ def get_all_last_latents(self, overlay_size):
89
+ return [latents[:,-overlay_size:,...] for latents in self.latents_store]
90
+
91
+
92
+ def step_callback(self, x_t):
93
+ x_t = super().step_callback(x_t)
94
+ x_t_device = x_t.device
95
+ x_t_dtype = x_t.dtype
96
+
97
+ # if self.previous_latents is not None:
98
+ # # replace latents
99
+ # step_in_store = self.cur_step - 1
100
+ # previous_latents = self.previous_latents[step_in_store]
101
+ # x_t[:,:len(previous_latents),...] = previous_latents.to(x_t_device, x_t_dtype)
102
+ if self.latent_blend:
103
+
104
+ avg_attention = self.get_average_attention()
105
+ masks = []
106
+ for key in avg_attention:
107
+ if 'down' in key and 'mask' in key:
108
+ for attn in avg_attention[key]:
109
+ if attn.shape[-2] == 8 * 9:
110
+ masks.append( attn )
111
+ mask = sum(masks) / len(masks)
112
+ mask[mask > 0.2] = 1.0
113
+ if self.use_inversion_attention and self.additional_attention_store is not None:
114
+ step_in_store = len(self.additional_attention_store.latents_store) - self.cur_step
115
+ elif self.additional_attention_store is None:
116
+ pass
117
+ else:
118
+ step_in_store = self.cur_step - 1
119
+
120
+ inverted_latents = self.additional_attention_store.latents_store[step_in_store]
121
+ inverted_latents = inverted_latents.to(device =x_t_device, dtype=x_t_dtype)
122
+
123
+ x_t = (1 - mask) * inverted_latents + mask * x_t
124
+
125
+ self.step_in_store_atten_dict = None
126
+
127
+ return x_t
128
+
129
+ def replace_self_attention(self, attn_base, attn_replace, reshaped_mask=None, key=None):
130
+
131
+ target_device = attn_replace.device
132
+ target_dtype = attn_replace.dtype
133
+ attn_base = attn_base.to(target_device, dtype=target_dtype)
134
+
135
+ if "temporal" in key:
136
+
137
+ if self.control_mode["temporal_self"] == "copy_v2":
138
+ if self.cur_step < int(self.temporal_step_thr[0] * self.num_steps):
139
+ return attn_base
140
+ if self.cur_step >= int(self.temporal_step_thr[1] * self.num_steps):
141
+ return attn_replace
142
+ if ('down' in key and self.current_pos<4) or \
143
+ ('up' in key and self.current_pos>1):
144
+ return attn_replace
145
+ return attn_base
146
+
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ elif "spatial" in key:
151
+
152
+ raise NotImplementedError
153
+
154
+ def replace_cross_attention(self, attn_base, attn_replace, key=None):
155
+ raise NotImplementedError
156
+
157
+ def update_attention_position_dict(self, current_attention_key):
158
+ self.attention_position_counter_dict[current_attention_key] +=1
159
+
160
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
161
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
162
+
163
+ if 'mask' in place_in_unet:
164
+ return attn
165
+
166
+ if (not is_cross and 'temporal' in place_in_unet and (self.cur_step < self.num_temporal_self_replace[0] or self.cur_step >=self.num_temporal_self_replace[1])):
167
+ if self.control_mode["temporal_self"] == "copy" or \
168
+ self.control_mode["temporal_self"] == "copy_v2":
169
+ return attn
170
+
171
+ if (not is_cross and 'spatial' in place_in_unet and (self.cur_step < self.num_spatial_self_replace[0] or self.cur_step >=self.num_spatial_self_replace[1])):
172
+ if self.control_mode["spatial_self"] == "copy":
173
+ return attn
174
+
175
+ if (is_cross and (self.cur_step < self.num_cross_replace[0] or self.cur_step >= self.num_cross_replace[1])):
176
+ return attn
177
+
178
+ if True:#'temporal' in place_in_unet:
179
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
180
+ current_pos = self.attention_position_counter_dict[key]
181
+
182
+ if self.use_inversion_attention and self.additional_attention_store is not None:
183
+ step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
184
+ elif self.additional_attention_store is None:
185
+ return attn
186
+
187
+ else:
188
+ step_in_store = self.cur_step
189
+
190
+ step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
191
+
192
+ if isinstance(step_in_store_atten_dict, str):
193
+ if self.step_in_store_atten_dict is None:
194
+ step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
195
+ self.step_in_store_atten_dict = step_in_store_atten_dict
196
+ else:
197
+ step_in_store_atten_dict = self.step_in_store_atten_dict
198
+
199
+ # Note that attn is append to step_store,
200
+ # if attn is get through clean -> noisy, we should inverse it
201
+ #print(key)
202
+ attn_base = step_in_store_atten_dict[key][current_pos]
203
+ self.current_pos = current_pos
204
+
205
+ self.update_attention_position_dict(key)
206
+ # save in format of [temporal, head, resolution, text_embedding]
207
+ attn_base, attn_replace = attn_base, attn
208
+
209
+ if not is_cross:
210
+ attn = self.replace_self_attention(attn_base, attn_replace, None, key)
211
+
212
+ #elif is_cross and (self.num_cross_replace[0] <= self.cur_step < self.num_cross_replace[1]):
213
+ elif is_cross:
214
+ attn = self.replace_cross_attention(attn_base, attn_replace, key)
215
+
216
+ return attn
217
+
218
+ else:
219
+
220
+ raise NotImplementedError("Due to CUDA RAM limit, direct replace functions for spatial are not implemented.")
221
+
222
+ def between_steps(self):
223
+
224
+ super().between_steps()
225
+
226
+
227
+
228
+ self.step_store = self.get_empty_store()
229
+
230
+ self.attention_position_counter_dict = {
231
+ 'down_spatial_q_cross': 0,
232
+ 'mid_spatial_q_cross': 0,
233
+ 'up_spatial_q_cross': 0,
234
+ 'down_spatial_k_cross': 0,
235
+ 'mid_spatial_k_cross': 0,
236
+ 'up_spatial_k_cross': 0,
237
+ 'down_spatial_mask_cross': 0,
238
+ 'mid_spatial_mask_cross': 0,
239
+ 'up_spatial_mask_cross': 0,
240
+ 'down_spatial_q_self': 0,
241
+ 'mid_spatial_q_self': 0,
242
+ 'up_spatial_q_self': 0,
243
+ 'down_spatial_k_self': 0,
244
+ 'mid_spatial_k_self': 0,
245
+ 'up_spatial_k_self': 0,
246
+ 'down_spatial_mask_self': 0,
247
+ 'mid_spatial_mask_self': 0,
248
+ 'up_spatial_mask_self': 0,
249
+ 'down_temporal_cross': 0,
250
+ 'mid_temporal_cross': 0,
251
+ 'up_temporal_cross': 0,
252
+ 'down_temporal_self': 0,
253
+ 'mid_temporal_self': 0,
254
+ 'up_temporal_self': 0
255
+ }
256
+ return
257
+
258
+ def __init__(self, num_steps: int,
259
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
260
+ temporal_self_replace_steps: Union[float, Tuple[float, float]],
261
+ spatial_self_replace_steps: Union[float, Tuple[float, float]],
262
+ control_mode={"temporal_self":"copy","spatial_self":"copy"},
263
+ spatial_attention_chunk_size = 1,
264
+ additional_attention_store: AttentionStore =None,
265
+ use_inversion_attention: bool=False,
266
+ save_self_attention: bool=True,
267
+ save_latents: bool=True,
268
+ disk_store=False,
269
+ *args, **kwargs
270
+ ):
271
+ super(AttentionControlEdit, self).__init__(
272
+ save_self_attention=save_self_attention,
273
+ save_latents=save_latents,
274
+ disk_store=disk_store)
275
+ self.additional_attention_store = additional_attention_store
276
+ if type(temporal_self_replace_steps) is float:
277
+ temporal_self_replace_steps = 0, temporal_self_replace_steps
278
+ if type(spatial_self_replace_steps) is float:
279
+ spatial_self_replace_steps = 0, spatial_self_replace_steps
280
+ if type(cross_replace_steps) is float:
281
+ cross_replace_steps = 0, cross_replace_steps
282
+ self.num_temporal_self_replace = int(num_steps * temporal_self_replace_steps[0]), int(num_steps * temporal_self_replace_steps[1])
283
+ self.num_spatial_self_replace = int(num_steps * spatial_self_replace_steps[0]), int(num_steps * spatial_self_replace_steps[1])
284
+ self.num_cross_replace = int(num_steps * cross_replace_steps[0]), int(num_steps * cross_replace_steps[1])
285
+ self.control_mode = control_mode
286
+ self.spatial_attention_chunk_size = spatial_attention_chunk_size
287
+ self.step_in_store_atten_dict = None
288
+ # We need to know the current position in attention
289
+ self.prev_attention_key_name = 0
290
+ self.use_inversion_attention = use_inversion_attention
291
+ self.attention_position_counter_dict = {
292
+ 'down_spatial_q_cross': 0,
293
+ 'mid_spatial_q_cross': 0,
294
+ 'up_spatial_q_cross': 0,
295
+ 'down_spatial_k_cross': 0,
296
+ 'mid_spatial_k_cross': 0,
297
+ 'up_spatial_k_cross': 0,
298
+ 'down_spatial_mask_cross': 0,
299
+ 'mid_spatial_mask_cross': 0,
300
+ 'up_spatial_mask_cross': 0,
301
+ 'down_spatial_q_self': 0,
302
+ 'mid_spatial_q_self': 0,
303
+ 'up_spatial_q_self': 0,
304
+ 'down_spatial_k_self': 0,
305
+ 'mid_spatial_k_self': 0,
306
+ 'up_spatial_k_self': 0,
307
+ 'down_spatial_mask_self': 0,
308
+ 'mid_spatial_mask_self': 0,
309
+ 'up_spatial_mask_self': 0,
310
+ 'down_temporal_cross': 0,
311
+ 'mid_temporal_cross': 0,
312
+ 'up_temporal_cross': 0,
313
+ 'down_temporal_self': 0,
314
+ 'mid_temporal_self': 0,
315
+ 'up_temporal_self': 0
316
+ }
317
+ self.mask_thr = kwargs.get("mask_thr", 0.35)
318
+ self.latent_blend = kwargs.get('latent_blend', False)
319
+
320
+ self.temporal_step_thr = kwargs.get("temporal_step_thr", [0.4,0.8])
321
+ self.num_steps = num_steps
322
+
323
+ def spatial_attention_control(
324
+ self, place_in_unet, attention_type, is_cross,
325
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False
326
+ ):
327
+
328
+ return self.spatial_attention_matching(
329
+ place_in_unet, attention_type, is_cross,
330
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False,
331
+ mode = self.control_mode["spatial_self"]
332
+ )
333
+
334
+
335
+ def spatial_attention_matching(
336
+ self, place_in_unet, attention_type, is_cross,
337
+ q, k, v, attn_mask, dropout_p=0.0, is_causal=False,
338
+ mode = "matching"
339
+ ):
340
+ place_in_unet = f"{place_in_unet}_{attention_type}"
341
+ with sdp_kernel(**BACKEND_MAP[None]):
342
+ # print("register", q.shape, k.shape, v.shape)
343
+
344
+ # fetch inversion q and k
345
+ key_q = f"{place_in_unet}_q_{'cross' if is_cross else 'self'}"
346
+ key_k = f"{place_in_unet}_k_{'cross' if is_cross else 'self'}"
347
+ current_pos_q = self.attention_position_counter_dict[key_q]
348
+ current_pos_k = self.attention_position_counter_dict[key_k]
349
+
350
+ if self.use_inversion_attention and self.additional_attention_store is not None:
351
+ step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
352
+ else:
353
+ step_in_store = self.cur_step
354
+
355
+ step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
356
+
357
+ if isinstance(step_in_store_atten_dict, str):
358
+ if self.step_in_store_atten_dict is None:
359
+ step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
360
+ self.step_in_store_atten_dict = step_in_store_atten_dict
361
+ else:
362
+ step_in_store_atten_dict = self.step_in_store_atten_dict
363
+
364
+ q0s = step_in_store_atten_dict[key_q][current_pos_q].to(q.device)
365
+ k0s = step_in_store_atten_dict[key_k][current_pos_k].to(k.device)
366
+
367
+ self.update_attention_position_dict(key_q)
368
+ self.update_attention_position_dict(key_k)
369
+
370
+ qs, ks, vs = q, k, v
371
+
372
+ h = q.shape[1]
373
+ res = int(np.sqrt(q.shape[-2] / (9*16)))
374
+ if res == 0:
375
+ res = 1
376
+ #res = int(np.sqrt(q.shape[-2] / (8*14)))
377
+ bs = self.spatial_attention_chunk_size
378
+ if bs is None: bs = qs.shape[0]
379
+ N = qs.shape[0] // bs
380
+ assert qs.shape[0] % bs == 0
381
+ i1st, n1st = qs.shape[0]//2//bs, qs.shape[0]//2%bs
382
+ outs = []
383
+ masks = []
384
+
385
+ # this might reduce time costs but will introduce inaccurate motions
386
+ # if current_pos_q >= 6 and 'up' in key_q:
387
+ # return F.scaled_dot_product_attention(
388
+ # q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
389
+ # )
390
+
391
+ for i in range(N):
392
+ q = qs[i*bs:(i+1)*bs,...].type(torch.float32)
393
+ k = ks[i*bs:(i+1)*bs,...].type(torch.float32)
394
+ v = vs[i*bs:(i+1)*bs,...].type(torch.float32)
395
+
396
+ q, k, v = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k, v))
397
+
398
+ with torch.autocast("cuda", enabled=False):
399
+ attention_scores = torch.baddbmm(
400
+ torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
401
+ q,
402
+ k.transpose(-1, -2),
403
+ beta=0,
404
+ alpha=1 / math.sqrt(q.size(-1)),
405
+ )
406
+
407
+ if attn_mask is not None:
408
+ if attn_mask.dtype == torch.bool:
409
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
410
+ attention_scores = attention_scores + attn_mask
411
+
412
+ attention_probs = attention_scores.softmax(dim=-1).to(vs.dtype)
413
+
414
+ # only compute conditional output
415
+ if i >= N//2:
416
+
417
+ q0 = q0s[(i-N//2)*bs:(i-N//2+1)*bs,...].type(torch.float32)
418
+ k0 = k0s[(i-N//2)*bs:(i-N//2+1)*bs,...].type(torch.float32)
419
+
420
+ q0, k0 = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q0, k0))
421
+
422
+ with torch.autocast("cuda", enabled=False):
423
+ attention_scores_0 = torch.baddbmm(
424
+ torch.empty(q0.shape[0], q0.shape[1], k0.shape[1], dtype=q0.dtype, device=q0.device),
425
+ q0,
426
+ k0.transpose(-1, -2),
427
+ beta=0,
428
+ alpha=1 / math.sqrt(q0.size(-1)),
429
+ )
430
+
431
+ attention_probs_0 = attention_scores_0.softmax(dim=-1).to(vs.dtype)
432
+
433
+ attention_probs, attention_probs_0 = \
434
+ map(lambda t: rearrange(t, "(b h) n d -> b h n d", h=h),
435
+ (attention_probs, attention_probs_0))
436
+
437
+ if mode == "masked_copy":
438
+
439
+ mask = torch.sum(
440
+ torch.mean(
441
+ torch.abs(attention_probs_0 - attention_probs),
442
+ dim=1
443
+ ),
444
+ dim=2
445
+ ).reshape(bs,1,-1,1).clamp(0,2)/2.0
446
+ mask_thr = (self.mask_thr[1]-self.mask_thr[0]) / (qs.shape[0]//2)*(i-N//2) + self.mask_thr[0]
447
+ mask_tmp = mask.clone()
448
+ mask[mask>=mask_thr] = 1.0
449
+ masks.append(mask)
450
+
451
+ # apply mask
452
+ attention_probs = (1 - mask) * attention_probs_0 + mask * attention_probs
453
+
454
+ else:
455
+ raise NotImplementedError
456
+
457
+ attention_probs = rearrange(attention_probs, "b h n d -> (b h) n d")
458
+
459
+ # compute attention output
460
+ hidden_states = torch.bmm(attention_probs, v)
461
+
462
+ # reshape hidden_states
463
+ hidden_states = rearrange(hidden_states, "(b h) n d -> b h n d", h=h)
464
+
465
+ outs.append(hidden_states)
466
+
467
+ if mode == "masked_copy":
468
+
469
+ # masks = rearrange(torch.cat(masks, 0), "b 1 (h w) 1 -> h (b w)", h=res*9)
470
+ masks = torch.cat(masks, 0)
471
+ #print(f"{place_in_unet}_masked_copy")
472
+ # save mask
473
+ _ = self.__call__(masks, is_cross, f"{place_in_unet}_mask")
474
+
475
+ return torch.cat(outs, 0)
476
+
477
+ class ConsistencyAttentionControl(AttentionStore, abc.ABC):
478
+ """Decide self or cross-attention. Call the reweighting cross attention module
479
+
480
+ Args:
481
+ AttentionStore (_type_): ([1, 4, 8, 64, 64])
482
+ abc (_type_): [8, 8, 1024, 77]
483
+ """
484
+ def step_callback(self, x_t):
485
+ x_t = super().step_callback(x_t)
486
+ x_t_device = x_t.device
487
+ x_t_dtype = x_t.dtype
488
+
489
+ # if self.previous_latents is not None:
490
+ # # replace latents
491
+ # step_in_store = self.cur_step - 1
492
+ # previous_latents = self.previous_latents[step_in_store]
493
+ # x_t[:,:len(previous_latents),...] = previous_latents.to(x_t_device, x_t_dtype)
494
+
495
+ self.step_in_store_atten_dict = None
496
+
497
+ return x_t
498
+
499
+ def update_attention_position_dict(self, current_attention_key):
500
+ self.attention_position_counter_dict[current_attention_key] +=1
501
+
502
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
503
+ if self.cur_att_layer >= self.num_uncond_att_layers:
504
+ attn = self.forward(attn, is_cross, place_in_unet)
505
+
506
+ self.cur_att_layer += 1
507
+
508
+ return attn
509
+
510
+ def set_cur_step(self, step: int = 0):
511
+ self.cur_step = step
512
+
513
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
514
+ super(ConsistencyAttentionControl, self).forward(attn, is_cross, place_in_unet)
515
+
516
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
517
+ current_pos = self.attention_position_counter_dict[key]
518
+
519
+ if self.use_inversion_attention and self.additional_attention_store is not None:
520
+ step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1
521
+ elif self.additional_attention_store is None:
522
+ return attn
523
+
524
+ else:
525
+ step_in_store = self.cur_step
526
+
527
+ step_in_store_atten_dict = self.additional_attention_store.attention_store_all_step[step_in_store]
528
+
529
+ if isinstance(step_in_store_atten_dict, str):
530
+ if self.step_in_store_atten_dict is None:
531
+ step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
532
+ self.step_in_store_atten_dict = step_in_store_atten_dict
533
+ else:
534
+ step_in_store_atten_dict = self.step_in_store_atten_dict
535
+
536
+ # Note that attn is append to step_store,
537
+ # if attn is get through clean -> noisy, we should inverse it
538
+ #print("consistency", key)
539
+ attn_base = step_in_store_atten_dict[key][current_pos].to(attn.device, attn.dtype)
540
+ attn_base = attn_base.detach()
541
+
542
+ self.update_attention_position_dict(key)
543
+ # save in format of [temporal, head, resolution, text_embedding]
544
+
545
+ attn = torch.cat([attn_base, attn], dim=2)
546
+
547
+ return attn
548
+
549
+ @staticmethod
550
+ def get_empty_store():
551
+ return {
552
+ "down_temporal_k_self": [], "mid_temporal_k_self": [], "up_temporal_k_self": [],
553
+ "down_temporal_v_self": [], "mid_temporal_v_self": [], "up_temporal_v_self": []
554
+ }
555
+
556
+ def between_steps(self):
557
+
558
+ super().between_steps()
559
+
560
+ self.step_store = self.get_empty_store()
561
+
562
+ self.attention_position_counter_dict = {
563
+ 'down_temporal_k_self': 0,
564
+ 'mid_temporal_k_self': 0,
565
+ 'up_temporal_k_self': 0,
566
+ 'down_temporal_v_self': 0,
567
+ 'mid_temporal_v_self': 0,
568
+ 'up_temporal_v_self': 0
569
+ }
570
+ return
571
+
572
+ def __init__(self,
573
+ additional_attention_store: AttentionStore =None,
574
+ use_inversion_attention: bool=False,
575
+ load_attention_store: str = None,
576
+ save_self_attention: bool=True,
577
+ save_latents: bool=True,
578
+ disk_store=False,
579
+ store_path:str=None
580
+ ):
581
+ super(ConsistencyAttentionControl, self).__init__(
582
+ save_self_attention=save_self_attention,
583
+ load_attention_store=load_attention_store,
584
+ save_latents=save_latents,
585
+ disk_store=disk_store,
586
+ store_path=store_path
587
+ )
588
+
589
+ self.additional_attention_store = additional_attention_store
590
+ self.step_in_store_atten_dict = None
591
+ # We need to know the current position in attention
592
+ self.use_inversion_attention = use_inversion_attention
593
+ self.attention_position_counter_dict = {
594
+ 'down_temporal_k_self': 0,
595
+ 'mid_temporal_k_self': 0,
596
+ 'up_temporal_k_self': 0,
597
+ 'down_temporal_v_self': 0,
598
+ 'mid_temporal_v_self': 0,
599
+ 'up_temporal_v_self': 0
600
+ }
601
+
602
+
603
+
604
+ def make_controller(
605
+ cross_replace_steps: Dict[str, float], self_replace_steps: float=0.0,
606
+ additional_attention_store=None, use_inversion_attention = False,
607
+ NUM_DDIM_STEPS=None,
608
+ save_path = None,
609
+ save_self_attention = True,
610
+ disk_store = False
611
+ ) -> AttentionControlEdit:
612
+ controller = AttentionControlEdit(NUM_DDIM_STEPS,
613
+ cross_replace_steps=cross_replace_steps,
614
+ self_replace_steps=self_replace_steps,
615
+ additional_attention_store=additional_attention_store,
616
+ use_inversion_attention = use_inversion_attention,
617
+ save_self_attention = save_self_attention,
618
+ disk_store=disk_store
619
+ )
620
+ return controller
621
+
i2vedit/prompt_attention/common/__init__.py ADDED
File without changes
i2vedit/prompt_attention/common/image_util.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import textwrap
4
+
5
+ import imageio
6
+ import numpy as np
7
+ from typing import Sequence
8
+ import requests
9
+ import cv2
10
+ from PIL import Image, ImageDraw, ImageFont
11
+
12
+ import torch
13
+ from torchvision import transforms
14
+ from einops import rearrange
15
+
16
+
17
+ IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".JPEG")
18
+
19
+ FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf"
20
+ FONT_PATH = "./docs/OpenSans-Regular.ttf"
21
+
22
+
23
+ def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image:
24
+ new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color)
25
+ new_image.paste(image, (left, top))
26
+ return new_image
27
+
28
+
29
+ def download_font_opensans(path=FONT_PATH):
30
+ font_url = FONT_URL
31
+ response = requests.get(font_url)
32
+ os.makedirs(os.path.dirname(path), exist_ok=True)
33
+ with open(path, "wb") as f:
34
+ f.write(response.content)
35
+
36
+
37
+ def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image:
38
+ image_w = image.width
39
+ _, _, text_w, text_h = font.getbbox(text)
40
+ line_size = math.floor(len(text) * image_w / text_w)
41
+
42
+ lines = textwrap.wrap(text, width=line_size)
43
+ padding = text_h * len(lines)
44
+ image = pad(image, top=padding + 3)
45
+
46
+ ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font)
47
+ return image
48
+
49
+
50
+ def annotate_image(image: Image.Image, text: str, font_size: int = 15):
51
+ if not os.path.isfile(FONT_PATH):
52
+ download_font_opensans()
53
+ font = ImageFont.truetype(FONT_PATH, size=font_size)
54
+ return annotate_image_with_font(image=image, text=text, font=font)
55
+
56
+
57
+ def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image:
58
+ if isinstance(images[0], np.ndarray):
59
+ images = [Image.fromarray(i) for i in images]
60
+
61
+ if rows is None:
62
+ assert cols is not None
63
+ rows = math.ceil(len(images) / cols)
64
+ else:
65
+ cols = math.ceil(len(images) / rows)
66
+
67
+ w, h = images[0].size
68
+ grid = Image.new("RGB", size=(cols * w, rows * h))
69
+ for i, image in enumerate(images):
70
+ if image.size != (w, h):
71
+ image = image.resize((w, h))
72
+ grid.paste(image, box=(i % cols * w, i // cols * h))
73
+ return grid
74
+
75
+
76
+ def save_images_as_gif(
77
+ images: Sequence[Image.Image],
78
+ save_path: str,
79
+ loop=0,
80
+ duration=100,
81
+ optimize=False,
82
+ ) -> None:
83
+
84
+ images[0].save(
85
+ save_path,
86
+ save_all=True,
87
+ append_images=images[1:],
88
+ optimize=optimize,
89
+ loop=loop,
90
+ duration=duration,
91
+ )
92
+
93
+ def save_images_as_mp4(
94
+ images: Sequence[Image.Image],
95
+ save_path: str,
96
+ ) -> None:
97
+
98
+ writer_edit = imageio.get_writer(
99
+ save_path,
100
+ fps=10)
101
+ for i in images:
102
+ init_image = i.convert("RGB")
103
+ writer_edit.append_data(np.array(init_image))
104
+ writer_edit.close()
105
+
106
+
107
+
108
+ def save_images_as_folder(
109
+ images: Sequence[Image.Image],
110
+ save_path: str,
111
+ ) -> None:
112
+ os.makedirs(save_path, exist_ok=True)
113
+ for index, image in enumerate(images):
114
+ init_image = image
115
+ if len(np.array(init_image).shape) == 3:
116
+ cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)[:, :, ::-1])
117
+ else:
118
+ cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image))
119
+
120
+ def log_train_samples(
121
+ train_dataloader,
122
+ save_path,
123
+ num_batch: int = 4,
124
+ ):
125
+ train_samples = []
126
+ for idx, batch in enumerate(train_dataloader):
127
+ if idx >= num_batch:
128
+ break
129
+ train_samples.append(batch["images"])
130
+
131
+ train_samples = torch.cat(train_samples).numpy()
132
+ train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
133
+ train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
134
+ train_samples = numpy_batch_seq_to_pil(train_samples)
135
+ train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
136
+ # save_images_as_gif(train_samples, save_path)
137
+ save_gif_mp4_folder_type(train_samples, save_path)
138
+
139
+ def log_train_reg_samples(
140
+ train_dataloader,
141
+ save_path,
142
+ num_batch: int = 4,
143
+ ):
144
+ train_samples = []
145
+ for idx, batch in enumerate(train_dataloader):
146
+ if idx >= num_batch:
147
+ break
148
+ train_samples.append(batch["class_images"])
149
+
150
+ train_samples = torch.cat(train_samples).numpy()
151
+ train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
152
+ train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
153
+ train_samples = numpy_batch_seq_to_pil(train_samples)
154
+ train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
155
+ # save_images_as_gif(train_samples, save_path)
156
+ save_gif_mp4_folder_type(train_samples, save_path)
157
+
158
+
159
+ def save_gif_mp4_folder_type(images, save_path, save_gif=True):
160
+
161
+ if isinstance(images[0], np.ndarray):
162
+ images = [Image.fromarray(i) for i in images]
163
+ elif isinstance(images[0], torch.Tensor):
164
+ images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images]
165
+ save_path_mp4 = save_path.replace('gif', 'mp4')
166
+ save_path_folder = save_path.replace('.gif', '')
167
+ if save_gif: save_images_as_gif(images, save_path)
168
+ save_images_as_mp4(images, save_path_mp4)
169
+ save_images_as_folder(images, save_path_folder)
170
+
171
+ # copy from video_diffusion/pipelines/stable_diffusion.py
172
+ def numpy_seq_to_pil(images):
173
+ """
174
+ Convert a numpy image or a batch of images to a PIL image.
175
+ """
176
+ if images.ndim == 3:
177
+ images = images[None, ...]
178
+ images = (images * 255).round().astype("uint8")
179
+ if images.shape[-1] == 1:
180
+ # special case for grayscale (single channel) images
181
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
182
+ else:
183
+ pil_images = [Image.fromarray(image) for image in images]
184
+
185
+ return pil_images
186
+
187
+ # copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py
188
+ def numpy_batch_seq_to_pil(images):
189
+ pil_images = []
190
+ for sequence in images:
191
+ pil_images.append(numpy_seq_to_pil(sequence))
192
+ return pil_images
i2vedit/prompt_attention/common/instantiate_from_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy from stable diffusion
3
+ """
4
+ import importlib
5
+
6
+
7
+ def instantiate_from_config(config:dict, **args_from_code):
8
+ """Util funciton to decompose differenct modules using config
9
+
10
+ Args:
11
+ config (dict): with key of "target" and "params", better from yaml
12
+ static
13
+ args_from_code: additional con
14
+
15
+
16
+ Returns:
17
+ a validation/training pipeline, a module
18
+ """
19
+ if not "target" in config:
20
+ if config == '__is_first_stage__':
21
+ return None
22
+ elif config == "__is_unconditional__":
23
+ return None
24
+ raise KeyError("Expected key `target` to instantiate.")
25
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code)
26
+
27
+
28
+ def get_obj_from_str(string, reload=False):
29
+ module, cls = string.rsplit(".", 1)
30
+ if reload:
31
+ module_imp = importlib.import_module(module)
32
+ importlib.reload(module_imp)
33
+ return getattr(importlib.import_module(module, package=None), cls)
i2vedit/prompt_attention/common/logger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging, logging.handlers
3
+ from accelerate.logging import get_logger
4
+
5
+ def get_logger_config_path(logdir):
6
+ # accelerate handles the logger in multiprocessing
7
+ logger = get_logger(__name__)
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s:%(levelname)s : %(message)s',
11
+ datefmt='%a, %d %b %Y %H:%M:%S',
12
+ filename=os.path.join(logdir, 'log.log'),
13
+ filemode='w')
14
+ chlr = logging.StreamHandler()
15
+ chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s'))
16
+ logger.logger.addHandler(chlr)
17
+ return logger
i2vedit/prompt_attention/common/set_seed.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
3
+
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+
8
+ from accelerate.utils import set_seed
9
+
10
+
11
+ def video_set_seed(seed: int):
12
+ """
13
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
14
+
15
+ Args:
16
+ seed (`int`): The seed to set.
17
+ device_specific (`bool`, *optional*, defaults to `False`):
18
+ Whether to differ the seed on each device slightly with `self.process_index`.
19
+ """
20
+ set_seed(seed)
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ torch.backends.cudnn.benchmark = False
26
+ # torch.use_deterministic_algorithms(True, warn_only=True)
27
+ # [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic)
28
+
i2vedit/prompt_attention/common/util.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import inspect
5
+ import datetime
6
+ from typing import List, Tuple, Optional, Dict
7
+
8
+
9
+ def glob_files(
10
+ root_path: str,
11
+ extensions: Tuple[str],
12
+ recursive: bool = True,
13
+ skip_hidden_directories: bool = True,
14
+ max_directories: Optional[int] = None,
15
+ max_files: Optional[int] = None,
16
+ relative_path: bool = False,
17
+ ) -> Tuple[List[str], bool, bool]:
18
+ """glob files with specified extensions
19
+
20
+ Args:
21
+ root_path (str): _description_
22
+ extensions (Tuple[str]): _description_
23
+ recursive (bool, optional): _description_. Defaults to True.
24
+ skip_hidden_directories (bool, optional): _description_. Defaults to True.
25
+ max_directories (Optional[int], optional): max number of directories to search. Defaults to None.
26
+ max_files (Optional[int], optional): max file number limit. Defaults to None.
27
+ relative_path (bool, optional): _description_. Defaults to False.
28
+
29
+ Returns:
30
+ Tuple[List[str], bool, bool]: _description_
31
+ """
32
+ paths = []
33
+ hit_max_directories = False
34
+ hit_max_files = False
35
+ for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)):
36
+ if skip_hidden_directories and os.path.basename(directory).startswith("."):
37
+ continue
38
+
39
+ if max_directories is not None and directory_idx >= max_directories:
40
+ hit_max_directories = True
41
+ break
42
+
43
+ paths += [
44
+ os.path.join(directory, fname)
45
+ for fname in sorted(fnames)
46
+ if fname.lower().endswith(extensions)
47
+ ]
48
+
49
+ if not recursive:
50
+ break
51
+
52
+ if max_files is not None and len(paths) > max_files:
53
+ hit_max_files = True
54
+ paths = paths[:max_files]
55
+ break
56
+
57
+ if relative_path:
58
+ paths = [os.path.relpath(p, root_path) for p in paths]
59
+
60
+ return paths, hit_max_directories, hit_max_files
61
+
62
+
63
+ def get_time_string() -> str:
64
+ x = datetime.datetime.now()
65
+ return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}"
66
+
67
+
68
+ def get_function_args() -> Dict:
69
+ frame = sys._getframe(1)
70
+ args, _, _, values = inspect.getargvalues(frame)
71
+ args_dict = copy.deepcopy({arg: values[arg] for arg in args})
72
+
73
+ return args_dict
i2vedit/prompt_attention/ptp_utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ utils code for image visualization
3
+ '''
4
+
5
+
6
+ # Copyright 2022 Google LLC
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ import numpy as np
21
+ import torch
22
+ from PIL import Image
23
+ import cv2
24
+ from typing import Optional, Union, Tuple, List, Callable, Dict
25
+
26
+ import datetime
27
+
28
+
29
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
30
+ h, w, c = image.shape
31
+ offset = int(h * .2)
32
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
33
+ font = cv2.FONT_HERSHEY_SIMPLEX
34
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
35
+ img[:h] = image
36
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
37
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
38
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
39
+ return img
40
+
41
+
42
+ def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None):
43
+ if type(images) is list:
44
+ num_empty = len(images) % num_rows
45
+ elif images.ndim == 4:
46
+ num_empty = images.shape[0] % num_rows
47
+ else:
48
+ images = [images]
49
+ num_empty = 0
50
+
51
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
52
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
53
+ num_items = len(images)
54
+
55
+ h, w, c = images[0].shape
56
+ offset = int(h * offset_ratio)
57
+ num_cols = num_items // num_rows
58
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
59
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
60
+ for i in range(num_rows):
61
+ for j in range(num_cols):
62
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
63
+ i * num_cols + j]
64
+
65
+ if save_path is not None:
66
+ pil_img = Image.fromarray(image_)
67
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
68
+ pil_img.save(f'{save_path}/{now}.png')
69
+ # display(pil_img)
70
+
71
+
72
+
73
+ def register_attention_control_p2p_deprecated(model, controller):
74
+ "Original code from prompt to prompt"
75
+ def ca_forward(self, place_in_unet):
76
+ to_out = self.to_out
77
+ if type(to_out) is torch.nn.modules.container.ModuleList:
78
+ to_out = self.to_out[0]
79
+ else:
80
+ to_out = self.to_out
81
+
82
+ # def forward(x, encoder_hidden_states=None, attention_mask=None):
83
+ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
84
+ batch_size, sequence_length, _ = hidden_states.shape
85
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
86
+
87
+ query = self.to_q(hidden_states)
88
+ query = self.head_to_batch_dim(query)
89
+
90
+ is_cross = encoder_hidden_states is not None
91
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
92
+ key = self.to_k(encoder_hidden_states)
93
+ value = self.to_v(encoder_hidden_states)
94
+ key = self.head_to_batch_dim(key)
95
+ value = self.head_to_batch_dim(value)
96
+
97
+ attention_probs = self.get_attention_scores(query, key, attention_mask) # [16, 4096, 4096]
98
+ attention_probs = controller(attention_probs, is_cross, place_in_unet)
99
+ hidden_states = torch.bmm(attention_probs, value)
100
+ hidden_states = self.batch_to_head_dim(hidden_states)
101
+
102
+ # linear proj
103
+ hidden_states = self.to_out[0](hidden_states)
104
+ # dropout
105
+ hidden_states = self.to_out[1](hidden_states)
106
+
107
+ return hidden_states
108
+
109
+ return forward
110
+
111
+ class DummyController:
112
+
113
+ def __call__(self, *args):
114
+ return args[0]
115
+
116
+ def __init__(self):
117
+ self.num_att_layers = 0
118
+
119
+ if controller is None:
120
+ controller = DummyController()
121
+
122
+ def register_recr(net_, count, place_in_unet):
123
+ if net_.__class__.__name__ == 'CrossAttention':
124
+ net_.forward = ca_forward(net_, place_in_unet)
125
+ return count + 1
126
+ elif hasattr(net_, 'children'):
127
+ for net__ in net_.children():
128
+ count = register_recr(net__, count, place_in_unet)
129
+ return count
130
+
131
+ cross_att_count = 0
132
+ sub_nets = model.unet.named_children()
133
+ for net in sub_nets:
134
+ if "down" in net[0]:
135
+ cross_att_count += register_recr(net[1], 0, "down")
136
+ elif "up" in net[0]:
137
+ cross_att_count += register_recr(net[1], 0, "up")
138
+ elif "mid" in net[0]:
139
+ cross_att_count += register_recr(net[1], 0, "mid")
140
+
141
+ controller.num_att_layers = cross_att_count
142
+
143
+
144
+ def get_word_inds(text: str, word_place: int, tokenizer):
145
+ split_text = text.split(" ")
146
+ if type(word_place) is str:
147
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
148
+ elif type(word_place) is int:
149
+ word_place = [word_place]
150
+ out = []
151
+ if len(word_place) > 0:
152
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
153
+ cur_len, ptr = 0, 0
154
+
155
+ for i in range(len(words_encode)):
156
+ cur_len += len(words_encode[i])
157
+ if ptr in word_place:
158
+ out.append(i + 1)
159
+ if cur_len >= len(split_text[ptr]):
160
+ ptr += 1
161
+ cur_len = 0
162
+ return np.array(out)
163
+
164
+
165
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
166
+ word_inds: Optional[torch.Tensor]=None):
167
+ # Edit the alpha map during attention map editing
168
+ if type(bounds) is float:
169
+ bounds = 0, bounds
170
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
171
+ if word_inds is None:
172
+ word_inds = torch.arange(alpha.shape[2])
173
+ alpha[: start, prompt_ind, word_inds] = 0
174
+ alpha[start: end, prompt_ind, word_inds] = 1
175
+ alpha[end:, prompt_ind, word_inds] = 0
176
+ return alpha
177
+
178
+ import omegaconf
179
+ def get_time_words_attention_alpha(prompts, num_steps,
180
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
181
+ tokenizer, max_num_words=77):
182
+ # Not understand
183
+ if (type(cross_replace_steps) is not dict) and \
184
+ (type(cross_replace_steps) is not omegaconf.dictconfig.DictConfig):
185
+ cross_replace_steps = {"default_": cross_replace_steps}
186
+ if "default_" not in cross_replace_steps:
187
+ cross_replace_steps["default_"] = (0., 1.)
188
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
189
+ for i in range(len(prompts) - 1):
190
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
191
+ i)
192
+ for key, item in cross_replace_steps.items():
193
+ if key != "default_":
194
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
195
+ for i, ind in enumerate(inds):
196
+ if len(ind) > 0:
197
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
198
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
199
+ return alpha_time_words
i2vedit/prompt_attention/visualization.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import os
3
+ import datetime
4
+ import numpy as np
5
+ from PIL import Image
6
+ from einops import rearrange, repeat
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from packaging import version
12
+
13
+ from i2vedit.prompt_attention import ptp_utils
14
+ from i2vedit.prompt_attention.common.image_util import save_gif_mp4_folder_type
15
+ from i2vedit.prompt_attention.attention_store import AttentionStore
16
+
17
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
18
+ SDP_IS_AVAILABLE = True
19
+ from torch.backends.cuda import SDPBackend, sdp_kernel
20
+
21
+ BACKEND_MAP = {
22
+ SDPBackend.MATH: {
23
+ "enable_math": True,
24
+ "enable_flash": False,
25
+ "enable_mem_efficient": False,
26
+ },
27
+ SDPBackend.FLASH_ATTENTION: {
28
+ "enable_math": False,
29
+ "enable_flash": True,
30
+ "enable_mem_efficient": False,
31
+ },
32
+ SDPBackend.EFFICIENT_ATTENTION: {
33
+ "enable_math": False,
34
+ "enable_flash": False,
35
+ "enable_mem_efficient": True,
36
+ },
37
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
38
+ }
39
+ else:
40
+ from contextlib import nullcontext
41
+
42
+ SDP_IS_AVAILABLE = False
43
+ sdp_kernel = nullcontext
44
+ BACKEND_MAP = {}
45
+ logpy.warn(
46
+ f"No SDP backend available, likely because you are running in pytorch "
47
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
48
+ f"You might want to consider upgrading."
49
+ )
50
+
51
+ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
52
+ out = []
53
+ attention_maps = attention_store.get_average_attention()
54
+ num_pixels = res ** 2
55
+ for location in from_where:
56
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
57
+ if item.dim() == 3:
58
+ if item.shape[1] == num_pixels:
59
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
60
+ out.append(cross_maps)
61
+ elif item.dim() == 4:
62
+ t, h, res_sq, token = item.shape
63
+ if item.shape[2] == num_pixels:
64
+ cross_maps = item.reshape(len(prompts), t, -1, res, res, item.shape[-1])[select]
65
+ out.append(cross_maps)
66
+
67
+ out = torch.cat(out, dim=-4)
68
+ out = out.sum(-4) / out.shape[-4]
69
+ return out.cpu()
70
+
71
+
72
+ def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore,
73
+ res: int, from_where: List[str], select: int = 0, save_path = None):
74
+ """
75
+ attention_store (AttentionStore):
76
+ ["down", "mid", "up"] X ["self", "cross"]
77
+ 4, 1, 6
78
+ head*res*text_token_len = 8*res*77
79
+ res=1024 -> 64 -> 1024
80
+ res (int): res
81
+ from_where (List[str]): "up", "down'
82
+ """
83
+ if isinstance(prompts, str):
84
+ prompts = [prompts,]
85
+ tokens = tokenizer.encode(prompts[select])
86
+ decoder = tokenizer.decode
87
+
88
+ attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
89
+ os.makedirs('trash', exist_ok=True)
90
+ attention_list = []
91
+ if attention_maps.dim()==3: attention_maps=attention_maps[None, ...]
92
+ for j in range(attention_maps.shape[0]):
93
+ images = []
94
+ for i in range(len(tokens)):
95
+ image = attention_maps[j, :, :, i]
96
+ image = 255 * image / image.max()
97
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
98
+ image = image.numpy().astype(np.uint8)
99
+ image = np.array(Image.fromarray(image).resize((256, 256)))
100
+ image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
101
+ images.append(image)
102
+ ptp_utils.view_images(np.stack(images, axis=0), save_path=save_path)
103
+ atten_j = np.concatenate(images, axis=1)
104
+ attention_list.append(atten_j)
105
+ if save_path is not None:
106
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
107
+ video_save_path = f'{save_path}/{now}.gif'
108
+ save_gif_mp4_folder_type(attention_list, video_save_path)
109
+ return attention_list
110
+
111
+
112
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
113
+ max_com=10, select: int = 0):
114
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
115
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
116
+ images = []
117
+ for i in range(max_com):
118
+ image = vh[i].reshape(res, res)
119
+ image = image - image.min()
120
+ image = 255 * image / image.max()
121
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
122
+ image = Image.fromarray(image).resize((256, 256))
123
+ image = np.array(image)
124
+ images.append(image)
125
+ ptp_utils.view_images(np.concatenate(images, axis=1))
126
+
127
+ def show_avg_difference_maps(
128
+ attention_store: AttentionStore,
129
+ save_path = None
130
+ ):
131
+ avg_attention = attention_store.get_average_attention()
132
+ masks = []
133
+ for key in avg_attention:
134
+ if 'mask' in key:
135
+ for cur_pos in range(len(avg_attention[key])):
136
+ mask = avg_attention[key][cur_pos]
137
+ res = mask.shape[0] / 9
138
+ file_path = os.path.join(
139
+ save_path,
140
+ f"avg_key_{key}_curpos_{cur_pos}_res_{res}_mask.png"
141
+ )
142
+ print(key, cur_pos, mask.shape)
143
+ image = 255 * mask #/ attn.max()
144
+ image = image.cpu().numpy().astype(np.uint8)
145
+ image = Image.fromarray(image)
146
+ image.save(file_path)
147
+
148
+
149
+
150
+
151
+ def show_self_attention(
152
+ attention_store: AttentionStore,
153
+ steps: List[int],
154
+ save_path = None,
155
+ inversed = False):
156
+ """
157
+ attention_store (AttentionStore):
158
+ ["down", "mid", "up"] X ["self", "cross"]
159
+ 4, 1, 6
160
+ head*res*text_token_len = 8*res*77
161
+ res=1024 -> 64 -> 1024
162
+ res (int): res
163
+ from_where (List[str]): "up", "down'
164
+ """
165
+ #os.system(f"rm -rf {save_path}")
166
+ os.makedirs(save_path, exist_ok=True)
167
+ for step in steps:
168
+ step_in_store = len(attention_store.attention_store_all_step) - step - 1 if inversed else step
169
+ print("step_in_store", step_in_store)
170
+ step_in_store_atten_dict = attention_store.attention_store_all_step[step_in_store]
171
+ if isinstance(step_in_store_atten_dict, str):
172
+ step_in_store_atten_dict = torch.load(step_in_store_atten_dict)
173
+
174
+ step_in_store_atten_dict_reorg = {}
175
+
176
+ for key in step_in_store_atten_dict:
177
+ if '_q_' not in key and ('_k_' not in key and '_v_' not in key):
178
+ step_in_store_atten_dict_reorg[key] = step_in_store_atten_dict[key]
179
+ elif '_q_' in key:
180
+ step_in_store_atten_dict_reorg[key.replace("_q_","_qxk_")] = \
181
+ [[step_in_store_atten_dict[key][i], \
182
+ step_in_store_atten_dict[key.replace("_q_","_k_")][i] \
183
+ ] \
184
+ for i in range(len(step_in_store_atten_dict[key]))]
185
+
186
+ for key in step_in_store_atten_dict_reorg:
187
+ if '_mask_' not in key and '_qxk_' not in key:
188
+ for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
189
+ attn = step_in_store_atten_dict_reorg[key][cur_pos]
190
+ attn = torch.mean(attn, dim=1)
191
+ s, t, d = attn.shape
192
+ res = int(np.sqrt(s / (9*16)))
193
+ attn = attn.reshape(res*9,res*16,t,d).permute(2,0,3,1).reshape(t*res*9,d*res*16)
194
+ file_path = os.path.join(
195
+ save_path,
196
+ f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}.png"
197
+ )
198
+ print(step, key, cur_pos, attn.shape)
199
+ image = 255 * attn #/ attn.max()
200
+ image = image.cpu().numpy().astype(np.uint8)
201
+ image = Image.fromarray(image)
202
+ image.save(file_path)
203
+
204
+ elif '_mask_' in key:
205
+ for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
206
+ mask = step_in_store_atten_dict_reorg[key][cur_pos]
207
+ res = mask.shape[0] / 9
208
+ file_path = os.path.join(
209
+ save_path,
210
+ f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_mask.png"
211
+ )
212
+ print(step, key, cur_pos, mask.shape)
213
+ image = 255 * mask #/ attn.max()
214
+ image = image.cpu().numpy().astype(np.uint8)
215
+ image = Image.fromarray(image)
216
+ image.save(file_path)
217
+
218
+ else:
219
+ for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
220
+ q, k = step_in_store_atten_dict_reorg[key][cur_pos]
221
+ q = q.to("cuda").type(torch.float32)
222
+ k = k.to("cuda").type(torch.float32)
223
+ res = int(np.sqrt(q.shape[-2] / (9*16)))
224
+ h = q.shape[1]
225
+ bs = 1
226
+ N = q.shape[0] // bs
227
+ vectors = []
228
+ vectors_diff = []
229
+ for i in range(N):
230
+ attn_prob = calculate_attention_probs(q[i*bs:(i+1)*bs], k[i*bs:(i+1)*bs])
231
+ print("attn_prob 1", attn_prob.min(), attn_prob.max())
232
+ attn_prob = torch.mean(attn_prob, dim=2).reshape(h, res*9, res*16)
233
+ print("attn_prob 2", attn_prob.min(), attn_prob.max())
234
+ attn_prob = torch.mean(attn_prob, dim=0)
235
+ print("attn_prob 3", attn_prob.min(), attn_prob.max())
236
+ vectors.append( attn_prob )
237
+ for i in range(1, len(vectors)):
238
+ vectors_diff.append(vectors[i] - vectors[i-1])
239
+ vectors = torch.cat(vectors, dim=1)
240
+ vectors_diff = torch.cat(vectors_diff, dim=1)
241
+ file_path = os.path.join(
242
+ save_path,
243
+ f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_vector.png"
244
+ )
245
+ print(step, key, cur_pos, vectors.shape)
246
+ image = 255 * vectors / vectors.max()
247
+ image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
248
+ image = Image.fromarray(image)
249
+ image.save(file_path)
250
+
251
+ file_path = os.path.join(
252
+ save_path,
253
+ f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_diff.png"
254
+ )
255
+ print(step, key, cur_pos, vectors_diff.shape)
256
+ image = 255 * vectors_diff / vectors_diff.max()
257
+ image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
258
+ image = Image.fromarray(image)
259
+ image.save(file_path)
260
+
261
+
262
+ # else:
263
+ # # 只看最后两帧
264
+ # for cur_pos in range(len(step_in_store_atten_dict_reorg[key])):
265
+ # q, k, v = step_in_store_atten_dict_reorg[key][cur_pos]
266
+ # q = q[-2:,...].to("cuda")
267
+ # k = k[-2:,...].to("cuda")
268
+ # v = v[-2:,...].to("cuda")
269
+ # res = int(np.sqrt(q.shape[-2] / (9*16)))
270
+ # attn = calculate_attention_probs(q,k,v)
271
+ # attn_d = torch.sum(torch.mean(torch.abs(attn[0,...] - attn[1,...]), dim=0), dim=1).reshape(res*9,res*16)
272
+ # print(step, key, cur_pos, attn_d.shape, attn_d.min(), attn_d.max())
273
+ # file_path = os.path.join(
274
+ # save_path,
275
+ # f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_attn_d.png"
276
+ # )
277
+ # image = (255 * attn_d + 1e-3) / 2.#attn_d.max()
278
+ # image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
279
+ # image = Image.fromarray(image)
280
+ # image.save(file_path)
281
+
282
+ def show_self_attention_distance(
283
+ attention_store: List[AttentionStore],
284
+ steps: List[int],
285
+ save_path = None,
286
+ ):
287
+ """
288
+ attention_store (AttentionStore):
289
+ ["down", "mid", "up"] X ["self", "cross"]
290
+ 4, 1, 6
291
+ head*res*text_token_len = 8*res*77
292
+ res=1024 -> 64 -> 1024
293
+ res (int): res
294
+ from_where (List[str]): "up", "down'
295
+ """
296
+ os.system(f"rm -rf {save_path}")
297
+ os.makedirs(save_path, exist_ok=True)
298
+ assert len(attention_store) == 2
299
+ for step in steps:
300
+ step_in_store = [len(attention_store[0].attention_store_all_step) - step - 1, step]
301
+ step_in_store_atten_dict = [attention_store[i].attention_store_all_step[step_in_store[i]] \
302
+ for i in range(2)]
303
+ step_in_store_atten_dict = [ \
304
+ torch.load(step_in_store_atten_dict[i]) \
305
+ if isinstance(step_in_store_atten_dict[i], str) \
306
+ else step_in_store_atten_dict[i] \
307
+ for i in range(2)]
308
+
309
+ step_in_store_atten_dict_reorg = [{},{}]
310
+
311
+ for i in range(2):
312
+ item = step_in_store_atten_dict[i]
313
+ for key in item:
314
+ if '_q_' in key:
315
+ step_in_store_atten_dict_reorg[i][key.replace("_q_","_qxk_")] = \
316
+ [[step_in_store_atten_dict[i][key][j], \
317
+ step_in_store_atten_dict[i][key.replace("_q_","_k_")][j] \
318
+ ] \
319
+ for j in range(len(step_in_store_atten_dict[i][key]))]
320
+
321
+ for key in step_in_store_atten_dict_reorg[1]:
322
+ for cur_pos in range(len(step_in_store_atten_dict_reorg[1][key])):
323
+ q1, k1 = step_in_store_atten_dict_reorg[1][key][cur_pos]
324
+ q0, k0 = step_in_store_atten_dict_reorg[0][key][cur_pos]
325
+ res = int(np.sqrt(q1.shape[-2] / (9*16)))
326
+
327
+ attn_d = calculate_attention_mask(q0, k0, q1, k1, bs=1, device="cuda")
328
+ attn_d = rearrange(attn_d, "b h w -> h (b w)")
329
+
330
+ print(step, key, cur_pos, attn_d.shape, "attnd", attn_d.min(), attn_d.max())
331
+ file_path = os.path.join(
332
+ save_path,
333
+ f"step_{step}_key_{key}_curpos_{cur_pos}_res_{res}_attn_d.png"
334
+ )
335
+ image = 255 * attn_d#attn_d.max()
336
+ image = image.clamp(0,255).cpu().numpy().astype(np.uint8)
337
+ image = Image.fromarray(image)
338
+ image.save(file_path)
339
+
340
+ def calculate_attention_mask(q0, k0, q1, k1, bs=1, device="cuda"):
341
+ q1 = q1.to(device)
342
+ k1 = k1.to(device)
343
+ q0 = q0.to(device)
344
+ k0 = k0.to(device)
345
+ res = int(np.sqrt(q1.shape[-2] / (9*16)))
346
+ N = q0.shape[0] // bs
347
+ attn_d = []
348
+ for i in range(N):
349
+ attn0 = calculate_attention_probs(q0[bs*i:bs*(i+1),...],k0[bs*i:bs*(i+1),...])
350
+ attn1 = calculate_attention_probs(q1[bs*i:bs*(i+1),...],k1[bs*i:bs*(i+1),...])
351
+ attn_d_i = torch.sum(torch.mean(torch.abs(attn0 - attn1), dim=1), dim=2).reshape(bs,res*9,res*16)
352
+ attn_d.append( attn_d_i )
353
+ attn_d = torch.cat(attn_d, dim=0) / 2.0
354
+ return attn_d.clamp(0,1)
355
+
356
+ def calculate_attention_probs(q, k, attn_mask=None):
357
+ with sdp_kernel(**BACKEND_MAP[None]):
358
+ h = q.shape[1]
359
+ q, k = map(lambda t: rearrange(t, "b h n d -> (b h) n d"), (q, k))
360
+
361
+ with torch.autocast("cuda", enabled=False):
362
+ attention_scores = torch.baddbmm(
363
+ torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
364
+ q,
365
+ k.transpose(-1, -2),
366
+ beta=0,
367
+ alpha=1 / math.sqrt(q.size(-1)),
368
+ )
369
+ #print("attention_scores", attention_scores.min(), attention_scores.max())
370
+
371
+ if attn_mask is not None:
372
+ if attn_mask.dtype == torch.bool:
373
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
374
+ attention_scores = attention_scores + attn_mask
375
+
376
+ attention_probs = attention_scores.softmax(dim=-1)
377
+ #print("attention_softmax", attention_probs.min(), attention_probs.max())
378
+
379
+ # cast back to the original dtype
380
+ attention_probs = attention_probs.to(q.dtype)
381
+
382
+ # reshape hidden_states
383
+ attention_probs = rearrange(attention_probs, "(b h) n d -> b h n d", h=h)
384
+
385
+ # v = torch.eye(q.shape[-2], device=q.device)
386
+ # v = repeat(v, "... -> b h ...", b=q.shape[0], h=q.shape[1])
387
+ # attention_probs = F.scaled_dot_product_attention(
388
+ # q, k, v, attn_mask=attn_mask
389
+ # ) # scale is dim_head ** -0.5 per default
390
+
391
+ return attention_probs
i2vedit/train.py ADDED
@@ -0,0 +1,1488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import random
8
+ import gc
9
+ import copy
10
+ from scipy.stats import anderson
11
+ import imageio
12
+ import numpy as np
13
+
14
+ from typing import Dict, Optional, Tuple, List
15
+ from omegaconf import OmegaConf
16
+ from einops import rearrange, repeat
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ import diffusers
22
+ import transformers
23
+
24
+ from torchvision import transforms
25
+ from tqdm.auto import tqdm
26
+ from PIL import Image
27
+
28
+ from accelerate import Accelerator
29
+ from accelerate.logging import get_logger
30
+ from accelerate.utils import set_seed
31
+
32
+ import diffusers
33
+ from diffusers.models import AutoencoderKL
34
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
35
+ from diffusers.optimization import get_scheduler
36
+ from diffusers.utils.import_utils import is_xformers_available
37
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
38
+ from diffusers.models.attention import BasicTransformerBlock
39
+ from diffusers import StableVideoDiffusionPipeline
40
+ from diffusers.models.lora import LoRALinearLayer
41
+ from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
42
+ from diffusers.image_processor import VaeImageProcessor
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.training_utils import EMAModel
45
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
46
+ from diffusers.utils.import_utils import is_xformers_available
47
+ from diffusers.models.unet_3d_blocks import \
48
+ (CrossAttnDownBlockSpatioTemporal,
49
+ DownBlockSpatioTemporal,
50
+ CrossAttnUpBlockSpatioTemporal,
51
+ UpBlockSpatioTemporal)
52
+
53
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
54
+ from transformers.models.clip.modeling_clip import CLIPEncoder
55
+
56
+ from i2vedit.utils.dataset import (
57
+ CachedDataset,
58
+ )
59
+
60
+ from i2vedit.utils.lora_handler import LoraHandler
61
+ from i2vedit.utils.lora import extract_lora_child_module
62
+ from i2vedit.utils.euler_utils import euler_inversion
63
+ from i2vedit.utils.svd_util import SmoothAreaRandomDetection
64
+ from i2vedit.utils.model_utils import (
65
+ tensor_to_vae_latent,
66
+ P2PEulerDiscreteScheduler,
67
+ P2PStableVideoDiffusionPipeline
68
+ )
69
+ from i2vedit.data import ResolutionControl, SingleClipDataset
70
+ from i2vedit.prompt_attention import attention_util
71
+
72
+ already_printed_trainables = False
73
+
74
+ logger = get_logger(__name__, log_level="INFO")
75
+
76
+
77
+ def create_logging(logging, logger, accelerator):
78
+ logging.basicConfig(
79
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80
+ datefmt="%m/%d/%Y %H:%M:%S",
81
+ level=logging.INFO,
82
+ )
83
+ logger.info(accelerator.state, main_process_only=False)
84
+
85
+
86
+ def accelerate_set_verbose(accelerator):
87
+ if accelerator.is_local_main_process:
88
+ transformers.utils.logging.set_verbosity_warning()
89
+ diffusers.utils.logging.set_verbosity_info()
90
+ else:
91
+ transformers.utils.logging.set_verbosity_error()
92
+ diffusers.utils.logging.set_verbosity_error()
93
+
94
+ def extend_datasets(datasets, dataset_items, extend=False):
95
+ biggest_data_len = max(x.__len__() for x in datasets)
96
+ extended = []
97
+ for dataset in datasets:
98
+ if dataset.__len__() == 0:
99
+ del dataset
100
+ continue
101
+ if dataset.__len__() < biggest_data_len:
102
+ for item in dataset_items:
103
+ if extend and item not in extended and hasattr(dataset, item):
104
+ print(f"Extending {item}")
105
+
106
+ value = getattr(dataset, item)
107
+ value *= biggest_data_len
108
+ value = value[:biggest_data_len]
109
+
110
+ setattr(dataset, item, value)
111
+
112
+ print(f"New {item} dataset length: {dataset.__len__()}")
113
+ extended.append(item)
114
+
115
+
116
+ def export_to_video(video_frames, output_video_path, fps, resctrl:ResolutionControl):
117
+ flattened_video_frames = [img for sublist in video_frames for img in sublist]
118
+ video_writer = imageio.get_writer(output_video_path, fps=fps)
119
+ for img in flattened_video_frames:
120
+ img = resctrl.callback(img)
121
+ video_writer.append_data(np.array(img))
122
+ video_writer.close()
123
+
124
+
125
+ def create_output_folders(output_dir, config, clip_id):
126
+ out_dir = os.path.join(output_dir, f"train_motion_lora/clip_{clip_id}")
127
+
128
+ os.makedirs(out_dir, exist_ok=True)
129
+ os.makedirs(f"{out_dir}/samples", exist_ok=True)
130
+ # OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))
131
+
132
+ return out_dir
133
+
134
+
135
+ def load_primary_models(pretrained_model_path):
136
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
137
+ pretrained_model_path, subfolder="scheduler")
138
+ feature_extractor = CLIPImageProcessor.from_pretrained(
139
+ pretrained_model_path, subfolder="feature_extractor", revision=None
140
+ )
141
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
142
+ pretrained_model_path, subfolder="image_encoder", revision=None, variant="fp16"
143
+ )
144
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
145
+ pretrained_model_path, subfolder="vae", revision=None, variant="fp16")
146
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
147
+ pretrained_model_path,
148
+ subfolder="unet",
149
+ low_cpu_mem_usage=True,
150
+ variant="fp16",
151
+ )
152
+
153
+ return noise_scheduler, feature_extractor, image_encoder, vae, unet
154
+
155
+
156
+ def unet_and_text_g_c(unet, image_encoder, unet_enable, image_enable):
157
+ unet.gradient_checkpointing = unet_enable
158
+ unet.mid_block.gradient_checkpointing = unet_enable
159
+ for module in unet.down_blocks + unet.up_blocks:
160
+ if isinstance(module,
161
+ (CrossAttnDownBlockSpatioTemporal,
162
+ DownBlockSpatioTemporal,
163
+ CrossAttnUpBlockSpatioTemporal,
164
+ UpBlockSpatioTemporal)):
165
+ module.gradient_checkpointing = unet_enable
166
+
167
+
168
+ def freeze_models(models_to_freeze):
169
+ for model in models_to_freeze:
170
+ if model is not None: model.requires_grad_(False)
171
+
172
+
173
+ def is_attn(name):
174
+ return ('attn1' or 'attn2' == name.split('.')[-1])
175
+
176
+
177
+ def set_processors(attentions):
178
+ for attn in attentions: attn.set_processor(AttnProcessor2_0())
179
+
180
+
181
+ def set_torch_2_attn(unet):
182
+ optim_count = 0
183
+
184
+ for name, module in unet.named_modules():
185
+ if is_attn(name):
186
+ if isinstance(module, torch.nn.ModuleList):
187
+ for m in module:
188
+ if isinstance(m, BasicTransformerBlock):
189
+ set_processors([m.attn1, m.attn2])
190
+ optim_count += 1
191
+ if optim_count > 0:
192
+ print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
193
+
194
+
195
+ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
196
+ try:
197
+ is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
198
+ enable_torch_2 = is_torch_2 and enable_torch_2_attn
199
+
200
+ if enable_xformers_memory_efficient_attention and not enable_torch_2:
201
+ if is_xformers_available():
202
+ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
203
+ unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
204
+ else:
205
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
206
+
207
+ if enable_torch_2:
208
+ set_torch_2_attn(unet)
209
+
210
+ except:
211
+ print("Could not enable memory efficient attention for xformers or Torch 2.0.")
212
+
213
+
214
+ def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
215
+ extra_params = extra_params if len(extra_params.keys()) > 0 else None
216
+ return {
217
+ "model": model,
218
+ "condition": condition,
219
+ 'extra_params': extra_params,
220
+ 'is_lora': is_lora,
221
+ "negation": negation
222
+ }
223
+
224
+
225
+ def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
226
+ params = {
227
+ "name": name,
228
+ "params": params,
229
+ "lr": lr
230
+ }
231
+ if extra_params is not None:
232
+ for k, v in extra_params.items():
233
+ params[k] = v
234
+
235
+ return params
236
+
237
+
238
+ def negate_params(name, negation):
239
+ # We have to do this if we are co-training with LoRA.
240
+ # This ensures that parameter groups aren't duplicated.
241
+ if negation is None: return False
242
+ for n in negation:
243
+ if n in name and 'temp' not in name:
244
+ return True
245
+ return False
246
+
247
+
248
+ def create_optimizer_params(model_list, lr):
249
+ import itertools
250
+ optimizer_params = []
251
+
252
+ for optim in model_list:
253
+ model, condition, extra_params, is_lora, negation = optim.values()
254
+ # Check if we are doing LoRA training.
255
+ if is_lora and condition and isinstance(model, list):
256
+ params = create_optim_params(
257
+ params=itertools.chain(*model),
258
+ extra_params=extra_params
259
+ )
260
+ optimizer_params.append(params)
261
+ continue
262
+
263
+ if is_lora and condition and not isinstance(model, list):
264
+ for n, p in model.named_parameters():
265
+ if 'lora' in n:
266
+ params = create_optim_params(n, p, lr, extra_params)
267
+ optimizer_params.append(params)
268
+ continue
269
+
270
+ # If this is true, we can train it.
271
+ if condition:
272
+ for n, p in model.named_parameters():
273
+ should_negate = 'lora' in n and not is_lora
274
+ if should_negate: continue
275
+
276
+ params = create_optim_params(n, p, lr, extra_params)
277
+ optimizer_params.append(params)
278
+
279
+ return optimizer_params
280
+
281
+
282
+ def get_optimizer(use_8bit_adam):
283
+ if use_8bit_adam:
284
+ try:
285
+ import bitsandbytes as bnb
286
+ except ImportError:
287
+ raise ImportError(
288
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
289
+ )
290
+
291
+ return bnb.optim.AdamW8bit
292
+ else:
293
+ return torch.optim.AdamW
294
+
295
+
296
+ def is_mixed_precision(accelerator):
297
+ weight_dtype = torch.float32
298
+
299
+ if accelerator.mixed_precision == "fp16":
300
+ weight_dtype = torch.float16
301
+
302
+ elif accelerator.mixed_precision == "bf16":
303
+ weight_dtype = torch.bfloat16
304
+
305
+ return weight_dtype
306
+
307
+
308
+ def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
309
+ for model in model_list:
310
+ if model is not None: model.to(accelerator.device, dtype=weight_dtype)
311
+
312
+
313
+ def inverse_video(pipe, latents, num_steps, image):
314
+ euler_inv_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
315
+ euler_inv_scheduler.set_timesteps(num_steps)
316
+
317
+ euler_inv_latent = euler_inversion(
318
+ pipe, euler_inv_scheduler, video_latent=latents.to(pipe.device),
319
+ num_inv_steps=num_steps, image=image)[-1]
320
+ return euler_inv_latent
321
+
322
+
323
+ def handle_cache_latents(
324
+ should_cache,
325
+ output_dir,
326
+ train_dataloader,
327
+ train_batch_size,
328
+ vae,
329
+ unet,
330
+ cached_latent_dir=None,
331
+ ):
332
+ # Cache latents by storing them in VRAM.
333
+ # Speeds up training and saves memory by not encoding during the train loop.
334
+ if not should_cache: return None
335
+ vae.to('cuda', dtype=torch.float32)
336
+ #vae.enable_slicing()
337
+
338
+ cached_latent_dir = (
339
+ os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
340
+ )
341
+
342
+ if cached_latent_dir is None:
343
+ cache_save_dir = f"{output_dir}/cached_latents"
344
+ os.makedirs(cache_save_dir, exist_ok=True)
345
+
346
+ for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
347
+
348
+ save_name = f"cached_{i}"
349
+ full_out_path = f"{cache_save_dir}/{save_name}.pt"
350
+
351
+ pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32)
352
+ refer_pixel_values = batch['refer_pixel_values'].to('cuda', dtype=torch.float32)
353
+ cross_pixel_values = batch['cross_pixel_values'].to('cuda', dtype=torch.float32)
354
+ batch['latents'] = tensor_to_vae_latent(pixel_values, vae)
355
+ if batch.get("refer_latents") is None:
356
+ batch['refer_latents'] = tensor_to_vae_latent(refer_pixel_values, vae)
357
+ batch['cross_latents'] = tensor_to_vae_latent(cross_pixel_values, vae)
358
+
359
+ for k, v in batch.items(): batch[k] = v[0]
360
+
361
+ torch.save(batch, full_out_path)
362
+ del pixel_values
363
+ del batch
364
+
365
+ # We do this to avoid fragmentation from casting latents between devices.
366
+ torch.cuda.empty_cache()
367
+ else:
368
+ cache_save_dir = cached_latent_dir
369
+
370
+ return torch.utils.data.DataLoader(
371
+ CachedDataset(cache_dir=cache_save_dir),
372
+ batch_size=train_batch_size,
373
+ shuffle=True,
374
+ num_workers=0
375
+ )
376
+
377
+
378
+ def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
379
+ global already_printed_trainables
380
+
381
+ # This can most definitely be refactored :-)
382
+ unfrozen_params = 0
383
+ if trainable_modules is not None:
384
+ for name, module in model.named_modules():
385
+ for tm in tuple(trainable_modules):
386
+ if tm == 'all':
387
+ model.requires_grad_(is_enabled)
388
+ unfrozen_params = len(list(model.parameters()))
389
+ break
390
+
391
+ if tm in name and 'lora' not in name:
392
+ for m in module.parameters():
393
+ m.requires_grad_(is_enabled)
394
+ if is_enabled: unfrozen_params += 1
395
+
396
+ if unfrozen_params > 0 and not already_printed_trainables:
397
+ already_printed_trainables = True
398
+ print(f"{unfrozen_params} params have been unfrozen for training.")
399
+
400
+ def sample_noise(latents, noise_strength, use_offset_noise=False):
401
+ b, c, f, *_ = latents.shape
402
+ noise_latents = torch.randn_like(latents, device=latents.device)
403
+
404
+ if use_offset_noise:
405
+ offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
406
+ noise_latents = noise_latents + noise_strength * offset_noise
407
+
408
+ return noise_latents
409
+
410
+
411
+ def enforce_zero_terminal_snr(betas):
412
+ """
413
+ Corrects noise in diffusion schedulers.
414
+ From: Common Diffusion Noise Schedules and Sample Steps are Flawed
415
+ https://arxiv.org/pdf/2305.08891.pdf
416
+ """
417
+ # Convert betas to alphas_bar_sqrt
418
+ alphas = 1 - betas
419
+ alphas_bar = alphas.cumprod(0)
420
+ alphas_bar_sqrt = alphas_bar.sqrt()
421
+
422
+ # Store old values.
423
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
424
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
425
+
426
+ # Shift so the last timestep is zero.
427
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
428
+
429
+ # Scale so the first timestep is back to the old value.
430
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
431
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T
432
+ )
433
+
434
+ # Convert alphas_bar_sqrt to betas
435
+ alphas_bar = alphas_bar_sqrt ** 2
436
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
437
+ alphas = torch.cat([alphas_bar[0:1], alphas])
438
+ betas = 1 - alphas
439
+
440
+ return betas
441
+
442
+
443
+ def should_sample(global_step, validation_steps, validation_data):
444
+ return global_step % validation_steps == 0 and validation_data.sample_preview
445
+
446
+
447
+ def save_pipe(
448
+ path,
449
+ global_step,
450
+ accelerator,
451
+ unet,
452
+ image_encoder,
453
+ vae,
454
+ output_dir,
455
+ lora_manager_spatial: LoraHandler,
456
+ lora_manager_temporal: LoraHandler,
457
+ unet_target_replace_module=None,
458
+ image_target_replace_module=None,
459
+ is_checkpoint=False,
460
+ save_pretrained_model=True,
461
+ ):
462
+ if is_checkpoint:
463
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
464
+ os.makedirs(save_path, exist_ok=True)
465
+ else:
466
+ save_path = output_dir
467
+
468
+ # Save the dtypes so we can continue training at the same precision.
469
+ u_dtype, i_dtype, v_dtype = unet.dtype, image_encoder.dtype, vae.dtype
470
+
471
+ # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
472
+ unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False))
473
+ image_encoder_out = copy.deepcopy(accelerator.unwrap_model(image_encoder.cpu(), keep_fp32_wrapper=False))
474
+ pipeline = P2PStableVideoDiffusionPipeline.from_pretrained(
475
+ path,
476
+ unet=unet_out,
477
+ image_encoder=image_encoder_out,
478
+ vae=accelerator.unwrap_model(vae),
479
+ # torch_dtype=weight_dtype,
480
+ ).to(torch_dtype=torch.float32)
481
+
482
+ # lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step)
483
+ if lora_manager_temporal is not None:
484
+ lora_manager_temporal.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/temporal', step=global_step)
485
+
486
+ if save_pretrained_model:
487
+ pipeline.save_pretrained(save_path)
488
+
489
+ if is_checkpoint:
490
+ unet, image_encoder = accelerator.prepare(unet, image_encoder)
491
+ models_to_cast_back = [(unet, u_dtype), (image_encoder, i_dtype), (vae, v_dtype)]
492
+ [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
493
+
494
+ logger.info(f"Saved model at {save_path} on step {global_step}")
495
+
496
+ del pipeline
497
+ del unet_out
498
+ del image_encoder_out
499
+ torch.cuda.empty_cache()
500
+ gc.collect()
501
+
502
+ def load_images_from_list(img_list):
503
+ images = []
504
+ valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"} # Add or remove extensions as needed
505
+
506
+ # Function to extract frame number from the filename
507
+ def frame_number(filename):
508
+ parts = filename.split('_')
509
+ if len(parts) > 1 and parts[0] == 'frame':
510
+ try:
511
+ return int(parts[1].split('.')[0]) # Extracting the number part
512
+ except ValueError:
513
+ return float('inf') # In case of non-integer part, place this file at the end
514
+ return float('inf') # Non-frame files are placed at the end
515
+
516
+ # Sorting files based on frame number
517
+ #sorted_files = sorted(os.listdir(folder), key=frame_number)
518
+ sorted_files = img_list
519
+
520
+ # Load images in sorted order
521
+ for filename in sorted_files:
522
+ ext = os.path.splitext(filename)[1].lower()
523
+ if ext in valid_extensions:
524
+ img = Image.open(filename).convert('RGB')
525
+ images.append(img)
526
+
527
+ return images
528
+
529
+ # copy from https://github.com/crowsonkb/k-diffusion.git
530
+ def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
531
+ """Draws stratified samples from a uniform distribution."""
532
+ if groups <= 0:
533
+ raise ValueError(f"groups must be positive, got {groups}")
534
+ if group < 0 or group >= groups:
535
+ raise ValueError(f"group must be in [0, {groups})")
536
+ n = shape[-1] * groups
537
+ offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
538
+ u = torch.rand(shape, dtype=dtype, device=device)
539
+ return (offsets + u) / n
540
+
541
+ def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
542
+ """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
543
+
544
+ def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
545
+ t_min = math.atan(math.exp(-0.5 * logsnr_max))
546
+ t_max = math.atan(math.exp(-0.5 * logsnr_min))
547
+ return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
548
+
549
+ def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
550
+ shift = 2 * math.log(noise_d / image_d)
551
+ return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
552
+
553
+ def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
554
+ logsnr_low = logsnr_schedule_cosine_shifted(
555
+ t, image_d, noise_d_low, logsnr_min, logsnr_max)
556
+ logsnr_high = logsnr_schedule_cosine_shifted(
557
+ t, image_d, noise_d_high, logsnr_min, logsnr_max)
558
+ return torch.lerp(logsnr_low, logsnr_high, t)
559
+
560
+ logsnr_min = -2 * math.log(min_value / sigma_data)
561
+ logsnr_max = -2 * math.log(max_value / sigma_data)
562
+ u = stratified_uniform(
563
+ shape, group=0, groups=1, dtype=dtype, device=device
564
+ )
565
+ logsnr = logsnr_schedule_cosine_interpolated(
566
+ u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
567
+ return torch.exp(-logsnr / 2) * sigma_data, u
568
+
569
+
570
+ min_value = 0.002
571
+ max_value = 700
572
+ image_d = 64
573
+ noise_d_low = 32
574
+ noise_d_high = 64
575
+ sigma_data = 0.5
576
+
577
+ def _compute_padding(kernel_size):
578
+ """Compute padding tuple."""
579
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
580
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
581
+ if len(kernel_size) < 2:
582
+ raise AssertionError(kernel_size)
583
+ computed = [k - 1 for k in kernel_size]
584
+
585
+ # for even kernels we need to do asymmetric padding :(
586
+ out_padding = 2 * len(kernel_size) * [0]
587
+
588
+ for i in range(len(kernel_size)):
589
+ computed_tmp = computed[-(i + 1)]
590
+
591
+ pad_front = computed_tmp // 2
592
+ pad_rear = computed_tmp - pad_front
593
+
594
+ out_padding[2 * i + 0] = pad_front
595
+ out_padding[2 * i + 1] = pad_rear
596
+
597
+ return out_padding
598
+
599
+ def _filter2d(input, kernel):
600
+ # prepare kernel
601
+ b, c, h, w = input.shape
602
+ tmp_kernel = kernel[:, None, ...].to(
603
+ device=input.device, dtype=input.dtype)
604
+
605
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
606
+
607
+ height, width = tmp_kernel.shape[-2:]
608
+
609
+ padding_shape: list[int] = _compute_padding([height, width])
610
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
611
+
612
+ # kernel and input tensor reshape to align element-wise or batch-wise params
613
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
614
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
615
+
616
+ # convolve the tensor with the kernel.
617
+ output = torch.nn.functional.conv2d(
618
+ input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
619
+
620
+ out = output.view(b, c, h, w)
621
+ return out
622
+
623
+
624
+ def _gaussian(window_size: int, sigma):
625
+ if isinstance(sigma, float):
626
+ sigma = torch.tensor([[sigma]])
627
+
628
+ batch_size = sigma.shape[0]
629
+
630
+ x = (torch.arange(window_size, device=sigma.device,
631
+ dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
632
+
633
+ if window_size % 2 == 0:
634
+ x = x + 0.5
635
+
636
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
637
+
638
+ return gauss / gauss.sum(-1, keepdim=True)
639
+
640
+
641
+ def _gaussian_blur2d(input, kernel_size, sigma):
642
+ if isinstance(sigma, tuple):
643
+ sigma = torch.tensor([sigma], dtype=input.dtype)
644
+ else:
645
+ sigma = sigma.to(dtype=input.dtype)
646
+
647
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
648
+ bs = sigma.shape[0]
649
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
650
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
651
+ out_x = _filter2d(input, kernel_x[..., None, :])
652
+ out = _filter2d(out_x, kernel_y[..., None])
653
+
654
+ return out
655
+
656
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
657
+ h, w = input.shape[-2:]
658
+ factors = (h / size[0], w / size[1])
659
+
660
+ # First, we have to determine sigma
661
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
662
+ sigmas = (
663
+ max((factors[0] - 1.0) / 2.0, 0.001),
664
+ max((factors[1] - 1.0) / 2.0, 0.001),
665
+ )
666
+
667
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
668
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
669
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
670
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
671
+
672
+ # Make sure it is odd
673
+ if (ks[0] % 2) == 0:
674
+ ks = ks[0] + 1, ks[1]
675
+
676
+ if (ks[1] % 2) == 0:
677
+ ks = ks[0], ks[1] + 1
678
+
679
+ input = _gaussian_blur2d(input, ks, sigmas)
680
+
681
+ output = torch.nn.functional.interpolate(
682
+ input, size=size, mode=interpolation, align_corners=align_corners)
683
+ return output
684
+
685
+ def train_motion_lora(
686
+ pretrained_model_path,
687
+ output_dir: str,
688
+ train_dataset: SingleClipDataset,
689
+ validation_data: Dict,
690
+ edited_firstframes: List[Image.Image],
691
+ train_data: Dict,
692
+ validation_images: List[Image.Image],
693
+ validation_images_latents: List[torch.Tensor],
694
+ clip_id: int,
695
+ consistency_controller: attention_util.ConsistencyAttentionControl = None,
696
+ consistency_edit_controller_list: List[attention_util.ConsistencyAttentionControl] = [None,],
697
+ consistency_find_modules: Dict = {},
698
+ single_spatial_lora: bool = False,
699
+ train_temporal_lora: bool = True,
700
+ validation_steps: int = 100,
701
+ trainable_modules: Tuple[str] = None, # Eg: ("attn1", "attn2")
702
+ extra_unet_params=None,
703
+ train_batch_size: int = 1,
704
+ max_train_steps: int = 500,
705
+ learning_rate: float = 5e-5,
706
+ lr_scheduler: str = "constant",
707
+ lr_warmup_steps: int = 0,
708
+ adam_beta1: float = 0.9,
709
+ adam_beta2: float = 0.999,
710
+ adam_weight_decay: float = 1e-2,
711
+ adam_epsilon: float = 1e-08,
712
+ gradient_accumulation_steps: int = 1,
713
+ gradient_checkpointing: bool = False,
714
+ image_encoder_gradient_checkpointing: bool = False,
715
+ checkpointing_steps: int = 500,
716
+ resume_from_checkpoint: Optional[str] = None,
717
+ resume_step: Optional[int] = None,
718
+ mixed_precision: Optional[str] = "fp16",
719
+ use_8bit_adam: bool = False,
720
+ enable_xformers_memory_efficient_attention: bool = True,
721
+ enable_torch_2_attn: bool = False,
722
+ seed: Optional[int] = None,
723
+ use_offset_noise: bool = False,
724
+ rescale_schedule: bool = False,
725
+ offset_noise_strength: float = 0.1,
726
+ extend_dataset: bool = False,
727
+ cache_latents: bool = False,
728
+ cached_latent_dir=None,
729
+ use_unet_lora: bool = False,
730
+ unet_lora_modules: Tuple[str] = [],
731
+ image_encoder_lora_modules: Tuple[str] = [],
732
+ save_pretrained_model: bool = True,
733
+ lora_rank: int = 16,
734
+ lora_path: str = '',
735
+ lora_unet_dropout: float = 0.1,
736
+ logger_type: str = 'tensorboard',
737
+ **kwargs
738
+ ):
739
+
740
+ *_, config = inspect.getargvalues(inspect.currentframe())
741
+
742
+ accelerator = Accelerator(
743
+ gradient_accumulation_steps=gradient_accumulation_steps,
744
+ mixed_precision=mixed_precision,
745
+ log_with=logger_type,
746
+ project_dir=output_dir
747
+ )
748
+
749
+ # Make one log on every process with the configuration for debugging.
750
+ create_logging(logging, logger, accelerator)
751
+
752
+ # Initialize accelerate, transformers, and diffusers warnings
753
+ accelerate_set_verbose(accelerator)
754
+
755
+ # Handle the output folder creation
756
+ if accelerator.is_main_process:
757
+ output_dir = create_output_folders(output_dir, config, clip_id)
758
+
759
+ # Load scheduler, tokenizer and models.
760
+ noise_scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(pretrained_model_path)
761
+
762
+ # Freeze any necessary models
763
+ freeze_models([vae, image_encoder, unet])
764
+
765
+ # Enable xformers if available
766
+ handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
767
+
768
+ # Initialize the optimizer
769
+ optimizer_cls = get_optimizer(use_8bit_adam)
770
+
771
+ # Create parameters to optimize over with a condition (if "condition" is true, optimize it)
772
+ #extra_unet_params = extra_unet_params if extra_unet_params is not None else {}
773
+ #extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {}
774
+
775
+ # Temporal LoRA
776
+ if train_temporal_lora:
777
+ # one temporal lora
778
+ lora_manager_temporal = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["TemporalBasicTransformerBlock"])
779
+
780
+ unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model(
781
+ use_unet_lora, unet, lora_manager_temporal.unet_replace_modules, lora_unet_dropout,
782
+ lora_path + '/temporal/lora/', r=lora_rank)
783
+
784
+ optimizer_temporal = optimizer_cls(
785
+ create_optimizer_params([param_optim(unet_lora_params_temporal, use_unet_lora, is_lora=True,
786
+ extra_params={**{"lr": learning_rate}}
787
+ )], learning_rate),
788
+ lr=learning_rate,
789
+ betas=(adam_beta1, adam_beta2),
790
+ weight_decay=adam_weight_decay,
791
+ eps=adam_epsilon,
792
+ )
793
+
794
+ lr_scheduler_temporal = get_scheduler(
795
+ lr_scheduler,
796
+ optimizer=optimizer_temporal,
797
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
798
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
799
+ )
800
+ else:
801
+ lora_manager_temporal = None
802
+ unet_lora_params_temporal, unet_negation_temporal = [], []
803
+ optimizer_temporal = None
804
+ lr_scheduler_temporal = None
805
+
806
+ # Spatial LoRAs
807
+ if single_spatial_lora:
808
+ spatial_lora_num = 1
809
+ else:
810
+ # one spatial lora for each video
811
+ spatial_lora_num = train_dataset.__len__()
812
+
813
+ lora_managers_spatial = []
814
+ unet_lora_params_spatial_list = []
815
+ optimizer_spatial_list = []
816
+ lr_scheduler_spatial_list = []
817
+ for i in range(spatial_lora_num):
818
+ lora_manager_spatial = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["BasicTransformerBlock"])
819
+ lora_managers_spatial.append(lora_manager_spatial)
820
+ unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model(
821
+ use_unet_lora, unet, lora_manager_spatial.unet_replace_modules, lora_unet_dropout,
822
+ lora_path + '/spatial/lora/', r=lora_rank)
823
+
824
+ unet_lora_params_spatial_list.append(unet_lora_params_spatial)
825
+
826
+ optimizer_spatial = optimizer_cls(
827
+ create_optimizer_params([param_optim(unet_lora_params_spatial, use_unet_lora, is_lora=True,
828
+ extra_params={**{"lr": learning_rate}}
829
+ )], learning_rate),
830
+ lr=learning_rate,
831
+ betas=(adam_beta1, adam_beta2),
832
+ weight_decay=adam_weight_decay,
833
+ eps=adam_epsilon,
834
+ )
835
+
836
+ optimizer_spatial_list.append(optimizer_spatial)
837
+
838
+ # Scheduler
839
+ lr_scheduler_spatial = get_scheduler(
840
+ lr_scheduler,
841
+ optimizer=optimizer_spatial,
842
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
843
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
844
+ )
845
+ lr_scheduler_spatial_list.append(lr_scheduler_spatial)
846
+
847
+ unet_negation_all = unet_negation_spatial + unet_negation_temporal
848
+
849
+ # DataLoaders creation:
850
+ train_dataloader = torch.utils.data.DataLoader(
851
+ train_dataset,
852
+ batch_size=train_batch_size,
853
+ shuffle=True
854
+ )
855
+
856
+ # Latents caching
857
+ cached_data_loader = handle_cache_latents(
858
+ cache_latents,
859
+ output_dir,
860
+ train_dataloader,
861
+ train_batch_size,
862
+ vae,
863
+ unet,
864
+ cached_latent_dir
865
+ )
866
+
867
+ if cached_data_loader is not None and train_data.get("use_data_aug") is None:
868
+ train_dataloader = cached_data_loader
869
+
870
+ # Prepare everything with our `accelerator`.
871
+ unet, optimizer_temporal, train_dataloader, lr_scheduler_temporal, image_encoder = accelerator.prepare(
872
+ unet,
873
+ optimizer_temporal,
874
+ train_dataloader,
875
+ lr_scheduler_temporal,
876
+ image_encoder
877
+ )
878
+
879
+ # Use Gradient Checkpointing if enabled.
880
+ unet_and_text_g_c(
881
+ unet,
882
+ image_encoder,
883
+ gradient_checkpointing,
884
+ image_encoder_gradient_checkpointing
885
+ )
886
+
887
+ # Enable VAE slicing to save memory.
888
+ #vae.enable_slicing()
889
+
890
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
891
+ # as these models are only used for inference, keeping weights in full precision is not required.
892
+ weight_dtype = is_mixed_precision(accelerator)
893
+
894
+ # Move text encoders, and VAE to GPU
895
+ models_to_cast = [image_encoder, vae]
896
+ cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)
897
+
898
+ # Fix noise schedules to predcit light and dark areas if available.
899
+ # if not use_offset_noise and rescale_schedule:
900
+ # noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas)
901
+
902
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
903
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
904
+
905
+ # Afterwards we recalculate our number of training epochs
906
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
907
+
908
+ # We need to initialize the trackers we use, and also store our configuration.
909
+ # The trackers initializes automatically on the main process.
910
+ if accelerator.is_main_process:
911
+ accelerator.init_trackers("svd-finetune")
912
+
913
+ # Train!
914
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
915
+
916
+ logger.info("***** Running training for motion lora*****")
917
+ logger.info(f" Num examples = {len(train_dataset)}")
918
+ logger.info(f" Num Epochs = {num_train_epochs}")
919
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
920
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
921
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
922
+ logger.info(f" Total optimization steps = {max_train_steps}")
923
+ global_step = 0
924
+ first_epoch = 0
925
+
926
+ def encode_image(pixel_values):
927
+ # pixel_values = pixel_values * 2.0 - 1.0
928
+ pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
929
+ pixel_values = (pixel_values + 1.0) / 2.0
930
+
931
+ # Normalize the image with for CLIP input
932
+ pixel_values = feature_extractor(
933
+ images=pixel_values,
934
+ do_normalize=True,
935
+ do_center_crop=False,
936
+ do_resize=False,
937
+ do_rescale=False,
938
+ return_tensors="pt",
939
+ ).pixel_values
940
+
941
+ pixel_values = pixel_values.to(
942
+ device=accelerator.device, dtype=weight_dtype)
943
+ image_embeddings = image_encoder(pixel_values).image_embeds
944
+ image_embeddings= image_embeddings.unsqueeze(1)
945
+ return image_embeddings
946
+
947
+ def _get_add_time_ids(
948
+ fps,
949
+ motion_bucket_ids, # Expecting a list of tensor floats
950
+ noise_aug_strength,
951
+ dtype,
952
+ batch_size,
953
+ unet=None,
954
+ device=None, # Add a device parameter
955
+ ):
956
+ # Determine the target device
957
+ target_device = device if device is not None else 'cpu'
958
+
959
+ # Ensure motion_bucket_ids is a tensor and on the target device
960
+ if not isinstance(motion_bucket_ids, torch.Tensor):
961
+ motion_bucket_ids = torch.tensor(motion_bucket_ids, dtype=dtype, device=target_device)
962
+ else:
963
+ motion_bucket_ids = motion_bucket_ids.to(device=target_device)
964
+
965
+ # Reshape motion_bucket_ids if necessary
966
+ if motion_bucket_ids.dim() == 1:
967
+ motion_bucket_ids = motion_bucket_ids.view(-1, 1)
968
+
969
+ # Check for batch size consistency
970
+ if motion_bucket_ids.size(0) != batch_size:
971
+ raise ValueError("The length of motion_bucket_ids must match the batch_size.")
972
+
973
+ # Create fps and noise_aug_strength tensors on the target device
974
+ add_time_ids = torch.tensor([fps, noise_aug_strength], dtype=dtype, device=target_device).repeat(batch_size, 1)
975
+
976
+ # Concatenate with motion_bucket_ids
977
+ add_time_ids = torch.cat([add_time_ids, motion_bucket_ids], dim=1)
978
+
979
+ # Checking the dimensions of the added time embedding
980
+ passed_add_embed_dim = unet.config.addition_time_embed_dim * add_time_ids.size(1)
981
+ expected_add_embed_dim = unet.add_embedding.linear_1.in_features
982
+
983
+ if expected_add_embed_dim != passed_add_embed_dim:
984
+ raise ValueError(
985
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
986
+ f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. "
987
+ "Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
988
+ )
989
+
990
+ return add_time_ids
991
+
992
+ # Only show the progress bar once on each machine.
993
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
994
+ progress_bar.set_description("Steps")
995
+
996
+ # set consistency controller
997
+ if consistency_controller is not None:
998
+ consistency_train_controller = attention_util.ConsistencyAttentionControl(
999
+ additional_attention_store=consistency_controller,
1000
+ use_inversion_attention=True,
1001
+ save_self_attention=False,
1002
+ save_latents=False,
1003
+ disk_store=True
1004
+ )
1005
+ attention_util.register_attention_control(
1006
+ unet,
1007
+ None,
1008
+ consistency_train_controller,
1009
+ find_modules={},
1010
+ consistency_find_modules=consistency_find_modules
1011
+ )
1012
+
1013
+ def finetune_unet(batch, step, mask_spatial_lora=False, mask_temporal_lora=False):
1014
+ nonlocal use_offset_noise
1015
+ nonlocal rescale_schedule
1016
+
1017
+
1018
+ # Unfreeze UNET Layers
1019
+ if global_step == 0:
1020
+ already_printed_trainables = False
1021
+ unet.train()
1022
+ handle_trainable_modules(
1023
+ unet,
1024
+ trainable_modules,
1025
+ is_enabled=True,
1026
+ negation=unet_negation_all
1027
+ )
1028
+
1029
+ # Convert videos to latent space
1030
+ #print("use_data_aug", train_data.get("use_data_aug"))
1031
+ if not cache_latents or train_data.get("use_data_aug") is not None:
1032
+ latents = tensor_to_vae_latent(batch["pixel_values"], vae)
1033
+ refer_latents = tensor_to_vae_latent(batch["refer_pixel_values"], vae)
1034
+ cross_latents = tensor_to_vae_latent(batch["cross_pixel_values"], vae)
1035
+ else:
1036
+ latents = batch["latents"]
1037
+ refer_latents = batch["refer_latents"]
1038
+ cross_latents = batch["cross_latents"]
1039
+
1040
+ # Sample noise that we'll add to the latents
1041
+ use_offset_noise = use_offset_noise and not rescale_schedule
1042
+ noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
1043
+ noise_1 = sample_noise(latents, offset_noise_strength, False)
1044
+ bsz = latents.shape[0]
1045
+
1046
+ # Sample a random timestep for each video
1047
+ sigmas, u = rand_cosine_interpolated(shape=[bsz,], image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high,
1048
+ sigma_data=sigma_data, min_value=min_value, max_value=max_value)
1049
+ noise_scheduler.set_timesteps(validation_data.num_inference_steps, device=latents.device)
1050
+ all_sigmas = noise_scheduler.sigmas
1051
+ sigmas = sigmas.to(latents.device)
1052
+ timestep = (validation_data.num_inference_steps - torch.searchsorted(all_sigmas.to(latents.device).flip(dims=(0,)), sigmas, right=False)).clamp(0,validation_data.num_inference_steps-1)[0]
1053
+ u = u.item()
1054
+ if consistency_controller is not None:
1055
+ #timestep = int(u * (validation_data.num_inference_steps-1)+0.5)
1056
+ #print("u", u, "timestep", timestep, "sigmas", sigmas, "all_sigmas", all_sigmas)
1057
+ consistency_train_controller.set_cur_step(timestep)
1058
+ # Add noise to the latents according to the noise magnitude at each timestep
1059
+ # (this is the forward diffusion process)
1060
+ sigmas_reshaped = sigmas.clone()
1061
+ while len(sigmas_reshaped.shape) < len(latents.shape):
1062
+ sigmas_reshaped = sigmas_reshaped.unsqueeze(-1)
1063
+
1064
+ # add noise to the latents or the original image?
1065
+ train_noise_aug = 0.02
1066
+ conditional_latents = refer_latents / vae.config.scaling_factor
1067
+ small_noise_latents = conditional_latents + noise_1[:,0:1,:,:,:] * train_noise_aug
1068
+ conditional_latents = small_noise_latents[:, 0, :, :, :]
1069
+
1070
+ noisy_latents = latents + noise * sigmas_reshaped
1071
+
1072
+ timesteps = torch.Tensor(
1073
+ [0.25 * sigma.log() for sigma in sigmas]).to(latents.device)
1074
+
1075
+ inp_noisy_latents = noisy_latents / ((sigmas_reshaped**2 + 1) ** 0.5)
1076
+
1077
+ # *Potentially* Fixes gradient checkpointing training.
1078
+ # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
1079
+ if kwargs.get('eval_train', False):
1080
+ unet.eval()
1081
+ image_encoder.eval()
1082
+
1083
+ # Get the text embedding for conditioning.
1084
+ encoder_hidden_states = encode_image(
1085
+ batch["cross_pixel_values"][:, 0, :, :, :])
1086
+ detached_encoder_state = encoder_hidden_states.clone().detach()
1087
+
1088
+ added_time_ids = _get_add_time_ids(
1089
+ 6,
1090
+ batch["motion_values"],
1091
+ train_noise_aug, # noise_aug_strength == 0.0
1092
+ encoder_hidden_states.dtype,
1093
+ bsz,
1094
+ unet,
1095
+ device=latents.device
1096
+ )
1097
+ added_time_ids = added_time_ids.to(latents.device)
1098
+
1099
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
1100
+ conditioning_dropout_prob = kwargs.get('conditioning_dropout_prob')
1101
+ if conditioning_dropout_prob is not None:
1102
+ random_p = torch.rand(
1103
+ bsz, device=latents.device, generator=generator)
1104
+ # Sample masks for the edit prompts.
1105
+ prompt_mask = random_p < 2 * conditioning_dropout_prob
1106
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
1107
+ # Final text conditioning.
1108
+ null_conditioning = torch.zeros_like(encoder_hidden_states)
1109
+ encoder_hidden_states = torch.where(
1110
+ prompt_mask, null_conditioning, encoder_hidden_states)
1111
+
1112
+ # Sample masks for the original images.
1113
+ image_mask_dtype = conditional_latents.dtype
1114
+ image_mask = 1 - (
1115
+ (random_p >= conditioning_dropout_prob).to(
1116
+ image_mask_dtype)
1117
+ * (random_p < 3 * conditioning_dropout_prob).to(image_mask_dtype)
1118
+ )
1119
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
1120
+ # Final image conditioning.
1121
+ conditional_latents = image_mask * conditional_latents
1122
+
1123
+ # Concatenate the `conditional_latents` with the `noisy_latents`.
1124
+ conditional_latents = conditional_latents.unsqueeze(
1125
+ 1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
1126
+ inp_noisy_latents = torch.cat(
1127
+ [inp_noisy_latents, conditional_latents], dim=2)
1128
+
1129
+ # Get the target for loss depending on the prediction type
1130
+ # if noise_scheduler.config.prediction_type == "epsilon":
1131
+ # target = latents # we are computing loss against denoise latents
1132
+ # elif noise_scheduler.config.prediction_type == "v_prediction":
1133
+ # target = noise_scheduler.get_velocity(
1134
+ # latents, noise, timesteps)
1135
+ # else:
1136
+ # raise ValueError(
1137
+ # f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1138
+
1139
+ target = latents
1140
+
1141
+ encoder_hidden_states = detached_encoder_state
1142
+
1143
+ if True:#mask_spatial_lora:
1144
+ loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
1145
+ for lora_i in loras:
1146
+ lora_i.scale = 0.
1147
+ loss_spatial = None
1148
+ else:
1149
+ loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
1150
+
1151
+ if spatial_lora_num == 1:
1152
+ for lora_i in loras:
1153
+ lora_i.scale = 1.
1154
+ else:
1155
+ for lora_i in loras:
1156
+ lora_i.scale = 0.
1157
+
1158
+ for lora_idx in range(0, len(loras), spatial_lora_num):
1159
+ loras[lora_idx + step].scale = 1.
1160
+
1161
+ loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
1162
+ if len(loras) > 0:
1163
+ for lora_i in loras:
1164
+ lora_i.scale = 0.
1165
+
1166
+ ran_idx = 0#torch.randint(0, noisy_latents.shape[2], (1,)).item()
1167
+
1168
+ #spatial_inp_noisy_latents = inp_noisy_refer_latents[:, ran_idx:ran_idx+1, :, :, :]
1169
+ inp_noisy_spatial_latents = inp_noisy_latents#[:, ran_idx:ran_idx+1, :, :, :]
1170
+
1171
+ target_spatial = latents#[:, ran_idx:ran_idx+1, :, :, :]
1172
+ # Predict the noise residual
1173
+ model_pred = unet(
1174
+ inp_noisy_spatial_latents, timesteps, encoder_hidden_states,
1175
+ added_time_ids
1176
+ ).sample
1177
+
1178
+ sigmas = sigmas_reshaped
1179
+ # Denoise the latents
1180
+ c_out = -sigmas / ((sigmas**2 + 1)**0.5)
1181
+ c_skip = 1 / (sigmas**2 + 1)
1182
+ denoised_latents = model_pred * c_out + c_skip * noisy_latents#[:, ran_idx:ran_idx+1, :, :, :]
1183
+ weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
1184
+
1185
+ # MSE loss
1186
+ loss_spatial = torch.mean(
1187
+ (weighing.float() * (denoised_latents.float() -
1188
+ target_spatial.float()) ** 2).reshape(target_spatial.shape[0], -1),
1189
+ dim=1,
1190
+ )
1191
+ loss_spatial = loss_spatial.mean()
1192
+
1193
+ if mask_temporal_lora:
1194
+ loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
1195
+ for lora_i in loras:
1196
+ lora_i.scale = 0.
1197
+ loss_temporal = None
1198
+ else:
1199
+ loras = extract_lora_child_module(unet, target_replace_module=["TemporalBasicTransformerBlock"])
1200
+ for lora_i in loras:
1201
+ lora_i.scale = 1.
1202
+ # Predict the noise residual
1203
+ model_pred = unet(
1204
+ inp_noisy_latents, timesteps, encoder_hidden_states,
1205
+ added_time_ids=added_time_ids,
1206
+ ).sample
1207
+
1208
+ sigmas = sigmas_reshaped
1209
+ # Denoise the latents
1210
+ c_out = -sigmas / ((sigmas**2 + 1)**0.5)
1211
+ c_skip = 1 / (sigmas**2 + 1)
1212
+ denoised_latents = model_pred * c_out + c_skip * noisy_latents
1213
+ if consistency_controller is not None:
1214
+ consistency_train_controller.step_callback(denoised_latents.detach())
1215
+ weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
1216
+
1217
+ # MSE loss
1218
+ loss_temporal = torch.mean(
1219
+ (weighing.float() * (denoised_latents.float() -
1220
+ target.float()) ** 2).reshape(target.shape[0], -1),
1221
+ dim=1,
1222
+ )
1223
+ loss_temporal = loss_temporal.mean()
1224
+
1225
+ # beta = 1
1226
+ # alpha = (beta ** 2 + 1) ** 0.5
1227
+ # ran_idx = torch.randint(0, model_pred.shape[1], (1,)).item()
1228
+ # model_pred_decent = alpha * model_pred - beta * model_pred[:, ran_idx, :, :, :].unsqueeze(1)
1229
+ # target_decent = alpha * target - beta * target[:, ran_idx, :, :, :].unsqueeze(1)
1230
+ # loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
1231
+ loss_temporal = loss_temporal #+ loss_ad_temporal
1232
+
1233
+ return loss_spatial, loss_temporal, latents, noise
1234
+
1235
+ for epoch in range(first_epoch, num_train_epochs):
1236
+ train_loss_spatial = 0.0
1237
+ train_loss_temporal = 0.0
1238
+
1239
+
1240
+ for step, batch in enumerate(train_dataloader):
1241
+ #torch.cuda.empty_cache()
1242
+ # Skip steps until we reach the resumed step
1243
+ if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1244
+ if step % gradient_accumulation_steps == 0:
1245
+ progress_bar.update(1)
1246
+ continue
1247
+
1248
+ with accelerator.accumulate(unet):
1249
+
1250
+ for optimizer_spatial in optimizer_spatial_list:
1251
+ optimizer_spatial.zero_grad(set_to_none=True)
1252
+
1253
+ if optimizer_temporal is not None:
1254
+ optimizer_temporal.zero_grad(set_to_none=True)
1255
+
1256
+ if train_temporal_lora:
1257
+ mask_temporal_lora = False
1258
+ else:
1259
+ mask_temporal_lora = True
1260
+ if False:#clip_id != 0:
1261
+ mask_spatial_lora = random.uniform(0, 1) < 0.2 and not mask_temporal_lora
1262
+ else:
1263
+ mask_spatial_lora = True
1264
+
1265
+ with accelerator.autocast():
1266
+ loss_spatial, loss_temporal, latents, init_noise = finetune_unet(batch, step, mask_spatial_lora=mask_spatial_lora, mask_temporal_lora=mask_temporal_lora)
1267
+
1268
+ # Gather the losses across all processes for logging (if we use distributed training).
1269
+ if not mask_spatial_lora:
1270
+ avg_loss_spatial = accelerator.gather(loss_spatial.repeat(train_batch_size)).mean()
1271
+ train_loss_spatial += avg_loss_spatial.item() / gradient_accumulation_steps
1272
+
1273
+ if not mask_temporal_lora and train_temporal_lora:
1274
+ avg_loss_temporal = accelerator.gather(loss_temporal.repeat(train_batch_size)).mean()
1275
+ train_loss_temporal += avg_loss_temporal.item() / gradient_accumulation_steps
1276
+
1277
+ # Backpropagate
1278
+ if not mask_spatial_lora:
1279
+ accelerator.backward(loss_spatial, retain_graph=True)
1280
+ if spatial_lora_num == 1:
1281
+ optimizer_spatial_list[0].step()
1282
+ else:
1283
+ optimizer_spatial_list[step].step()
1284
+ if spatial_lora_num == 1:
1285
+ lr_scheduler_spatial_list[0].step()
1286
+ else:
1287
+ lr_scheduler_spatial_list[step].step()
1288
+
1289
+ if not mask_temporal_lora and train_temporal_lora:
1290
+ accelerator.backward(loss_temporal)
1291
+ optimizer_temporal.step()
1292
+
1293
+ if lr_scheduler_temporal is not None:
1294
+ lr_scheduler_temporal.step()
1295
+
1296
+ # Checks if the accelerator has performed an optimization step behind the scenes
1297
+ if accelerator.sync_gradients:
1298
+ progress_bar.update(1)
1299
+ global_step += 1
1300
+ accelerator.log({"train_loss": train_loss_temporal}, step=global_step)
1301
+ train_loss_temporal = 0.0
1302
+ if global_step % checkpointing_steps == 0 and global_step > 0:
1303
+ save_pipe(
1304
+ pretrained_model_path,
1305
+ global_step,
1306
+ accelerator,
1307
+ unet,
1308
+ image_encoder,
1309
+ vae,
1310
+ output_dir,
1311
+ lora_manager_spatial,
1312
+ lora_manager_temporal,
1313
+ unet_lora_modules,
1314
+ image_encoder_lora_modules,
1315
+ is_checkpoint=True,
1316
+ save_pretrained_model=save_pretrained_model
1317
+ )
1318
+
1319
+ if should_sample(global_step, validation_steps, validation_data):
1320
+ if accelerator.is_main_process:
1321
+ with accelerator.autocast():
1322
+ unet.eval()
1323
+ image_encoder.eval()
1324
+ generator = torch.Generator(device="cpu")
1325
+ generator.manual_seed(seed)
1326
+ unet_and_text_g_c(unet, image_encoder, False, False)
1327
+ loras = extract_lora_child_module(unet, target_replace_module=["BasicTransformerBlock"])
1328
+ for lora_i in loras:
1329
+ lora_i.scale = 0.0
1330
+
1331
+ if consistency_controller is not None:
1332
+ attention_util.register_attention_control(
1333
+ unet,
1334
+ None,
1335
+ consistency_train_controller,
1336
+ find_modules={},
1337
+ consistency_find_modules=consistency_find_modules,
1338
+ undo=True
1339
+ )
1340
+
1341
+ pipeline = P2PStableVideoDiffusionPipeline.from_pretrained(
1342
+ pretrained_model_path,
1343
+ image_encoder=image_encoder,
1344
+ vae=vae,
1345
+ unet=unet
1346
+ )
1347
+ if consistency_controller is not None:
1348
+ pipeline.scheduler = P2PEulerDiscreteScheduler.from_config(pipeline.scheduler.config)
1349
+
1350
+ # # recalculate inversed noise latent
1351
+ # if any([np > 0. for np in validation_data.noise_prior]):
1352
+ # pixel_values_for_inv = batch['pixel_values_for_inv'].to('cuda', dtype=torch.float16)
1353
+ # batch['inversion_noise'] = inverse_video(pipeline, batch['latents_for_inv'], 25, pixel_values_for_inv[:,0,:,:,:])
1354
+
1355
+ preset_noises = []
1356
+ for noise_prior in validation_data.noise_prior:
1357
+ if noise_prior > 0:
1358
+ assert batch['inversion_noise'] is not None, "inversion_noise should not be None when noise_prior > 0"
1359
+ preset_noise = (noise_prior) ** 0.5 * batch['inversion_noise'] + (
1360
+ 1-noise_prior) ** 0.5 * torch.randn_like(batch['inversion_noise'])
1361
+ #print("preset noise", torch.mean(preset_noise), torch.std(preset_noise))
1362
+ else:
1363
+ preset_noise = None
1364
+ preset_noises.append( preset_noise )
1365
+
1366
+ for val_img_idx in range(len(validation_images)):
1367
+ for i in range(len(preset_noises)):
1368
+
1369
+ if consistency_controller is not None:
1370
+ consistency_edit_controller = attention_util.ConsistencyAttentionControl(
1371
+ additional_attention_store=consistency_edit_controller_list[val_img_idx],
1372
+ use_inversion_attention=False,
1373
+ save_self_attention=False,
1374
+ save_latents=False,
1375
+ disk_store=True
1376
+ )
1377
+ attention_util.register_attention_control(
1378
+ pipeline.unet,
1379
+ None,
1380
+ consistency_edit_controller,
1381
+ find_modules={},
1382
+ consistency_find_modules=consistency_find_modules,
1383
+ )
1384
+ pipeline.scheduler.controller = [consistency_edit_controller]
1385
+
1386
+ preset_noise = preset_noises[i]
1387
+ save_filename = f"step_{global_step}_noise_{i}_{val_img_idx}"
1388
+
1389
+ out_file = f"{output_dir}/samples/{save_filename}.mp4"
1390
+
1391
+ val_img = validation_images[val_img_idx]
1392
+ edited_firstframe = edited_firstframes[val_img_idx]
1393
+ original_res = val_img.size
1394
+ resctrl = ResolutionControl(
1395
+ (original_res[1],original_res[0]),
1396
+ (validation_data.height, validation_data.width),
1397
+ validation_data.get("pad_to_fit", False),
1398
+ fill=0
1399
+ )
1400
+
1401
+ #val_img = Image.open("white.png").convert("RGB")
1402
+ val_img = resctrl(val_img)
1403
+ edited_firstframe = resctrl(edited_firstframe)
1404
+
1405
+ with torch.no_grad():
1406
+ video_frames = pipeline(
1407
+ val_img,
1408
+ edited_firstframe=edited_firstframe,
1409
+ image_latents=validation_images_latents[val_img_idx],
1410
+ width=validation_data.width,
1411
+ height=validation_data.height,
1412
+ num_frames=batch["pixel_values"].shape[1],
1413
+ decode_chunk_size=8,
1414
+ motion_bucket_id=127,
1415
+ fps=validation_data.get('fps', 7),
1416
+ noise_aug_strength=0.02,
1417
+ generator=generator,
1418
+ num_inference_steps=validation_data.num_inference_steps,
1419
+ latents=preset_noise
1420
+ ).frames
1421
+ export_to_video(video_frames, out_file, validation_data.get('fps', 7), resctrl)
1422
+ if consistency_controller is not None:
1423
+ attention_util.register_attention_control(
1424
+ pipeline.unet,
1425
+ None,
1426
+ consistency_edit_controller,
1427
+ find_modules={},
1428
+ consistency_find_modules=consistency_find_modules,
1429
+ undo=True
1430
+ )
1431
+ consistency_edit_controller.delete()
1432
+ del consistency_edit_controller
1433
+ logger.info(f"Saved a new sample to {out_file}")
1434
+ if consistency_controller is not None:
1435
+ attention_util.register_attention_control(
1436
+ unet,
1437
+ None,
1438
+ consistency_train_controller,
1439
+ find_modules={},
1440
+ consistency_find_modules=consistency_find_modules,
1441
+ )
1442
+ del pipeline
1443
+ torch.cuda.empty_cache()
1444
+
1445
+ unet_and_text_g_c(
1446
+ unet,
1447
+ image_encoder,
1448
+ gradient_checkpointing,
1449
+ image_encoder_gradient_checkpointing
1450
+ )
1451
+
1452
+ if loss_temporal is not None:
1453
+ accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step)
1454
+
1455
+ if global_step >= max_train_steps:
1456
+ break
1457
+
1458
+ # Create the pipeline using the trained modules and save it.
1459
+ accelerator.wait_for_everyone()
1460
+ if accelerator.is_main_process:
1461
+ save_pipe(
1462
+ pretrained_model_path,
1463
+ global_step,
1464
+ accelerator,
1465
+ unet,
1466
+ image_encoder,
1467
+ vae,
1468
+ output_dir,
1469
+ lora_manager_spatial,
1470
+ lora_manager_temporal,
1471
+ unet_lora_modules,
1472
+ image_encoder_lora_modules,
1473
+ is_checkpoint=False,
1474
+ save_pretrained_model=save_pretrained_model
1475
+ )
1476
+ accelerator.end_training()
1477
+
1478
+ if consistency_controller is not None:
1479
+ consistency_train_controller.delete()
1480
+ del consistency_train_controller
1481
+
1482
+
1483
+ if __name__ == "__main__":
1484
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
1485
+ parser = argparse.ArgumentParser()
1486
+ parser.add_argument("--config", type=str, default='./configs/config_multi_videos.yaml')
1487
+ args = parser.parse_args()
1488
+ train_motion_lora(**OmegaConf.load(args.config))
i2vedit/utils/__init__.py ADDED
File without changes
i2vedit/utils/bucketing.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ def min_res(size, min_size): return 192 if size < 192 else size
4
+
5
+ def up_down_bucket(m_size, in_size, direction):
6
+ if direction == 'down': return abs(int(m_size - in_size))
7
+ if direction == 'up': return abs(int(m_size + in_size))
8
+
9
+ def get_bucket_sizes(size, direction: 'down', min_size):
10
+ multipliers = [64, 128]
11
+ for i, m in enumerate(multipliers):
12
+ res = up_down_bucket(m, size, direction)
13
+ multipliers[i] = min_res(res, min_size=min_size)
14
+ return multipliers
15
+
16
+ def closest_bucket(m_size, size, direction, min_size):
17
+ lst = get_bucket_sizes(m_size, direction, min_size)
18
+ return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
19
+
20
+ def resolve_bucket(i,h,w): return (i / (h / w))
21
+
22
+ def sensible_buckets(m_width, m_height, w, h, min_size=192):
23
+ if h > w:
24
+ w = resolve_bucket(m_width, h, w)
25
+ w = closest_bucket(m_width, w, 'down', min_size=min_size)
26
+ return w, m_height
27
+ if h < w:
28
+ h = resolve_bucket(m_height, w, h)
29
+ h = closest_bucket(m_height, h, 'down', min_size=min_size)
30
+ return m_width, h
31
+
32
+ return m_width, m_height
i2vedit/utils/dataset.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import decord
3
+ import numpy as np
4
+ import random
5
+ import json
6
+ import torchvision
7
+ import torchvision.transforms as T
8
+ import torch
9
+ from torchvision.transforms import Resize, Pad, InterpolationMode, ToTensor
10
+
11
+ from glob import glob
12
+ from PIL import Image
13
+ from itertools import islice
14
+ from pathlib import Path
15
+ from .bucketing import sensible_buckets
16
+
17
+ decord.bridge.set_bridge('torch')
18
+
19
+ from torch.utils.data import Dataset
20
+ from einops import rearrange, repeat
21
+
22
+ def pad_with_ratio(frames, res, fill=0):
23
+ process = False
24
+ if not isinstance(frames, torch.Tensor):
25
+ frames = ToTensor()(frames).unsqueeze(0)
26
+ process = True
27
+ _, _, ih, iw = frames.shape
28
+ # print("ih, iw", ih, iw)
29
+ i_ratio = ih / iw
30
+ h, w = res
31
+ # print("h,w", h ,w)
32
+ n_ratio = h / w
33
+ if i_ratio > n_ratio:
34
+ nw = int(ih / h * w)
35
+ # print("nw", nw)
36
+ frames = Pad((nw - iw)//2, fill=fill)(frames)
37
+ frames = frames[...,(nw - iw)//2:-(nw - iw)//2,:]
38
+ else:
39
+ nh = int(iw / w * h)
40
+ frames = Pad((nh - ih)//2, fill=fill)(frames)
41
+ frames = frames[...,:,(nh - ih)//2:-(nh - ih)//2]
42
+ # print("after pad", frames.shape)
43
+ if process:
44
+ frames = (frames * 255.).type(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
45
+ frames = Image.fromarray(frames)
46
+ return frames
47
+
48
+ def return_to_original_res(frames, res, pad_to_fix=False):
49
+ process = False
50
+ if not isinstance(frames, torch.Tensor):
51
+ frames = ToTensor()(frames).unsqueeze(0)
52
+ process = True
53
+
54
+ # print("original res", res)
55
+ _, _, h, w = frames.shape
56
+ # print("h w", h, w)
57
+ n_ratio = h / w
58
+ ih, iw = res
59
+ i_ratio = ih / iw
60
+ if pad_to_fix:
61
+ if i_ratio > n_ratio:
62
+ nw = int(ih / h * w)
63
+ frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
64
+ frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
65
+ else:
66
+ nh = int(iw / w * h)
67
+ frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
68
+
69
+ frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
70
+ else:
71
+ frames = Resize((ih, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
72
+
73
+ if process:
74
+ frames = (frames * 255.).type(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
75
+ frames = Image.fromarray(frames)
76
+
77
+ return frames
78
+
79
+ def get_prompt_ids(prompt, tokenizer):
80
+ prompt_ids = tokenizer(
81
+ prompt,
82
+ truncation=True,
83
+ padding="max_length",
84
+ max_length=tokenizer.model_max_length,
85
+ return_tensors="pt",
86
+ ).input_ids
87
+
88
+ return prompt_ids
89
+
90
+
91
+ def read_caption_file(caption_file):
92
+ with open(caption_file, 'r', encoding="utf8") as t:
93
+ return t.read()
94
+
95
+
96
+ def get_text_prompt(
97
+ text_prompt: str = '',
98
+ fallback_prompt: str= '',
99
+ file_path:str = '',
100
+ ext_types=['.mp4'],
101
+ use_caption=False
102
+ ):
103
+ try:
104
+ if use_caption:
105
+ if len(text_prompt) > 1: return text_prompt
106
+ caption_file = ''
107
+ # Use caption on per-video basis (One caption PER video)
108
+ for ext in ext_types:
109
+ maybe_file = file_path.replace(ext, '.txt')
110
+ if maybe_file.endswith(ext_types): continue
111
+ if os.path.exists(maybe_file):
112
+ caption_file = maybe_file
113
+ break
114
+
115
+ if os.path.exists(caption_file):
116
+ return read_caption_file(caption_file)
117
+
118
+ # Return fallback prompt if no conditions are met.
119
+ return fallback_prompt
120
+
121
+ return text_prompt
122
+ except:
123
+ print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
124
+ return fallback_prompt
125
+
126
+
127
+ def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
128
+ max_range = len(vr)
129
+ frame_number = sorted((0, start_idx, max_range))[1]
130
+
131
+ frame_range = range(frame_number, max_range, sample_rate)
132
+ frame_range_indices = list(frame_range)[:max_frames]
133
+
134
+ return frame_range_indices
135
+
136
+
137
+ def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch, pad_to_fix=False, use_aug=False):
138
+ use_aug = False
139
+ if use_bucketing:
140
+ vr = decord.VideoReader(vid_path)
141
+ resize = get_frame_buckets(vr)
142
+ video = get_frame_batch(vr, resize=resize)
143
+
144
+ else:
145
+ if not pad_to_fix:
146
+ vr = decord.VideoReader(vid_path, width=w, height=h)
147
+ video = get_frame_batch(vr, use_aug=use_aug)
148
+ else:
149
+ vr = decord.VideoReader(vid_path)
150
+ video = get_frame_batch(vr, use_aug=use_aug)
151
+ video = pad_with_ratio(video, (h, w))
152
+ video = T.transforms.Resize((h, w), antialias=True)(video)
153
+
154
+ return video, vr
155
+
156
+
157
+ # https://github.com/ExponentialML/Video-BLIP2-Preprocessor
158
+ class VideoJsonDataset(Dataset):
159
+ def __init__(
160
+ self,
161
+ tokenizer = None,
162
+ width: int = 256,
163
+ height: int = 256,
164
+ n_sample_frames: int = 4,
165
+ sample_start_idx: int = 1,
166
+ frame_step: int = 1,
167
+ json_path: str ="",
168
+ json_data = None,
169
+ vid_data_key: str = "video_path",
170
+ preprocessed: bool = False,
171
+ use_bucketing: bool = False,
172
+ **kwargs
173
+ ):
174
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
175
+ self.use_bucketing = use_bucketing
176
+ self.tokenizer = tokenizer
177
+ self.preprocessed = preprocessed
178
+
179
+ self.vid_data_key = vid_data_key
180
+ self.train_data = self.load_from_json(json_path, json_data)
181
+
182
+ self.width = width
183
+ self.height = height
184
+
185
+ self.n_sample_frames = n_sample_frames
186
+ self.sample_start_idx = sample_start_idx
187
+ self.frame_step = frame_step
188
+
189
+ def build_json(self, json_data):
190
+ extended_data = []
191
+ for data in json_data['data']:
192
+ for nested_data in data['data']:
193
+ self.build_json_dict(
194
+ data,
195
+ nested_data,
196
+ extended_data
197
+ )
198
+ json_data = extended_data
199
+ return json_data
200
+
201
+ def build_json_dict(self, data, nested_data, extended_data):
202
+ clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
203
+
204
+ extended_data.append({
205
+ self.vid_data_key: data[self.vid_data_key],
206
+ 'frame_index': nested_data['frame_index'],
207
+ 'prompt': nested_data['prompt'],
208
+ 'clip_path': clip_path
209
+ })
210
+
211
+ def load_from_json(self, path, json_data):
212
+ try:
213
+ with open(path) as jpath:
214
+ print(f"Loading JSON from {path}")
215
+ json_data = json.load(jpath)
216
+
217
+ return self.build_json(json_data)
218
+
219
+ except:
220
+ self.train_data = []
221
+ print("Non-existant JSON path. Skipping.")
222
+
223
+ def validate_json(self, base_path, path):
224
+ return os.path.exists(f"{base_path}/{path}")
225
+
226
+ def get_frame_range(self, vr):
227
+ return get_video_frames(
228
+ vr,
229
+ self.sample_start_idx,
230
+ self.frame_step,
231
+ self.n_sample_frames
232
+ )
233
+
234
+ def get_vid_idx(self, vr, vid_data=None):
235
+ frames = self.n_sample_frames
236
+
237
+ if vid_data is not None:
238
+ idx = vid_data['frame_index']
239
+ else:
240
+ idx = self.sample_start_idx
241
+
242
+ return idx
243
+
244
+ def get_frame_buckets(self, vr):
245
+ _, h, w = vr[0].shape
246
+ width, height = sensible_buckets(self.width, self.height, h, w)
247
+ # width, height = self.width, self.height
248
+ resize = T.transforms.Resize((height, width), antialias=True)
249
+
250
+ return resize
251
+
252
+ def get_frame_batch(self, vr, resize=None):
253
+ frame_range = self.get_frame_range(vr)
254
+ frames = vr.get_batch(frame_range)
255
+ video = rearrange(frames, "f h w c -> f c h w")
256
+
257
+ if resize is not None: video = resize(video)
258
+ return video
259
+
260
+ def process_video_wrapper(self, vid_path):
261
+ video, vr = process_video(
262
+ vid_path,
263
+ self.use_bucketing,
264
+ self.width,
265
+ self.height,
266
+ self.get_frame_buckets,
267
+ self.get_frame_batch
268
+ )
269
+
270
+ return video, vr
271
+
272
+ def train_data_batch(self, index):
273
+
274
+ # If we are training on individual clips.
275
+ if 'clip_path' in self.train_data[index] and \
276
+ self.train_data[index]['clip_path'] is not None:
277
+
278
+ vid_data = self.train_data[index]
279
+
280
+ clip_path = vid_data['clip_path']
281
+
282
+ # Get video prompt
283
+ prompt = vid_data['prompt']
284
+
285
+ video, _ = self.process_video_wrapper(clip_path)
286
+
287
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
288
+
289
+ return video, prompt, prompt_ids
290
+
291
+ # Assign train data
292
+ train_data = self.train_data[index]
293
+
294
+ # Get the frame of the current index.
295
+ self.sample_start_idx = train_data['frame_index']
296
+
297
+ # Initialize resize
298
+ resize = None
299
+
300
+ video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
301
+
302
+ # Get video prompt
303
+ prompt = train_data['prompt']
304
+ vr.seek(0)
305
+
306
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
307
+
308
+ return video, prompt, prompt_ids
309
+
310
+ @staticmethod
311
+ def __getname__(): return 'json'
312
+
313
+ def __len__(self):
314
+ if self.train_data is not None:
315
+ return len(self.train_data)
316
+ else:
317
+ return 0
318
+
319
+ def __getitem__(self, index):
320
+
321
+ # Initialize variables
322
+ video = None
323
+ prompt = None
324
+ prompt_ids = None
325
+
326
+ # Use default JSON training
327
+ if self.train_data is not None:
328
+ video, prompt, prompt_ids = self.train_data_batch(index)
329
+
330
+ example = {
331
+ "pixel_values": (video / 127.5 - 1.0),
332
+ "prompt_ids": prompt_ids[0],
333
+ "text_prompt": prompt,
334
+ 'dataset': self.__getname__()
335
+ }
336
+
337
+ return example
338
+
339
+
340
+ class SingleVideoDataset(Dataset):
341
+ def __init__(
342
+ self,
343
+ width: int = 256,
344
+ height: int = 256,
345
+ inversion_width: int = 256,
346
+ inversion_height: int = 256,
347
+ start_t: float=0,
348
+ end_t: float=-1,
349
+ sample_fps: int=-1,
350
+ single_video_path: str = "",
351
+ refer_image_path: str = "",
352
+ use_caption: bool = False,
353
+ use_bucketing: bool = False,
354
+ pad_to_fix: bool = False,
355
+ use_aug: bool = False,
356
+ **kwargs
357
+ ):
358
+ self.use_bucketing = use_bucketing
359
+ self.frames = []
360
+ self.index = 1
361
+
362
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
363
+ self.start_t = start_t
364
+ self.end_t = end_t
365
+ self.output_fps = sample_fps
366
+
367
+ self.single_video_path = single_video_path
368
+ self.refer_image_path = refer_image_path
369
+
370
+ self.width = width
371
+ self.height = height
372
+ self.inversion_width = inversion_width
373
+ self.inversion_height = inversion_height
374
+
375
+ self.pad_to_fix = pad_to_fix
376
+
377
+ self.use_aug = use_aug
378
+ #self.data_augment = ControlNetDataAugmentation()
379
+
380
+ def create_video_chunks(self):
381
+ output_fps = self.output_fps
382
+ start_t = self.start_t
383
+ end_t = self.end_t
384
+ vr = decord.VideoReader(self.single_video_path)
385
+ initial_fps = vr.get_avg_fps()
386
+ if output_fps == -1:
387
+ output_fps = int(initial_fps)
388
+ if end_t == -1:
389
+ end_t = len(vr) / initial_fps
390
+ else:
391
+ end_t = min(len(vr) / initial_fps, end_t)
392
+ assert 0 <= start_t < end_t
393
+ assert output_fps > 0
394
+ start_f_ind = int(start_t * initial_fps)
395
+ end_f_ind = int(end_t * initial_fps)
396
+ num_f = int((end_t - start_t) * output_fps)
397
+ sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
398
+ self.frames = [sample_idx]
399
+ return self.frames
400
+
401
+ def chunk(self, it, size):
402
+ it = iter(it)
403
+ return iter(lambda: tuple(islice(it, size)), ())
404
+
405
+ def get_frame_batch(self, vr, resize=None, use_aug=False):
406
+ index = self.index
407
+ frames = vr.get_batch(self.frames[self.index])
408
+
409
+ if use_aug:
410
+ frames = self.data_augment.augment(frames)
411
+ print(frames.min(), frames.max())
412
+
413
+ video = rearrange(frames, "f h w c -> f c h w")
414
+
415
+ if resize is not None: video = resize(video)
416
+ return video
417
+
418
+ def get_frame_buckets(self, vr):
419
+ h, w, c = vr[0].shape
420
+ width, height = sensible_buckets(self.width, self.height, w, h)
421
+ resize = T.transforms.Resize((height, width), antialias=True)
422
+
423
+ return resize
424
+
425
+ def process_video_wrapper(self, vid_path):
426
+ video, vr = process_video(
427
+ vid_path,
428
+ self.use_bucketing,
429
+ self.width,
430
+ self.height,
431
+ self.get_frame_buckets,
432
+ self.get_frame_batch,
433
+ self.pad_to_fix,
434
+ self.use_aug
435
+ )
436
+ video_for_inversion, vr = process_video(
437
+ vid_path,
438
+ self.use_bucketing,
439
+ self.inversion_width,
440
+ self.inversion_height,
441
+ self.get_frame_buckets,
442
+ self.get_frame_batch,
443
+ self.pad_to_fix
444
+ )
445
+
446
+ return video, video_for_inversion, vr
447
+
448
+ def image_batch(self):
449
+ train_data = self.refer_image_path
450
+ img = train_data
451
+
452
+ try:
453
+ img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
454
+ except:
455
+ img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
456
+
457
+ width = self.width
458
+ height = self.height
459
+
460
+ if self.use_bucketing:
461
+ _, h, w = img.shape
462
+ width, height = sensible_buckets(width, height, w, h)
463
+
464
+ resize = T.transforms.Resize((height, width), antialias=True)
465
+
466
+ img = resize(img)
467
+ img = repeat(img, 'c h w -> f c h w', f=1)
468
+
469
+ return img
470
+
471
+ def single_video_batch(self, index):
472
+ train_data = self.single_video_path
473
+ self.index = index
474
+
475
+ if train_data.endswith(self.vid_types):
476
+ video, video_for_inv, _ = self.process_video_wrapper(train_data)
477
+
478
+ return video, video_for_inv
479
+ else:
480
+ raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
481
+
482
+ @staticmethod
483
+ def __getname__(): return 'single_video'
484
+
485
+ def __len__(self):
486
+
487
+ return len(self.create_video_chunks())
488
+
489
+ def __getitem__(self, index):
490
+
491
+ video, video_for_inv = self.single_video_batch(index)
492
+ image = self.image_batch()
493
+ motion_values = torch.Tensor([127.])
494
+
495
+ example = {
496
+ "pixel_values": (video / 127.5 - 1.0),
497
+ "pixel_values_for_inv": (video_for_inv / 127.5 - 1.0),
498
+ "refer_pixel_values": (image / 127.5 - 1.0),
499
+ "motion_values": motion_values,
500
+ 'dataset': self.__getname__()
501
+ }
502
+
503
+ return example
504
+
505
+
506
+ class ImageDataset(Dataset):
507
+
508
+ def __init__(
509
+ self,
510
+ tokenizer = None,
511
+ width: int = 256,
512
+ height: int = 256,
513
+ base_width: int = 256,
514
+ base_height: int = 256,
515
+ use_caption: bool = False,
516
+ image_dir: str = '',
517
+ single_img_prompt: str = '',
518
+ use_bucketing: bool = False,
519
+ fallback_prompt: str = '',
520
+ **kwargs
521
+ ):
522
+ self.tokenizer = tokenizer
523
+ self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
524
+ self.use_bucketing = use_bucketing
525
+
526
+ self.image_dir = self.get_images_list(image_dir)
527
+ self.fallback_prompt = fallback_prompt
528
+
529
+ self.use_caption = use_caption
530
+ self.single_img_prompt = single_img_prompt
531
+
532
+ self.width = width
533
+ self.height = height
534
+
535
+ def get_images_list(self, image_dir):
536
+ if os.path.exists(image_dir):
537
+ imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
538
+ full_img_dir = []
539
+
540
+ for img in imgs:
541
+ full_img_dir.append(f"{image_dir}/{img}")
542
+
543
+ return sorted(full_img_dir)
544
+
545
+ return ['']
546
+
547
+ def image_batch(self, index):
548
+ train_data = self.image_dir[index]
549
+ img = train_data
550
+
551
+ try:
552
+ img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
553
+ except:
554
+ img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
555
+
556
+ width = self.width
557
+ height = self.height
558
+
559
+ if self.use_bucketing:
560
+ _, h, w = img.shape
561
+ width, height = sensible_buckets(width, height, w, h)
562
+
563
+ resize = T.transforms.Resize((height, width), antialias=True)
564
+
565
+ img = resize(img)
566
+ img = repeat(img, 'c h w -> f c h w', f=16)
567
+
568
+ prompt = get_text_prompt(
569
+ file_path=train_data,
570
+ text_prompt=self.single_img_prompt,
571
+ fallback_prompt=self.fallback_prompt,
572
+ ext_types=self.img_types,
573
+ use_caption=True
574
+ )
575
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
576
+
577
+ return img, prompt, prompt_ids
578
+
579
+ @staticmethod
580
+ def __getname__(): return 'image'
581
+
582
+ def __len__(self):
583
+ # Image directory
584
+ if os.path.exists(self.image_dir[0]):
585
+ return len(self.image_dir)
586
+ else:
587
+ return 0
588
+
589
+ def __getitem__(self, index):
590
+ img, prompt, prompt_ids = self.image_batch(index)
591
+ example = {
592
+ "pixel_values": (img / 127.5 - 1.0),
593
+ "prompt_ids": prompt_ids[0],
594
+ "text_prompt": prompt,
595
+ 'dataset': self.__getname__()
596
+ }
597
+
598
+ return example
599
+
600
+
601
+ class VideoFolderDataset(Dataset):
602
+ def __init__(
603
+ self,
604
+ tokenizer=None,
605
+ width: int = 256,
606
+ height: int = 256,
607
+ n_sample_frames: int = 16,
608
+ fps: int = 8,
609
+ path: str = "./data",
610
+ fallback_prompt: str = "",
611
+ use_bucketing: bool = False,
612
+ **kwargs
613
+ ):
614
+ self.tokenizer = tokenizer
615
+ self.use_bucketing = use_bucketing
616
+
617
+ self.fallback_prompt = fallback_prompt
618
+
619
+ self.video_files = glob(f"{path}/*.mp4")
620
+
621
+ self.width = width
622
+ self.height = height
623
+
624
+ self.n_sample_frames = n_sample_frames
625
+ self.fps = fps
626
+
627
+ def get_frame_buckets(self, vr):
628
+ h, w, c = vr[0].shape
629
+ width, height = sensible_buckets(self.width, self.height, w, h)
630
+ resize = T.transforms.Resize((height, width), antialias=True)
631
+
632
+ return resize
633
+
634
+ def get_frame_batch(self, vr, resize=None):
635
+ n_sample_frames = self.n_sample_frames
636
+ native_fps = vr.get_avg_fps()
637
+
638
+ every_nth_frame = max(1, round(native_fps / self.fps))
639
+ every_nth_frame = min(len(vr), every_nth_frame)
640
+
641
+ effective_length = len(vr) // every_nth_frame
642
+ if effective_length < n_sample_frames:
643
+ n_sample_frames = effective_length
644
+
645
+ effective_idx = random.randint(0, (effective_length - n_sample_frames))
646
+ idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
647
+
648
+ video = vr.get_batch(idxs)
649
+ video = rearrange(video, "f h w c -> f c h w")
650
+
651
+ if resize is not None: video = resize(video)
652
+ return video, vr
653
+
654
+ def process_video_wrapper(self, vid_path):
655
+ video, vr = process_video(
656
+ vid_path,
657
+ self.use_bucketing,
658
+ self.width,
659
+ self.height,
660
+ self.get_frame_buckets,
661
+ self.get_frame_batch
662
+ )
663
+ return video, vr
664
+
665
+ def get_prompt_ids(self, prompt):
666
+ return self.tokenizer(
667
+ prompt,
668
+ truncation=True,
669
+ padding="max_length",
670
+ max_length=self.tokenizer.model_max_length,
671
+ return_tensors="pt",
672
+ ).input_ids
673
+
674
+ @staticmethod
675
+ def __getname__(): return 'folder'
676
+
677
+ def __len__(self):
678
+ return len(self.video_files)
679
+
680
+ def __getitem__(self, index):
681
+
682
+ video, _ = self.process_video_wrapper(self.video_files[index])
683
+
684
+ prompt = self.fallback_prompt
685
+
686
+ prompt_ids = self.get_prompt_ids(prompt)
687
+
688
+ return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
689
+
690
+
691
+ class CachedDataset(Dataset):
692
+ def __init__(self,cache_dir: str = ''):
693
+ self.cache_dir = cache_dir
694
+ self.cached_data_list = self.get_files_list()
695
+
696
+ def get_files_list(self):
697
+ tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
698
+ return sorted(tensors_list)
699
+
700
+ def __len__(self):
701
+ return len(self.cached_data_list)
702
+
703
+ def __getitem__(self, index):
704
+ cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
705
+ return cached_latent
i2vedit/utils/euler_utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Union
5
+ import copy
6
+ from scipy.stats import anderson
7
+
8
+ import torch
9
+
10
+ from tqdm import tqdm
11
+ from diffusers import StableVideoDiffusionPipeline
12
+ from i2vedit.prompt_attention import attention_util
13
+
14
+ # Euler Inversion
15
+ @torch.no_grad()
16
+ def init_image(image, firstframe, pipeline):
17
+ if isinstance(image, torch.Tensor):
18
+ height, width = image.shape[-2:]
19
+ image = (image + 1) / 2. * 255.
20
+ image = image.type(torch.uint8).squeeze().permute(1,2,0).cpu().numpy()
21
+ image = Image.fromarray(image)
22
+ if isinstance(firstframe, torch.Tensor):
23
+ firstframe = (firstframe + 1) / 2. * 255.
24
+ firstframe = firstframe.type(torch.uint8).squeeze().permute(1,2,0).cpu().numpy()
25
+ firstframe = Image.fromarray(firstframe)
26
+
27
+ device = pipeline._execution_device
28
+ image_embeddings = pipeline._encode_image(firstframe, device, 1, False)
29
+ image = pipeline.image_processor.preprocess(image, height=height, width=width)
30
+ firstframe = pipeline.image_processor.preprocess(firstframe, height=height, width=width)
31
+ #print(image.dtype)
32
+ noise = torch.randn(image.shape, device=image.device, dtype=image.dtype)
33
+ image = image + 0.02 * noise
34
+ firstframe = firstframe + 0.02 * noise
35
+ #print(image.dtype)
36
+ image_latents = pipeline._encode_vae_image(image, device, 1, False)
37
+ firstframe_latents = pipeline._encode_vae_image(firstframe, device, 1, False)
38
+ image_latents = image_latents.to(image_embeddings.dtype)
39
+ firstframe_latents = firstframe_latents.to(image_embeddings.dtype)
40
+
41
+ return image_embeddings, image_latents, firstframe_latents
42
+
43
+
44
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], sigma, sigma_next,
45
+ sample: Union[torch.FloatTensor, np.ndarray], euler_scheduler, controller=None, consistency_controller=None):
46
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
47
+ if controller is not None:
48
+ pred_original_sample = controller.step_callback(pred_original_sample)
49
+ if consistency_controller is not None:
50
+ pred_original_sample = consistency_controller.step_callback(pred_original_sample)
51
+ #print("sample", sample.mean(), sample.std(), "pred_original_sample", pred_original_sample.mean())
52
+ #pred_original_sample = sample.mean() - pred_original_sample.mean() + pred_original_sample
53
+ next_sample = sample + (sigma_next - sigma) * (sample - pred_original_sample) / sigma
54
+ #print(sigma, sigma_next)
55
+ #print("next sample", torch.mean(next_sample), torch.std(next_sample))
56
+ return next_sample
57
+
58
+
59
+ def get_model_pred_single(latents, t, image_embeddings, added_time_ids, unet):
60
+ noise_pred = unet(
61
+ latents,
62
+ t,
63
+ encoder_hidden_states=image_embeddings,
64
+ added_time_ids=added_time_ids,
65
+ return_dict=False,
66
+ )[0]
67
+ return noise_pred
68
+
69
+ @torch.no_grad()
70
+ def euler_loop(pipeline, euler_scheduler, latents, num_inv_steps, image, firstframe, controller=None, consistency_controller=None):
71
+ device = pipeline._execution_device
72
+
73
+ # prepare image conditions
74
+ image_embeddings, image_latents, firstframe_latents = init_image(image, firstframe, pipeline)
75
+ skip = 1#latents.shape[1]
76
+ image_latents = torch.cat(
77
+ [
78
+ image_latents.unsqueeze(1).repeat(1, skip, 1, 1, 1),
79
+ firstframe_latents.unsqueeze(1).repeat(1, latents.shape[1]-skip, 1, 1, 1)
80
+ ],
81
+ dim=1
82
+ )
83
+ #image_latents = image_latents.unsqueeze(1).repeat(1, latents.shape[1], 1, 1, 1)
84
+
85
+ # Get Added Time IDs
86
+ added_time_ids = pipeline._get_add_time_ids(
87
+ 8,
88
+ 127,
89
+ 0.02,
90
+ image_embeddings.dtype,
91
+ 1,
92
+ 1,
93
+ False
94
+ )
95
+ added_time_ids = added_time_ids.to(device)
96
+
97
+ # Prepare timesteps
98
+ euler_scheduler.set_timesteps(num_inv_steps, device=device)
99
+ sigmas_0 = euler_scheduler.sigmas[-2] * euler_scheduler.sigmas[-2] / euler_scheduler.sigmas[-3]
100
+ timesteps = torch.cat([euler_scheduler.timesteps[1:],torch.Tensor([0.25 * sigmas_0.log()]).to(device)])
101
+ sigmas = copy.deepcopy(euler_scheduler.sigmas)
102
+ sigmas[-1] = sigmas_0
103
+ #print(sigmas)
104
+
105
+ # prepare latents
106
+ all_latent = [latents]
107
+ latents = latents.clone().detach()
108
+
109
+ for i in tqdm(range(num_inv_steps)):
110
+ t = timesteps[len(timesteps) -i -1]
111
+ sigma = sigmas[len(sigmas) -i -1]
112
+ sigma_next = sigmas[len(sigmas) -i -2]
113
+ latent_model_input = latents
114
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
115
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
116
+ model_pred = get_model_pred_single(latent_model_input, t, image_embeddings, added_time_ids, pipeline.unet)
117
+ latents = next_step(model_pred, sigma, sigma_next, latents, euler_scheduler, controller=controller, consistency_controller=consistency_controller)
118
+ all_latent.append(latents)
119
+ all_latent[-1] = all_latent[-1] / ((sigmas[0]**2 + 1)**0.5)
120
+ return all_latent
121
+
122
+
123
+ @torch.no_grad()
124
+ def euler_inversion(pipeline, euler_scheduler, video_latent, num_inv_steps, image, firstframe, controller=None, consistency_controller=None):
125
+ euler_latents = euler_loop(pipeline, euler_scheduler, video_latent, num_inv_steps, image, firstframe, controller=controller, consistency_controller=consistency_controller)
126
+ return euler_latents
127
+
128
+ from diffusers import EulerDiscreteScheduler
129
+ from .model_utils import tensor_to_vae_latent, load_primary_models, handle_memory_attention
130
+
131
+ def inverse_video(
132
+ pretrained_model_path,
133
+ video,
134
+ keyframe,
135
+ firstframe,
136
+ num_steps,
137
+ resctrl=None,
138
+ sard=None,
139
+ enable_xformers_memory_efficient_attention=True,
140
+ enable_torch_2_attn=False,
141
+ store_controller = None,
142
+ consistency_store_controller = None,
143
+ find_modules={},
144
+ consistency_find_modules={},
145
+ sarp_noise_scale=0.002,
146
+ ):
147
+ dtype = torch.float32
148
+
149
+ # check if inverted latents exists
150
+ for _controller in [store_controller, consistency_store_controller]:
151
+ if _controller is not None:
152
+ if os.path.exists(os.path.join(_controller.store_dir, "inverted_latents.pt")):
153
+ euler_inv_latent = torch.load(os.path.join(_controller.store_dir, "inverted_latents.pt")).to("cuda", dtype)
154
+ print(f"Successfully load inverted latents from {os.path.join(_controller.store_dir, 'inverted_latents.pt')}")
155
+ return euler_inv_latent
156
+
157
+ # prepare model, Load scheduler, tokenizer and models.
158
+ noise_scheduler, feature_extractor, image_encoder, vae, unet = load_primary_models(pretrained_model_path)
159
+
160
+ # Enable xformers if available
161
+ handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
162
+
163
+ vae.to('cuda', dtype=dtype)
164
+ unet.to('cuda')
165
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
166
+ pretrained_model_path,
167
+ feature_extractor=feature_extractor,
168
+ image_encoder=image_encoder,
169
+ vae=vae,
170
+ unet=unet
171
+ )
172
+ pipe.image_encoder.to('cuda')
173
+
174
+ attention_util.register_attention_control(
175
+ pipe.unet,
176
+ store_controller,
177
+ consistency_store_controller,
178
+ find_modules=find_modules,
179
+ consistency_find_modules=consistency_find_modules
180
+ )
181
+ if store_controller is not None:
182
+ store_controller.LOW_RESOURCE = True
183
+
184
+ video_for_inv = torch.cat([firstframe,keyframe,video],dim=1).to(dtype)
185
+ #print(video_for_inv.shape)
186
+ if resctrl is not None:
187
+ video_for_inv = resctrl(video_for_inv)
188
+ if sard is not None:
189
+ indx = sard.detection(video_for_inv, 0.001)
190
+ #import cv2
191
+ #cv2.imwrite("indx.png", indx[0,0,:,:,:].permute(1,2,0).type(torch.uint8).cpu().numpy()*255)
192
+ noise = torch.randn(video_for_inv.shape, device=video.device, dtype=video.dtype)
193
+ video_for_inv[indx] = video_for_inv[indx] + noise[indx] * sarp_noise_scale
194
+ video_for_inv = video_for_inv.clamp(-1,1)
195
+
196
+ firstframe, keyframe, video_for_inv = video_for_inv.tensor_split([1,2],dim=1)
197
+
198
+ #print("video for inv", video_for_inv.mean(), video_for_inv.std())
199
+ latents_for_inv = tensor_to_vae_latent(video_for_inv, vae)
200
+ #noise = torch.randn(latents_for_inv.shape, device=video.device, dtype=video.dtype)
201
+ #latents_for_inv = latents_for_inv + noise * sarp_noise_scale
202
+ #print("video latent for inv", latents_for_inv.mean(), latents_for_inv.std(), latents_for_inv.shape)
203
+
204
+ euler_inv_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
205
+ euler_inv_scheduler.set_timesteps(num_steps)
206
+
207
+ euler_inv_latent = euler_inversion(
208
+ pipe, euler_inv_scheduler, video_latent=latents_for_inv.to(pipe.device),
209
+ num_inv_steps=num_steps, image=keyframe[:,0,:,:,:], firstframe=firstframe[:,0,:,:,:], controller=store_controller, consistency_controller=consistency_store_controller)[-1]
210
+
211
+ torch.cuda.empty_cache()
212
+ del pipe
213
+
214
+ #res = anderson(euler_inv_latent.cpu().view(-1).numpy())
215
+ #print(euler_inv_latent.mean(), euler_inv_latent.std())
216
+ #print(res.statistic)
217
+ #print(res.critical_values)
218
+ #print(res.significance_level)
219
+
220
+ # save inverted latents
221
+ for _controller in [store_controller, consistency_store_controller]:
222
+ if _controller is not None:
223
+ torch.save(euler_inv_latent, os.path.join(_controller.store_dir, "inverted_latents.pt"))
224
+ break
225
+
226
+ return euler_inv_latent
i2vedit/utils/lora.py ADDED
@@ -0,0 +1,1493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from itertools import groupby
4
+ import os
5
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ try:
14
+ from safetensors.torch import safe_open
15
+ from safetensors.torch import save_file as safe_save
16
+
17
+ safetensors_available = True
18
+ except ImportError:
19
+ from .safe_open import safe_open
20
+
21
+ def safe_save(
22
+ tensors: Dict[str, torch.Tensor],
23
+ filename: str,
24
+ metadata: Optional[Dict[str, str]] = None,
25
+ ) -> None:
26
+ raise EnvironmentError(
27
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
28
+ )
29
+
30
+ safetensors_available = False
31
+
32
+
33
+ class LoraInjectedLinear(nn.Module):
34
+ def __init__(
35
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
36
+ ):
37
+ super().__init__()
38
+
39
+ if r > min(in_features, out_features):
40
+ #raise ValueError(
41
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
42
+ #)
43
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
44
+ r = min(in_features, out_features)
45
+
46
+ self.r = r
47
+ self.linear = nn.Linear(in_features, out_features, bias)
48
+ self.lora_down = nn.Linear(in_features, r, bias=False)
49
+ self.dropout = nn.Dropout(dropout_p)
50
+ self.lora_up = nn.Linear(r, out_features, bias=False)
51
+ self.scale = scale
52
+ self.selector = nn.Identity()
53
+
54
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
55
+ nn.init.zeros_(self.lora_up.weight)
56
+
57
+ def forward(self, input):
58
+ return (
59
+ self.linear(input)
60
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
61
+ * self.scale
62
+ )
63
+
64
+ def realize_as_lora(self):
65
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
66
+
67
+ def set_selector_from_diag(self, diag: torch.Tensor):
68
+ # diag is a 1D tensor of size (r,)
69
+ assert diag.shape == (self.r,)
70
+ self.selector = nn.Linear(self.r, self.r, bias=False)
71
+ self.selector.weight.data = torch.diag(diag)
72
+ self.selector.weight.data = self.selector.weight.data.to(
73
+ self.lora_up.weight.device
74
+ ).to(self.lora_up.weight.dtype)
75
+
76
+
77
+ class MultiLoraInjectedLinear(nn.Module):
78
+ def __init__(
79
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0]
80
+ ):
81
+ super().__init__()
82
+
83
+ if r > min(in_features, out_features):
84
+ #raise ValueError(
85
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
86
+ #)
87
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
88
+ r = min(in_features, out_features)
89
+
90
+ self.r = r
91
+ self.linear = nn.Linear(in_features, out_features, bias)
92
+
93
+ for i in range(lora_num):
94
+ if i==0:
95
+ self.lora_down =[nn.Linear(in_features, r, bias=False)]
96
+ self.dropout = [nn.Dropout(dropout_p)]
97
+ self.lora_up = [nn.Linear(r, out_features, bias=False)]
98
+ self.scale = scales[i]
99
+ self.selector = [nn.Identity()]
100
+ else:
101
+ self.lora_down.append(nn.Linear(in_features, r, bias=False))
102
+ self.dropout.append( nn.Dropout(dropout_p))
103
+ self.lora_up.append( nn.Linear(r, out_features, bias=False))
104
+ self.scale.append(scales[i])
105
+
106
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
107
+ nn.init.zeros_(self.lora_up.weight)
108
+
109
+ def forward(self, input):
110
+ return (
111
+ self.linear(input)
112
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
113
+ * self.scale
114
+ )
115
+
116
+ def realize_as_lora(self):
117
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
118
+
119
+ def set_selector_from_diag(self, diag: torch.Tensor):
120
+ # diag is a 1D tensor of size (r,)
121
+ assert diag.shape == (self.r,)
122
+ self.selector = nn.Linear(self.r, self.r, bias=False)
123
+ self.selector.weight.data = torch.diag(diag)
124
+ self.selector.weight.data = self.selector.weight.data.to(
125
+ self.lora_up.weight.device
126
+ ).to(self.lora_up.weight.dtype)
127
+
128
+
129
+ class LoraInjectedConv2d(nn.Module):
130
+ def __init__(
131
+ self,
132
+ in_channels: int,
133
+ out_channels: int,
134
+ kernel_size,
135
+ stride=1,
136
+ padding=0,
137
+ dilation=1,
138
+ groups: int = 1,
139
+ bias: bool = True,
140
+ r: int = 4,
141
+ dropout_p: float = 0.1,
142
+ scale: float = 1.0,
143
+ ):
144
+ super().__init__()
145
+ if r > min(in_channels, out_channels):
146
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
147
+ r = min(in_channels, out_channels)
148
+
149
+ self.r = r
150
+ self.conv = nn.Conv2d(
151
+ in_channels=in_channels,
152
+ out_channels=out_channels,
153
+ kernel_size=kernel_size,
154
+ stride=stride,
155
+ padding=padding,
156
+ dilation=dilation,
157
+ groups=groups,
158
+ bias=bias,
159
+ )
160
+
161
+ self.lora_down = nn.Conv2d(
162
+ in_channels=in_channels,
163
+ out_channels=r,
164
+ kernel_size=kernel_size,
165
+ stride=stride,
166
+ padding=padding,
167
+ dilation=dilation,
168
+ groups=groups,
169
+ bias=False,
170
+ )
171
+ self.dropout = nn.Dropout(dropout_p)
172
+ self.lora_up = nn.Conv2d(
173
+ in_channels=r,
174
+ out_channels=out_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0,
178
+ bias=False,
179
+ )
180
+ self.selector = nn.Identity()
181
+ self.scale = scale
182
+
183
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
184
+ nn.init.zeros_(self.lora_up.weight)
185
+
186
+ def forward(self, input):
187
+ return (
188
+ self.conv(input)
189
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
190
+ * self.scale
191
+ )
192
+
193
+ def realize_as_lora(self):
194
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
195
+
196
+ def set_selector_from_diag(self, diag: torch.Tensor):
197
+ # diag is a 1D tensor of size (r,)
198
+ assert diag.shape == (self.r,)
199
+ self.selector = nn.Conv2d(
200
+ in_channels=self.r,
201
+ out_channels=self.r,
202
+ kernel_size=1,
203
+ stride=1,
204
+ padding=0,
205
+ bias=False,
206
+ )
207
+ self.selector.weight.data = torch.diag(diag)
208
+
209
+ # same device + dtype as lora_up
210
+ self.selector.weight.data = self.selector.weight.data.to(
211
+ self.lora_up.weight.device
212
+ ).to(self.lora_up.weight.dtype)
213
+
214
+ class LoraInjectedConv3d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ kernel_size: (3, 1, 1),
220
+ padding: (1, 0, 0),
221
+ bias: bool = False,
222
+ r: int = 4,
223
+ dropout_p: float = 0,
224
+ scale: float = 1.0,
225
+ ):
226
+ super().__init__()
227
+ if r > min(in_channels, out_channels):
228
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
229
+ r = min(in_channels, out_channels)
230
+
231
+ self.r = r
232
+ self.kernel_size = kernel_size
233
+ self.padding = padding
234
+ self.conv = nn.Conv3d(
235
+ in_channels=in_channels,
236
+ out_channels=out_channels,
237
+ kernel_size=kernel_size,
238
+ padding=padding,
239
+ )
240
+
241
+ self.lora_down = nn.Conv3d(
242
+ in_channels=in_channels,
243
+ out_channels=r,
244
+ kernel_size=kernel_size,
245
+ bias=False,
246
+ padding=padding
247
+ )
248
+ self.dropout = nn.Dropout(dropout_p)
249
+ self.lora_up = nn.Conv3d(
250
+ in_channels=r,
251
+ out_channels=out_channels,
252
+ kernel_size=1,
253
+ stride=1,
254
+ padding=0,
255
+ bias=False,
256
+ )
257
+ self.selector = nn.Identity()
258
+ self.scale = scale
259
+
260
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
261
+ nn.init.zeros_(self.lora_up.weight)
262
+
263
+ def forward(self, input):
264
+ return (
265
+ self.conv(input)
266
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
267
+ * self.scale
268
+ )
269
+
270
+ def realize_as_lora(self):
271
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
272
+
273
+ def set_selector_from_diag(self, diag: torch.Tensor):
274
+ # diag is a 1D tensor of size (r,)
275
+ assert diag.shape == (self.r,)
276
+ self.selector = nn.Conv3d(
277
+ in_channels=self.r,
278
+ out_channels=self.r,
279
+ kernel_size=1,
280
+ stride=1,
281
+ padding=0,
282
+ bias=False,
283
+ )
284
+ self.selector.weight.data = torch.diag(diag)
285
+
286
+ # same device + dtype as lora_up
287
+ self.selector.weight.data = self.selector.weight.data.to(
288
+ self.lora_up.weight.device
289
+ ).to(self.lora_up.weight.dtype)
290
+
291
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
292
+
293
+ UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
294
+
295
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
296
+
297
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
298
+
299
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
300
+
301
+ EMBED_FLAG = "<embed>"
302
+
303
+
304
+ def _find_children(
305
+ model,
306
+ search_class: List[Type[nn.Module]] = [nn.Linear],
307
+ ):
308
+ """
309
+ Find all modules of a certain class (or union of classes).
310
+
311
+ Returns all matching modules, along with the parent of those moduless and the
312
+ names they are referenced by.
313
+ """
314
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
315
+ for parent in model.modules():
316
+ for name, module in parent.named_children():
317
+ if any([isinstance(module, _class) for _class in search_class]):
318
+ yield parent, name, module
319
+
320
+
321
+ def _find_modules_v2(
322
+ model,
323
+ ancestor_class: Optional[Set[str]] = None,
324
+ search_class: List[Type[nn.Module]] = [nn.Linear],
325
+ exclude_children_of: Optional[List[Type[nn.Module]]] = None,
326
+ # [
327
+ # LoraInjectedLinear,
328
+ # LoraInjectedConv2d,
329
+ # LoraInjectedConv3d
330
+ # ],
331
+ ):
332
+ """
333
+ Find all modules of a certain class (or union of classes) that are direct or
334
+ indirect descendants of other modules of a certain class (or union of classes).
335
+
336
+ Returns all matching modules, along with the parent of those moduless and the
337
+ names they are referenced by.
338
+ """
339
+
340
+ # Get the targets we should replace all linears under
341
+ if ancestor_class is not None:
342
+ ancestors = (
343
+ module
344
+ for name, module in model.named_modules()
345
+ if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name)
346
+ )
347
+ else:
348
+ # this, incase you want to naively iterate over all modules.
349
+ ancestors = [module for module in model.modules()]
350
+
351
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
352
+ for ancestor in ancestors:
353
+ for fullname, module in ancestor.named_modules():
354
+ if any([isinstance(module, _class) for _class in search_class]):
355
+ continue_flag = True
356
+ if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
357
+ continue_flag = False
358
+ if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
359
+ continue_flag = False
360
+ if 'TemporalBasicTransformerBlock' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
361
+ continue_flag = False
362
+ #if 'TemporalBasicTransformerBlock' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
363
+ # continue_flag = False
364
+ if 'BasicTransformerBlock' in ancestor_class and ('attn2' in fullname or 'ff' in fullname):
365
+ continue_flag = False
366
+ if continue_flag:
367
+ continue
368
+ # Find the direct parent if this is a descendant, not a child, of target
369
+ *path, name = fullname.split(".")
370
+ parent = ancestor
371
+ while path:
372
+ parent = parent.get_submodule(path.pop(0))
373
+ # Skip this linear if it's a child of a LoraInjectedLinear
374
+ if exclude_children_of and any(
375
+ [isinstance(parent, _class) for _class in exclude_children_of]
376
+ ):
377
+ continue
378
+ if name in ['lora_up', 'dropout', 'lora_down']:
379
+ continue
380
+ # Otherwise, yield it
381
+ yield parent, name, module
382
+
383
+
384
+ def _find_modules_old(
385
+ model,
386
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
387
+ search_class: List[Type[nn.Module]] = [nn.Linear],
388
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
389
+ ):
390
+ ret = []
391
+ for _module in model.modules():
392
+ if _module.__class__.__name__ in ancestor_class:
393
+
394
+ for name, _child_module in _module.named_modules():
395
+ if _child_module.__class__ in search_class:
396
+ ret.append((_module, name, _child_module))
397
+ print(ret)
398
+ return ret
399
+
400
+
401
+ _find_modules = _find_modules_v2
402
+
403
+
404
+ def inject_trainable_lora(
405
+ model: nn.Module,
406
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
407
+ r: int = 4,
408
+ loras=None, # path to lora .pt
409
+ verbose: bool = False,
410
+ dropout_p: float = 0.0,
411
+ scale: float = 1.0,
412
+ ):
413
+ """
414
+ inject lora into model, and returns lora parameter groups.
415
+ """
416
+
417
+ require_grad_params = []
418
+ names = []
419
+
420
+ if loras != None:
421
+ loras = torch.load(loras)
422
+
423
+ for _module, name, _child_module in _find_modules(
424
+ model, target_replace_module, search_class=[nn.Linear]
425
+ ):
426
+ weight = _child_module.weight
427
+ bias = _child_module.bias
428
+ if verbose:
429
+ print("LoRA Injection : injecting lora into ", name)
430
+ print("LoRA Injection : weight shape", weight.shape)
431
+ _tmp = LoraInjectedLinear(
432
+ _child_module.in_features,
433
+ _child_module.out_features,
434
+ _child_module.bias is not None,
435
+ r=r,
436
+ dropout_p=dropout_p,
437
+ scale=scale,
438
+ )
439
+ _tmp.linear.weight = weight
440
+ if bias is not None:
441
+ _tmp.linear.bias = bias
442
+
443
+ # switch the module
444
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
445
+ _module._modules[name] = _tmp
446
+
447
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
448
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
449
+
450
+ if loras != None:
451
+ _module._modules[name].lora_up.weight = loras.pop(0)
452
+ _module._modules[name].lora_down.weight = loras.pop(0)
453
+
454
+ _module._modules[name].lora_up.weight.requires_grad = True
455
+ _module._modules[name].lora_down.weight.requires_grad = True
456
+ names.append(name)
457
+
458
+ return require_grad_params, names
459
+
460
+
461
+ def inject_trainable_lora_extended(
462
+ model: nn.Module,
463
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
464
+ r: int = 4,
465
+ loras=None, # path to lora .pt
466
+ dropout_p: float = 0.0,
467
+ scale: float = 1.0,
468
+ ):
469
+ """
470
+ inject lora into model, and returns lora parameter groups.
471
+ """
472
+
473
+ require_grad_params = []
474
+ names = []
475
+
476
+ if loras != None:
477
+ print(f"Load from lora: {loras} ...")
478
+ loras = torch.load(loras)
479
+ if True:
480
+ for target_replace_module_i in target_replace_module:
481
+ for _module, name, _child_module in _find_modules(
482
+ model, [target_replace_module_i], search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
483
+ ):
484
+ # if name == 'to_q':
485
+ # continue
486
+ if _child_module.__class__ == nn.Linear:
487
+ weight = _child_module.weight
488
+ bias = _child_module.bias
489
+ _tmp = LoraInjectedLinear(
490
+ _child_module.in_features,
491
+ _child_module.out_features,
492
+ _child_module.bias is not None,
493
+ r=r,
494
+ dropout_p=dropout_p,
495
+ scale=scale,
496
+ )
497
+ _tmp.linear.weight = weight
498
+ if bias is not None:
499
+ _tmp.linear.bias = bias
500
+ elif _child_module.__class__ == nn.Conv2d:
501
+ weight = _child_module.weight
502
+ bias = _child_module.bias
503
+ _tmp = LoraInjectedConv2d(
504
+ _child_module.in_channels,
505
+ _child_module.out_channels,
506
+ _child_module.kernel_size,
507
+ _child_module.stride,
508
+ _child_module.padding,
509
+ _child_module.dilation,
510
+ _child_module.groups,
511
+ _child_module.bias is not None,
512
+ r=r,
513
+ dropout_p=dropout_p,
514
+ scale=scale,
515
+ )
516
+
517
+ _tmp.conv.weight = weight
518
+ if bias is not None:
519
+ _tmp.conv.bias = bias
520
+
521
+ elif _child_module.__class__ == nn.Conv3d:
522
+ weight = _child_module.weight
523
+ bias = _child_module.bias
524
+ _tmp = LoraInjectedConv3d(
525
+ _child_module.in_channels,
526
+ _child_module.out_channels,
527
+ bias=_child_module.bias is not None,
528
+ kernel_size=_child_module.kernel_size,
529
+ padding=_child_module.padding,
530
+ r=r,
531
+ dropout_p=dropout_p,
532
+ scale=scale,
533
+ )
534
+
535
+ _tmp.conv.weight = weight
536
+ if bias is not None:
537
+ _tmp.conv.bias = bias
538
+ # LoRA layer
539
+ else:
540
+ _tmp = _child_module
541
+ # switch the module
542
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
543
+ try:
544
+ if bias is not None:
545
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
546
+ except:
547
+ pass
548
+
549
+ _module._modules[name] = _tmp
550
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
551
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
552
+
553
+ if loras != None:
554
+ _module._modules[name].lora_up.weight = loras.pop(0)
555
+ _module._modules[name].lora_down.weight = loras.pop(0)
556
+
557
+ _module._modules[name].lora_up.weight.requires_grad = True
558
+ _module._modules[name].lora_down.weight.requires_grad = True
559
+ names.append(name)
560
+ else:
561
+ for _module, name, _child_module in _find_modules(
562
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
563
+ ):
564
+ if _child_module.__class__ == nn.Linear:
565
+ weight = _child_module.weight
566
+ bias = _child_module.bias
567
+ _tmp = LoraInjectedLinear(
568
+ _child_module.in_features,
569
+ _child_module.out_features,
570
+ _child_module.bias is not None,
571
+ r=r,
572
+ dropout_p=dropout_p,
573
+ scale=scale,
574
+ )
575
+ _tmp.linear.weight = weight
576
+ if bias is not None:
577
+ _tmp.linear.bias = bias
578
+ elif _child_module.__class__ == nn.Conv2d:
579
+ weight = _child_module.weight
580
+ bias = _child_module.bias
581
+ _tmp = LoraInjectedConv2d(
582
+ _child_module.in_channels,
583
+ _child_module.out_channels,
584
+ _child_module.kernel_size,
585
+ _child_module.stride,
586
+ _child_module.padding,
587
+ _child_module.dilation,
588
+ _child_module.groups,
589
+ _child_module.bias is not None,
590
+ r=r,
591
+ dropout_p=dropout_p,
592
+ scale=scale,
593
+ )
594
+
595
+ _tmp.conv.weight = weight
596
+ if bias is not None:
597
+ _tmp.conv.bias = bias
598
+
599
+ elif _child_module.__class__ == nn.Conv3d:
600
+ weight = _child_module.weight
601
+ bias = _child_module.bias
602
+ _tmp = LoraInjectedConv3d(
603
+ _child_module.in_channels,
604
+ _child_module.out_channels,
605
+ bias=_child_module.bias is not None,
606
+ kernel_size=_child_module.kernel_size,
607
+ padding=_child_module.padding,
608
+ r=r,
609
+ dropout_p=dropout_p,
610
+ scale=scale,
611
+ )
612
+
613
+ _tmp.conv.weight = weight
614
+ if bias is not None:
615
+ _tmp.conv.bias = bias
616
+ # switch the module
617
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
618
+ if bias is not None:
619
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
620
+
621
+ _module._modules[name] = _tmp
622
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
623
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
624
+
625
+ if loras != None:
626
+ _module._modules[name].lora_up.weight = loras.pop(0)
627
+ _module._modules[name].lora_down.weight = loras.pop(0)
628
+
629
+ _module._modules[name].lora_up.weight.requires_grad = True
630
+ _module._modules[name].lora_down.weight.requires_grad = True
631
+ names.append(name)
632
+
633
+ return require_grad_params, names
634
+
635
+
636
+ def inject_inferable_lora(
637
+ model,
638
+ lora_path='',
639
+ unet_replace_modules=["UNet3DConditionModel"],
640
+ text_encoder_replace_modules=["CLIPEncoderLayer"],
641
+ is_extended=False,
642
+ r=16
643
+ ):
644
+ from transformers.models.clip import CLIPTextModel
645
+ from diffusers import UNet3DConditionModel
646
+
647
+ def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
648
+ def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
649
+
650
+ if os.path.exists(lora_path):
651
+ try:
652
+ for f in os.listdir(lora_path):
653
+ if f.endswith('.pt'):
654
+ lora_file = os.path.join(lora_path, f)
655
+
656
+ if is_text_model(f):
657
+ monkeypatch_or_replace_lora(
658
+ model.text_encoder,
659
+ torch.load(lora_file),
660
+ target_replace_module=text_encoder_replace_modules,
661
+ r=r
662
+ )
663
+ print("Successfully loaded Text Encoder LoRa.")
664
+ continue
665
+
666
+ if is_unet(f):
667
+ monkeypatch_or_replace_lora_extended(
668
+ model.unet,
669
+ torch.load(lora_file),
670
+ target_replace_module=unet_replace_modules,
671
+ r=r
672
+ )
673
+ print("Successfully loaded UNET LoRa.")
674
+ continue
675
+
676
+ print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
677
+
678
+ except Exception as e:
679
+ print(e)
680
+ print("Couldn't inject LoRA's due to an error.")
681
+
682
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
683
+
684
+ loras = []
685
+
686
+ for target_replace_module_i in target_replace_module:
687
+
688
+ for _m, _n, _child_module in _find_modules(
689
+ model,
690
+ [target_replace_module_i],
691
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
692
+ ):
693
+ loras.append((_child_module.lora_up, _child_module.lora_down))
694
+
695
+ if len(loras) == 0:
696
+ raise ValueError("No lora injected.")
697
+
698
+ return loras
699
+
700
+
701
+ def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE):
702
+
703
+ loras = []
704
+
705
+ for target_replace_module_i in target_replace_module:
706
+
707
+ for _m, _n, _child_module in _find_modules(
708
+ model,
709
+ [target_replace_module_i],
710
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
711
+ ):
712
+ loras.append(_child_module)
713
+
714
+ return loras
715
+
716
+ def extract_lora_as_tensor(
717
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
718
+ ):
719
+
720
+ loras = []
721
+
722
+ for _m, _n, _child_module in _find_modules(
723
+ model,
724
+ target_replace_module,
725
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
726
+ ):
727
+ up, down = _child_module.realize_as_lora()
728
+ if as_fp16:
729
+ up = up.to(torch.float16)
730
+ down = down.to(torch.float16)
731
+
732
+ loras.append((up, down))
733
+
734
+ if len(loras) == 0:
735
+ raise ValueError("No lora injected.")
736
+
737
+ return loras
738
+
739
+
740
+ def save_lora_weight(
741
+ model,
742
+ path="./lora.pt",
743
+ target_replace_module=DEFAULT_TARGET_REPLACE,
744
+ flag=None
745
+ ):
746
+ weights = []
747
+ for _up, _down in extract_lora_ups_down(
748
+ model, target_replace_module=target_replace_module
749
+ ):
750
+ weights.append(_up.weight.to("cpu").to(torch.float32))
751
+ weights.append(_down.weight.to("cpu").to(torch.float32))
752
+ if not flag:
753
+ torch.save(weights, path)
754
+ else:
755
+ weights_new=[]
756
+ for i in range(0, len(weights), 4):
757
+ subset = weights[i+(flag-1)*2:i+(flag-1)*2+2]
758
+ weights_new.extend(subset)
759
+ torch.save(weights_new, path)
760
+
761
+ def save_lora_as_json(model, path="./lora.json"):
762
+ weights = []
763
+ for _up, _down in extract_lora_ups_down(model):
764
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
765
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
766
+
767
+ import json
768
+
769
+ with open(path, "w") as f:
770
+ json.dump(weights, f)
771
+
772
+
773
+ def save_safeloras_with_embeds(
774
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
775
+ embeds: Dict[str, torch.Tensor] = {},
776
+ outpath="./lora.safetensors",
777
+ ):
778
+ """
779
+ Saves the Lora from multiple modules in a single safetensor file.
780
+
781
+ modelmap is a dictionary of {
782
+ "module name": (module, target_replace_module)
783
+ }
784
+ """
785
+ weights = {}
786
+ metadata = {}
787
+
788
+ for name, (model, target_replace_module) in modelmap.items():
789
+ metadata[name] = json.dumps(list(target_replace_module))
790
+
791
+ for i, (_up, _down) in enumerate(
792
+ extract_lora_as_tensor(model, target_replace_module)
793
+ ):
794
+ rank = _down.shape[0]
795
+
796
+ metadata[f"{name}:{i}:rank"] = str(rank)
797
+ weights[f"{name}:{i}:up"] = _up
798
+ weights[f"{name}:{i}:down"] = _down
799
+
800
+ for token, tensor in embeds.items():
801
+ metadata[token] = EMBED_FLAG
802
+ weights[token] = tensor
803
+
804
+ print(f"Saving weights to {outpath}")
805
+ safe_save(weights, outpath, metadata)
806
+
807
+
808
+ def save_safeloras(
809
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
810
+ outpath="./lora.safetensors",
811
+ ):
812
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
813
+
814
+
815
+ def convert_loras_to_safeloras_with_embeds(
816
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
817
+ embeds: Dict[str, torch.Tensor] = {},
818
+ outpath="./lora.safetensors",
819
+ ):
820
+ """
821
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
822
+
823
+ modelmap is a dictionary of {
824
+ "module name": (pytorch_model_path, target_replace_module, rank)
825
+ }
826
+ """
827
+
828
+ weights = {}
829
+ metadata = {}
830
+
831
+ for name, (path, target_replace_module, r) in modelmap.items():
832
+ metadata[name] = json.dumps(list(target_replace_module))
833
+
834
+ lora = torch.load(path)
835
+ for i, weight in enumerate(lora):
836
+ is_up = i % 2 == 0
837
+ i = i // 2
838
+
839
+ if is_up:
840
+ metadata[f"{name}:{i}:rank"] = str(r)
841
+ weights[f"{name}:{i}:up"] = weight
842
+ else:
843
+ weights[f"{name}:{i}:down"] = weight
844
+
845
+ for token, tensor in embeds.items():
846
+ metadata[token] = EMBED_FLAG
847
+ weights[token] = tensor
848
+
849
+ print(f"Saving weights to {outpath}")
850
+ safe_save(weights, outpath, metadata)
851
+
852
+
853
+ def convert_loras_to_safeloras(
854
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
855
+ outpath="./lora.safetensors",
856
+ ):
857
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
858
+
859
+
860
+ def parse_safeloras(
861
+ safeloras,
862
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
863
+ """
864
+ Converts a loaded safetensor file that contains a set of module Loras
865
+ into Parameters and other information
866
+
867
+ Output is a dictionary of {
868
+ "module name": (
869
+ [list of weights],
870
+ [list of ranks],
871
+ target_replacement_modules
872
+ )
873
+ }
874
+ """
875
+ loras = {}
876
+ metadata = safeloras.metadata()
877
+
878
+ get_name = lambda k: k.split(":")[0]
879
+
880
+ keys = list(safeloras.keys())
881
+ keys.sort(key=get_name)
882
+
883
+ for name, module_keys in groupby(keys, get_name):
884
+ info = metadata.get(name)
885
+
886
+ if not info:
887
+ raise ValueError(
888
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
889
+ )
890
+
891
+ # Skip Textual Inversion embeds
892
+ if info == EMBED_FLAG:
893
+ continue
894
+
895
+ # Handle Loras
896
+ # Extract the targets
897
+ target = json.loads(info)
898
+
899
+ # Build the result lists - Python needs us to preallocate lists to insert into them
900
+ module_keys = list(module_keys)
901
+ ranks = [4] * (len(module_keys) // 2)
902
+ weights = [None] * len(module_keys)
903
+
904
+ for key in module_keys:
905
+ # Split the model name and index out of the key
906
+ _, idx, direction = key.split(":")
907
+ idx = int(idx)
908
+
909
+ # Add the rank
910
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
911
+
912
+ # Insert the weight into the list
913
+ idx = idx * 2 + (1 if direction == "down" else 0)
914
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
915
+
916
+ loras[name] = (weights, ranks, target)
917
+
918
+ return loras
919
+
920
+
921
+ def parse_safeloras_embeds(
922
+ safeloras,
923
+ ) -> Dict[str, torch.Tensor]:
924
+ """
925
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
926
+ a dictionary of embed_token: Tensor
927
+ """
928
+ embeds = {}
929
+ metadata = safeloras.metadata()
930
+
931
+ for key in safeloras.keys():
932
+ # Only handle Textual Inversion embeds
933
+ meta = metadata.get(key)
934
+ if not meta or meta != EMBED_FLAG:
935
+ continue
936
+
937
+ embeds[key] = safeloras.get_tensor(key)
938
+
939
+ return embeds
940
+
941
+
942
+ def load_safeloras(path, device="cpu"):
943
+ safeloras = safe_open(path, framework="pt", device=device)
944
+ return parse_safeloras(safeloras)
945
+
946
+
947
+ def load_safeloras_embeds(path, device="cpu"):
948
+ safeloras = safe_open(path, framework="pt", device=device)
949
+ return parse_safeloras_embeds(safeloras)
950
+
951
+
952
+ def load_safeloras_both(path, device="cpu"):
953
+ safeloras = safe_open(path, framework="pt", device=device)
954
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
955
+
956
+
957
+ def collapse_lora(model, alpha=1.0):
958
+
959
+ for _module, name, _child_module in _find_modules(
960
+ model,
961
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
962
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
963
+ ):
964
+
965
+ if isinstance(_child_module, LoraInjectedLinear):
966
+ print("Collapsing Lin Lora in", name)
967
+
968
+ _child_module.linear.weight = nn.Parameter(
969
+ _child_module.linear.weight.data
970
+ + alpha
971
+ * (
972
+ _child_module.lora_up.weight.data
973
+ @ _child_module.lora_down.weight.data
974
+ )
975
+ .type(_child_module.linear.weight.dtype)
976
+ .to(_child_module.linear.weight.device)
977
+ )
978
+
979
+ else:
980
+ print("Collapsing Conv Lora in", name)
981
+ _child_module.conv.weight = nn.Parameter(
982
+ _child_module.conv.weight.data
983
+ + alpha
984
+ * (
985
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
986
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
987
+ )
988
+ .reshape(_child_module.conv.weight.data.shape)
989
+ .type(_child_module.conv.weight.dtype)
990
+ .to(_child_module.conv.weight.device)
991
+ )
992
+
993
+
994
+ def monkeypatch_or_replace_lora(
995
+ model,
996
+ loras,
997
+ target_replace_module=DEFAULT_TARGET_REPLACE,
998
+ r: Union[int, List[int]] = 4,
999
+ ):
1000
+ for _module, name, _child_module in _find_modules(
1001
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
1002
+ ):
1003
+ _source = (
1004
+ _child_module.linear
1005
+ if isinstance(_child_module, LoraInjectedLinear)
1006
+ else _child_module
1007
+ )
1008
+
1009
+ weight = _source.weight
1010
+ bias = _source.bias
1011
+ _tmp = LoraInjectedLinear(
1012
+ _source.in_features,
1013
+ _source.out_features,
1014
+ _source.bias is not None,
1015
+ r=r.pop(0) if isinstance(r, list) else r,
1016
+ )
1017
+ _tmp.linear.weight = weight
1018
+
1019
+ if bias is not None:
1020
+ _tmp.linear.bias = bias
1021
+
1022
+ # switch the module
1023
+ _module._modules[name] = _tmp
1024
+
1025
+ up_weight = loras.pop(0)
1026
+ down_weight = loras.pop(0)
1027
+
1028
+ _module._modules[name].lora_up.weight = nn.Parameter(
1029
+ up_weight.type(weight.dtype)
1030
+ )
1031
+ _module._modules[name].lora_down.weight = nn.Parameter(
1032
+ down_weight.type(weight.dtype)
1033
+ )
1034
+
1035
+ _module._modules[name].to(weight.device)
1036
+
1037
+
1038
+ def monkeypatch_or_replace_lora_extended(
1039
+ model,
1040
+ loras,
1041
+ target_replace_module=DEFAULT_TARGET_REPLACE,
1042
+ r: Union[int, List[int]] = 4,
1043
+ ):
1044
+ for _module, name, _child_module in _find_modules(
1045
+ model,
1046
+ target_replace_module,
1047
+ search_class=[
1048
+ nn.Linear,
1049
+ nn.Conv2d,
1050
+ nn.Conv3d,
1051
+ LoraInjectedLinear,
1052
+ LoraInjectedConv2d,
1053
+ LoraInjectedConv3d,
1054
+ ],
1055
+ ):
1056
+
1057
+ if (_child_module.__class__ == nn.Linear) or (
1058
+ _child_module.__class__ == LoraInjectedLinear
1059
+ ):
1060
+ if len(loras[0].shape) != 2:
1061
+ continue
1062
+
1063
+ _source = (
1064
+ _child_module.linear
1065
+ if isinstance(_child_module, LoraInjectedLinear)
1066
+ else _child_module
1067
+ )
1068
+
1069
+ weight = _source.weight
1070
+ bias = _source.bias
1071
+ _tmp = LoraInjectedLinear(
1072
+ _source.in_features,
1073
+ _source.out_features,
1074
+ _source.bias is not None,
1075
+ r=r.pop(0) if isinstance(r, list) else r,
1076
+ )
1077
+ _tmp.linear.weight = weight
1078
+
1079
+ if bias is not None:
1080
+ _tmp.linear.bias = bias
1081
+
1082
+ elif (_child_module.__class__ == nn.Conv2d) or (
1083
+ _child_module.__class__ == LoraInjectedConv2d
1084
+ ):
1085
+ if len(loras[0].shape) != 4:
1086
+ continue
1087
+ _source = (
1088
+ _child_module.conv
1089
+ if isinstance(_child_module, LoraInjectedConv2d)
1090
+ else _child_module
1091
+ )
1092
+
1093
+ weight = _source.weight
1094
+ bias = _source.bias
1095
+ _tmp = LoraInjectedConv2d(
1096
+ _source.in_channels,
1097
+ _source.out_channels,
1098
+ _source.kernel_size,
1099
+ _source.stride,
1100
+ _source.padding,
1101
+ _source.dilation,
1102
+ _source.groups,
1103
+ _source.bias is not None,
1104
+ r=r.pop(0) if isinstance(r, list) else r,
1105
+ )
1106
+
1107
+ _tmp.conv.weight = weight
1108
+
1109
+ if bias is not None:
1110
+ _tmp.conv.bias = bias
1111
+
1112
+ elif _child_module.__class__ == nn.Conv3d or(
1113
+ _child_module.__class__ == LoraInjectedConv3d
1114
+ ):
1115
+
1116
+ if len(loras[0].shape) != 5:
1117
+ continue
1118
+
1119
+ _source = (
1120
+ _child_module.conv
1121
+ if isinstance(_child_module, LoraInjectedConv3d)
1122
+ else _child_module
1123
+ )
1124
+
1125
+ weight = _source.weight
1126
+ bias = _source.bias
1127
+ _tmp = LoraInjectedConv3d(
1128
+ _source.in_channels,
1129
+ _source.out_channels,
1130
+ bias=_source.bias is not None,
1131
+ kernel_size=_source.kernel_size,
1132
+ padding=_source.padding,
1133
+ r=r.pop(0) if isinstance(r, list) else r,
1134
+ )
1135
+
1136
+ _tmp.conv.weight = weight
1137
+
1138
+ if bias is not None:
1139
+ _tmp.conv.bias = bias
1140
+
1141
+ # switch the module
1142
+ _module._modules[name] = _tmp
1143
+
1144
+ up_weight = loras.pop(0)
1145
+ down_weight = loras.pop(0)
1146
+
1147
+ _module._modules[name].lora_up.weight = nn.Parameter(
1148
+ up_weight.type(weight.dtype)
1149
+ )
1150
+ _module._modules[name].lora_down.weight = nn.Parameter(
1151
+ down_weight.type(weight.dtype)
1152
+ )
1153
+
1154
+ _module._modules[name].to(weight.device)
1155
+
1156
+
1157
+ def monkeypatch_or_replace_safeloras(models, safeloras):
1158
+ loras = parse_safeloras(safeloras)
1159
+
1160
+ for name, (lora, ranks, target) in loras.items():
1161
+ model = getattr(models, name, None)
1162
+
1163
+ if not model:
1164
+ print(f"No model provided for {name}, contained in Lora")
1165
+ continue
1166
+
1167
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
1168
+
1169
+
1170
+ def monkeypatch_remove_lora(model):
1171
+ for _module, name, _child_module in _find_modules(
1172
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
1173
+ ):
1174
+ if isinstance(_child_module, LoraInjectedLinear):
1175
+ _source = _child_module.linear
1176
+ weight, bias = _source.weight, _source.bias
1177
+
1178
+ _tmp = nn.Linear(
1179
+ _source.in_features, _source.out_features, bias is not None
1180
+ )
1181
+
1182
+ _tmp.weight = weight
1183
+ if bias is not None:
1184
+ _tmp.bias = bias
1185
+
1186
+ else:
1187
+ _source = _child_module.conv
1188
+ weight, bias = _source.weight, _source.bias
1189
+
1190
+ if isinstance(_source, nn.Conv2d):
1191
+ _tmp = nn.Conv2d(
1192
+ in_channels=_source.in_channels,
1193
+ out_channels=_source.out_channels,
1194
+ kernel_size=_source.kernel_size,
1195
+ stride=_source.stride,
1196
+ padding=_source.padding,
1197
+ dilation=_source.dilation,
1198
+ groups=_source.groups,
1199
+ bias=bias is not None,
1200
+ )
1201
+
1202
+ _tmp.weight = weight
1203
+ if bias is not None:
1204
+ _tmp.bias = bias
1205
+
1206
+ if isinstance(_source, nn.Conv3d):
1207
+ _tmp = nn.Conv3d(
1208
+ _source.in_channels,
1209
+ _source.out_channels,
1210
+ bias=_source.bias is not None,
1211
+ kernel_size=_source.kernel_size,
1212
+ padding=_source.padding,
1213
+ )
1214
+
1215
+ _tmp.weight = weight
1216
+ if bias is not None:
1217
+ _tmp.bias = bias
1218
+
1219
+ _module._modules[name] = _tmp
1220
+
1221
+
1222
+ def monkeypatch_add_lora(
1223
+ model,
1224
+ loras,
1225
+ target_replace_module=DEFAULT_TARGET_REPLACE,
1226
+ alpha: float = 1.0,
1227
+ beta: float = 1.0,
1228
+ ):
1229
+ for _module, name, _child_module in _find_modules(
1230
+ model, target_replace_module, search_class=[LoraInjectedLinear]
1231
+ ):
1232
+ weight = _child_module.linear.weight
1233
+
1234
+ up_weight = loras.pop(0)
1235
+ down_weight = loras.pop(0)
1236
+
1237
+ _module._modules[name].lora_up.weight = nn.Parameter(
1238
+ up_weight.type(weight.dtype).to(weight.device) * alpha
1239
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
1240
+ )
1241
+ _module._modules[name].lora_down.weight = nn.Parameter(
1242
+ down_weight.type(weight.dtype).to(weight.device) * alpha
1243
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
1244
+ )
1245
+
1246
+ _module._modules[name].to(weight.device)
1247
+
1248
+
1249
+ def tune_lora_scale(model, alpha: float = 1.0):
1250
+ for _module in model.modules():
1251
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1252
+ _module.scale = alpha
1253
+
1254
+
1255
+ def set_lora_diag(model, diag: torch.Tensor):
1256
+ for _module in model.modules():
1257
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1258
+ _module.set_selector_from_diag(diag)
1259
+
1260
+
1261
+ def _text_lora_path(path: str) -> str:
1262
+ assert path.endswith(".pt"), "Only .pt files are supported"
1263
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
1264
+
1265
+
1266
+ def _ti_lora_path(path: str) -> str:
1267
+ assert path.endswith(".pt"), "Only .pt files are supported"
1268
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
1269
+
1270
+
1271
+ def apply_learned_embed_in_clip(
1272
+ learned_embeds,
1273
+ text_encoder,
1274
+ tokenizer,
1275
+ token: Optional[Union[str, List[str]]] = None,
1276
+ idempotent=False,
1277
+ ):
1278
+ if isinstance(token, str):
1279
+ trained_tokens = [token]
1280
+ elif isinstance(token, list):
1281
+ assert len(learned_embeds.keys()) == len(
1282
+ token
1283
+ ), "The number of tokens and the number of embeds should be the same"
1284
+ trained_tokens = token
1285
+ else:
1286
+ trained_tokens = list(learned_embeds.keys())
1287
+
1288
+ for token in trained_tokens:
1289
+ print(token)
1290
+ embeds = learned_embeds[token]
1291
+
1292
+ # cast to dtype of text_encoder
1293
+ dtype = text_encoder.get_input_embeddings().weight.dtype
1294
+ num_added_tokens = tokenizer.add_tokens(token)
1295
+
1296
+ i = 1
1297
+ if not idempotent:
1298
+ while num_added_tokens == 0:
1299
+ print(f"The tokenizer already contains the token {token}.")
1300
+ token = f"{token[:-1]}-{i}>"
1301
+ print(f"Attempting to add the token {token}.")
1302
+ num_added_tokens = tokenizer.add_tokens(token)
1303
+ i += 1
1304
+ elif num_added_tokens == 0 and idempotent:
1305
+ print(f"The tokenizer already contains the token {token}.")
1306
+ print(f"Replacing {token} embedding.")
1307
+
1308
+ # resize the token embeddings
1309
+ text_encoder.resize_token_embeddings(len(tokenizer))
1310
+
1311
+ # get the id for the token and assign the embeds
1312
+ token_id = tokenizer.convert_tokens_to_ids(token)
1313
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
1314
+ return token
1315
+
1316
+
1317
+ def load_learned_embed_in_clip(
1318
+ learned_embeds_path,
1319
+ text_encoder,
1320
+ tokenizer,
1321
+ token: Optional[Union[str, List[str]]] = None,
1322
+ idempotent=False,
1323
+ ):
1324
+ learned_embeds = torch.load(learned_embeds_path)
1325
+ apply_learned_embed_in_clip(
1326
+ learned_embeds, text_encoder, tokenizer, token, idempotent
1327
+ )
1328
+
1329
+
1330
+ def patch_pipe(
1331
+ pipe,
1332
+ maybe_unet_path,
1333
+ token: Optional[str] = None,
1334
+ r: int = 4,
1335
+ patch_unet=True,
1336
+ patch_text=True,
1337
+ patch_ti=True,
1338
+ idempotent_token=True,
1339
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
1340
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1341
+ ):
1342
+ if maybe_unet_path.endswith(".pt"):
1343
+ # torch format
1344
+
1345
+ if maybe_unet_path.endswith(".ti.pt"):
1346
+ unet_path = maybe_unet_path[:-6] + ".pt"
1347
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
1348
+ unet_path = maybe_unet_path[:-16] + ".pt"
1349
+ else:
1350
+ unet_path = maybe_unet_path
1351
+
1352
+ ti_path = _ti_lora_path(unet_path)
1353
+ text_path = _text_lora_path(unet_path)
1354
+
1355
+ if patch_unet:
1356
+ print("LoRA : Patching Unet")
1357
+ monkeypatch_or_replace_lora(
1358
+ pipe.unet,
1359
+ torch.load(unet_path),
1360
+ r=r,
1361
+ target_replace_module=unet_target_replace_module,
1362
+ )
1363
+
1364
+ if patch_text:
1365
+ print("LoRA : Patching text encoder")
1366
+ monkeypatch_or_replace_lora(
1367
+ pipe.text_encoder,
1368
+ torch.load(text_path),
1369
+ target_replace_module=text_target_replace_module,
1370
+ r=r,
1371
+ )
1372
+ if patch_ti:
1373
+ print("LoRA : Patching token input")
1374
+ token = load_learned_embed_in_clip(
1375
+ ti_path,
1376
+ pipe.text_encoder,
1377
+ pipe.tokenizer,
1378
+ token=token,
1379
+ idempotent=idempotent_token,
1380
+ )
1381
+
1382
+ elif maybe_unet_path.endswith(".safetensors"):
1383
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1384
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1385
+ tok_dict = parse_safeloras_embeds(safeloras)
1386
+ if patch_ti:
1387
+ apply_learned_embed_in_clip(
1388
+ tok_dict,
1389
+ pipe.text_encoder,
1390
+ pipe.tokenizer,
1391
+ token=token,
1392
+ idempotent=idempotent_token,
1393
+ )
1394
+ return tok_dict
1395
+
1396
+
1397
+ def train_patch_pipe(pipe, patch_unet, patch_text):
1398
+ if patch_unet:
1399
+ print("LoRA : Patching Unet")
1400
+ collapse_lora(pipe.unet)
1401
+ monkeypatch_remove_lora(pipe.unet)
1402
+
1403
+ if patch_text:
1404
+ print("LoRA : Patching text encoder")
1405
+
1406
+ collapse_lora(pipe.text_encoder)
1407
+ monkeypatch_remove_lora(pipe.text_encoder)
1408
+
1409
+ @torch.no_grad()
1410
+ def inspect_lora(model):
1411
+ moved = {}
1412
+
1413
+ for name, _module in model.named_modules():
1414
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1415
+ ups = _module.lora_up.weight.data.clone()
1416
+ downs = _module.lora_down.weight.data.clone()
1417
+
1418
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1419
+
1420
+ dist = wght.flatten().abs().mean().item()
1421
+ if name in moved:
1422
+ moved[name].append(dist)
1423
+ else:
1424
+ moved[name] = [dist]
1425
+
1426
+ return moved
1427
+
1428
+
1429
+ def save_all(
1430
+ unet,
1431
+ text_encoder,
1432
+ save_path,
1433
+ placeholder_token_ids=None,
1434
+ placeholder_tokens=None,
1435
+ save_lora=True,
1436
+ save_ti=True,
1437
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1438
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1439
+ safe_form=True,
1440
+ ):
1441
+ if not safe_form:
1442
+ # save ti
1443
+ if save_ti:
1444
+ ti_path = _ti_lora_path(save_path)
1445
+ learned_embeds_dict = {}
1446
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1447
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1448
+ print(
1449
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1450
+ learned_embeds[:4],
1451
+ )
1452
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1453
+
1454
+ torch.save(learned_embeds_dict, ti_path)
1455
+ print("Ti saved to ", ti_path)
1456
+
1457
+ # save text encoder
1458
+ if save_lora:
1459
+ save_lora_weight(
1460
+ unet, save_path, target_replace_module=target_replace_module_unet
1461
+ )
1462
+ print("Unet saved to ", save_path)
1463
+
1464
+ save_lora_weight(
1465
+ text_encoder,
1466
+ _text_lora_path(save_path),
1467
+ target_replace_module=target_replace_module_text,
1468
+ )
1469
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1470
+
1471
+ else:
1472
+ assert save_path.endswith(
1473
+ ".safetensors"
1474
+ ), f"Save path : {save_path} should end with .safetensors"
1475
+
1476
+ loras = {}
1477
+ embeds = {}
1478
+
1479
+ if save_lora:
1480
+
1481
+ loras["unet"] = (unet, target_replace_module_unet)
1482
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1483
+
1484
+ if save_ti:
1485
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1486
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1487
+ print(
1488
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1489
+ learned_embeds[:4],
1490
+ )
1491
+ embeds[tok] = learned_embeds.detach().cpu()
1492
+
1493
+ save_safeloras_with_embeds(loras, embeds, save_path)
i2vedit/utils/lora_handler.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from logging import warnings
3
+ import torch
4
+ from typing import Union
5
+ from types import SimpleNamespace
6
+ from diffusers import UNetSpatioTemporalConditionModel
7
+ from transformers import CLIPTextModel
8
+
9
+ from .lora import (
10
+ extract_lora_ups_down,
11
+ inject_trainable_lora_extended,
12
+ save_lora_weight,
13
+ train_patch_pipe,
14
+ monkeypatch_or_replace_lora,
15
+ monkeypatch_or_replace_lora_extended
16
+ )
17
+
18
+
19
+ FILE_BASENAMES = ['unet', 'text_encoder']
20
+ LORA_FILE_TYPES = ['.pt', '.safetensors']
21
+ CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
22
+ STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
23
+
24
+ lora_versions = dict(
25
+ stable_lora = "stable_lora",
26
+ cloneofsimo = "cloneofsimo"
27
+ )
28
+
29
+ lora_func_types = dict(
30
+ loader = "loader",
31
+ injector = "injector"
32
+ )
33
+
34
+ lora_args = dict(
35
+ model = None,
36
+ loras = None,
37
+ target_replace_module = [],
38
+ target_module = [],
39
+ r = 4,
40
+ search_class = [torch.nn.Linear],
41
+ dropout = 0,
42
+ lora_bias = 'none'
43
+ )
44
+
45
+ LoraVersions = SimpleNamespace(**lora_versions)
46
+ LoraFuncTypes = SimpleNamespace(**lora_func_types)
47
+
48
+ LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
49
+ LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
50
+
51
+ def filter_dict(_dict, keys=[]):
52
+ if len(keys) == 0:
53
+ assert "Keys cannot empty for filtering return dict."
54
+
55
+ for k in keys:
56
+ if k not in lora_args.keys():
57
+ assert f"{k} does not exist in available LoRA arguments"
58
+
59
+ return {k: v for k, v in _dict.items() if k in keys}
60
+
61
+ class LoraHandler(object):
62
+ def __init__(
63
+ self,
64
+ version: LORA_VERSIONS = LoraVersions.cloneofsimo,
65
+ use_unet_lora: bool = False,
66
+ use_image_lora: bool = False,
67
+ save_for_webui: bool = False,
68
+ only_for_webui: bool = False,
69
+ lora_bias: str = 'none',
70
+ unet_replace_modules: list = None,
71
+ image_encoder_replace_modules: list = None
72
+ ):
73
+ self.version = version
74
+ self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
75
+ self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
76
+ self.lora_bias = lora_bias
77
+ self.use_unet_lora = use_unet_lora
78
+ self.use_image_lora = use_image_lora
79
+ self.save_for_webui = save_for_webui
80
+ self.only_for_webui = only_for_webui
81
+ self.unet_replace_modules = unet_replace_modules
82
+ self.image_encoder_replace_modules = image_encoder_replace_modules
83
+ self.use_lora = any([use_image_lora, use_unet_lora])
84
+
85
+ def is_cloneofsimo_lora(self):
86
+ return self.version == LoraVersions.cloneofsimo
87
+
88
+
89
+ def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
90
+
91
+ if self.is_cloneofsimo_lora():
92
+
93
+ if func_type == LoraFuncTypes.loader:
94
+ return monkeypatch_or_replace_lora_extended
95
+
96
+ if func_type == LoraFuncTypes.injector:
97
+ return inject_trainable_lora_extended
98
+
99
+ assert "LoRA Version does not exist."
100
+
101
+ def check_lora_ext(self, lora_file: str):
102
+ return lora_file.endswith(tuple(LORA_FILE_TYPES))
103
+
104
+ def get_lora_file_path(
105
+ self,
106
+ lora_path: str,
107
+ model: Union[UNetSpatioTemporalConditionModel, CLIPTextModel]
108
+ ):
109
+ if os.path.exists(lora_path):
110
+ lora_filenames = [fns for fns in os.listdir(lora_path)]
111
+ is_lora = self.check_lora_ext(lora_path)
112
+
113
+ is_unet = isinstance(model, UNetSpatioTemporalConditionModel)
114
+ #is_text = isinstance(model, CLIPTextModel)
115
+ idx = 0 if is_unet else 1
116
+
117
+ base_name = FILE_BASENAMES[idx]
118
+
119
+ for lora_filename in lora_filenames:
120
+ is_lora = self.check_lora_ext(lora_filename)
121
+ if not is_lora:
122
+ continue
123
+
124
+ if base_name in lora_filename:
125
+ return os.path.join(lora_path, lora_filename)
126
+ else:
127
+ print(f"lora_path: {lora_path} does not exist. Inject without pretrained loras...")
128
+
129
+ return None
130
+
131
+ def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
132
+ self.lora_loader(**lora_loader_args)
133
+ print(f"Successfully loaded LoRA from: {file_name}")
134
+
135
+ def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
136
+ try:
137
+ lora_file = self.get_lora_file_path(lora_path, model)
138
+
139
+ if lora_file is not None:
140
+ lora_loader_args.update({"lora_path": lora_file})
141
+ self.handle_lora_load(lora_file, lora_loader_args)
142
+
143
+ else:
144
+ print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
145
+
146
+ except Exception as e:
147
+ print(f"An error occurred while loading a LoRA file: {e}")
148
+
149
+ def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale):
150
+ return_dict = lora_args.copy()
151
+
152
+ if self.is_cloneofsimo_lora():
153
+ return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
154
+ return_dict.update({
155
+ "model": model,
156
+ "loras": self.get_lora_file_path(lora_path, model),
157
+ "target_replace_module": replace_modules,
158
+ "r": r,
159
+ "scale": scale,
160
+ "dropout_p": dropout,
161
+ })
162
+
163
+ return return_dict
164
+
165
+ def do_lora_injection(
166
+ self,
167
+ model,
168
+ replace_modules,
169
+ bias='none',
170
+ dropout=0,
171
+ r=4,
172
+ lora_loader_args=None,
173
+ ):
174
+ REPLACE_MODULES = replace_modules
175
+
176
+ params = None
177
+ negation = None
178
+ is_injection_hybrid = False
179
+
180
+ if self.is_cloneofsimo_lora():
181
+ is_injection_hybrid = True
182
+ injector_args = lora_loader_args
183
+
184
+ params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended
185
+ for _up, _down in extract_lora_ups_down(
186
+ model,
187
+ target_replace_module=REPLACE_MODULES):
188
+
189
+ if all(x is not None for x in [_up, _down]):
190
+ print(f"Lora successfully injected into {model.__class__.__name__}.")
191
+
192
+ break
193
+
194
+ return params, negation, is_injection_hybrid
195
+
196
+ return params, negation, is_injection_hybrid
197
+
198
+ def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0):
199
+
200
+ params = None
201
+ negation = None
202
+
203
+ lora_loader_args = self.get_lora_func_args(
204
+ lora_path,
205
+ use_lora,
206
+ model,
207
+ replace_modules,
208
+ r,
209
+ dropout,
210
+ self.lora_bias,
211
+ scale
212
+ )
213
+
214
+ if use_lora:
215
+ params, negation, is_injection_hybrid = self.do_lora_injection(
216
+ model,
217
+ replace_modules,
218
+ bias=self.lora_bias,
219
+ lora_loader_args=lora_loader_args,
220
+ dropout=dropout,
221
+ r=r
222
+ )
223
+
224
+ if not is_injection_hybrid:
225
+ self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
226
+
227
+ params = model if params is None else params
228
+ return params, negation
229
+
230
+ def save_cloneofsimo_lora(self, model, save_path, step, flag):
231
+
232
+ def save_lora(model, name, condition, replace_modules, step, save_path, flag=None):
233
+ if condition and replace_modules is not None:
234
+ save_path = f"{save_path}/{step}_{name}.pt"
235
+ save_lora_weight(model, save_path, replace_modules, flag)
236
+
237
+ save_lora(
238
+ model.unet,
239
+ FILE_BASENAMES[0],
240
+ self.use_unet_lora,
241
+ self.unet_replace_modules,
242
+ step,
243
+ save_path,
244
+ flag
245
+ )
246
+ save_lora(
247
+ model.image_encoder,
248
+ FILE_BASENAMES[1],
249
+ self.use_image_lora,
250
+ self.image_encoder_replace_modules,
251
+ step,
252
+ save_path,
253
+ flag
254
+ )
255
+
256
+ # train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
257
+
258
+ def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None):
259
+ save_path = f"{save_path}/lora"
260
+ os.makedirs(save_path, exist_ok=True)
261
+
262
+ if self.is_cloneofsimo_lora():
263
+ if any([self.save_for_webui, self.only_for_webui]):
264
+ warnings.warn(
265
+ """
266
+ You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
267
+ Only 'stable_lora' is supported for saving to a compatible webui file.
268
+ """
269
+ )
270
+ self.save_cloneofsimo_lora(model, save_path, step, flag)
i2vedit/utils/model_utils.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import random
8
+ import gc
9
+ import copy
10
+ import imageio
11
+ import numpy as np
12
+ import PIL
13
+ from PIL import Image
14
+ from scipy.stats import anderson
15
+ from typing import Dict, Optional, Tuple, Callable, List, Union
16
+ from omegaconf import OmegaConf
17
+ from einops import rearrange, repeat
18
+ from dataclasses import dataclass
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torchvision import transforms
24
+ from tqdm.auto import tqdm
25
+
26
+ from accelerate import Accelerator
27
+ from accelerate.logging import get_logger
28
+ from accelerate.utils import set_seed
29
+
30
+ import transformers
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
33
+ from transformers.models.clip.modeling_clip import CLIPEncoder
34
+
35
+ import diffusers
36
+ from diffusers.models import AutoencoderKL
37
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.utils.import_utils import is_xformers_available
40
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
41
+ from diffusers.models.attention import BasicTransformerBlock
42
+ from diffusers import StableVideoDiffusionPipeline
43
+ from diffusers.models.lora import LoRALinearLayer
44
+ from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
45
+ from diffusers.image_processor import VaeImageProcessor
46
+ from diffusers.optimization import get_scheduler
47
+ from diffusers.training_utils import EMAModel
48
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image, BaseOutput
49
+ from diffusers.utils.import_utils import is_xformers_available
50
+ from diffusers.utils.torch_utils import randn_tensor
51
+ from diffusers.models.unet_3d_blocks import \
52
+ (CrossAttnDownBlockSpatioTemporal,
53
+ DownBlockSpatioTemporal,
54
+ CrossAttnUpBlockSpatioTemporal,
55
+ UpBlockSpatioTemporal)
56
+ from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteSchedulerOutput, EulerDiscreteScheduler
57
+
58
+
59
+ def _append_dims(x, target_dims):
60
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
61
+ dims_to_append = target_dims - x.ndim
62
+ if dims_to_append < 0:
63
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
64
+ return x[(...,) + (None,) * dims_to_append]
65
+
66
+ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
67
+ # Based on:
68
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
69
+
70
+ batch_size, channels, num_frames, height, width = video.shape
71
+ outputs = []
72
+ for batch_idx in range(batch_size):
73
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
74
+ batch_output = processor.postprocess(batch_vid, output_type)
75
+
76
+ outputs.append(batch_output)
77
+
78
+ return outputs
79
+
80
+ @torch.no_grad()
81
+ def tensor_to_vae_latent(t, vae):
82
+ video_length = t.shape[1]
83
+
84
+ t = rearrange(t, "b f c h w -> (b f) c h w")
85
+ latents = vae.encode(t).latent_dist.sample()
86
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
87
+ latents = latents * vae.config.scaling_factor
88
+
89
+ return latents
90
+
91
+ def load_primary_models(pretrained_model_path):
92
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
93
+ pretrained_model_path, subfolder="scheduler")
94
+ feature_extractor = CLIPImageProcessor.from_pretrained(
95
+ pretrained_model_path, subfolder="feature_extractor", revision=None
96
+ )
97
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
98
+ pretrained_model_path, subfolder="image_encoder", revision=None, variant="fp16"
99
+ )
100
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
101
+ pretrained_model_path, subfolder="vae", revision=None, variant="fp16")
102
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
103
+ pretrained_model_path,
104
+ subfolder="unet",
105
+ low_cpu_mem_usage=True,
106
+ variant="fp16",
107
+ )
108
+
109
+ return noise_scheduler, feature_extractor, image_encoder, vae, unet
110
+
111
+ def set_processors(attentions):
112
+ for attn in attentions: attn.set_processor(AttnProcessor2_0())
113
+
114
+ def is_attn(name):
115
+ return ('attn1' or 'attn2' == name.split('.')[-1])
116
+
117
+ def set_torch_2_attn(unet):
118
+ optim_count = 0
119
+
120
+ for name, module in unet.named_modules():
121
+ if is_attn(name):
122
+ if isinstance(module, torch.nn.ModuleList):
123
+ for m in module:
124
+ if isinstance(m, BasicTransformerBlock):
125
+ set_processors([m.attn1, m.attn2])
126
+ optim_count += 1
127
+ if optim_count > 0:
128
+ print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
129
+
130
+ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
131
+ try:
132
+ is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
133
+ enable_torch_2 = is_torch_2 and enable_torch_2_attn
134
+
135
+ if enable_xformers_memory_efficient_attention and not enable_torch_2:
136
+ if is_xformers_available():
137
+ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
138
+ unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
139
+ else:
140
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
141
+
142
+ if enable_torch_2:
143
+ set_torch_2_attn(unet)
144
+
145
+ except Exception as e:
146
+ print(e)
147
+ print("Could not enable memory efficient attention for xformers or Torch 2.0.")
148
+
149
+
150
+ class P2PEulerDiscreteScheduler(EulerDiscreteScheduler):
151
+
152
+ def step(
153
+ self,
154
+ model_output: torch.FloatTensor,
155
+ timestep: Union[float, torch.FloatTensor],
156
+ sample: torch.FloatTensor,
157
+ conditional_latents: torch.FloatTensor = None,
158
+ s_churn: float = 0.0,
159
+ s_tmin: float = 0.0,
160
+ s_tmax: float = float("inf"),
161
+ s_noise: float = 1.0,
162
+ generator: Optional[torch.Generator] = None,
163
+ return_dict: bool = True,
164
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
165
+ """
166
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
167
+ process from the learned model outputs (most often the predicted noise).
168
+
169
+ Args:
170
+ model_output (`torch.FloatTensor`):
171
+ The direct output from learned diffusion model.
172
+ timestep (`float`):
173
+ The current discrete timestep in the diffusion chain.
174
+ sample (`torch.FloatTensor`):
175
+ A current instance of a sample created by the diffusion process.
176
+ s_churn (`float`):
177
+ s_tmin (`float`):
178
+ s_tmax (`float`):
179
+ s_noise (`float`, defaults to 1.0):
180
+ Scaling factor for noise added to the sample.
181
+ generator (`torch.Generator`, *optional*):
182
+ A random number generator.
183
+ return_dict (`bool`):
184
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
185
+ tuple.
186
+
187
+ Returns:
188
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
189
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
190
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
191
+ """
192
+
193
+ if (
194
+ isinstance(timestep, int)
195
+ or isinstance(timestep, torch.IntTensor)
196
+ or isinstance(timestep, torch.LongTensor)
197
+ ):
198
+ raise ValueError(
199
+ (
200
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
201
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
202
+ " one of the `scheduler.timesteps` as a timestep."
203
+ ),
204
+ )
205
+
206
+ if not self.is_scale_input_called:
207
+ logger.warning(
208
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
209
+ "See `StableDiffusionPipeline` for a usage example."
210
+ )
211
+
212
+ if self.step_index is None:
213
+ self._init_step_index(timestep)
214
+
215
+ # Upcast to avoid precision issues when computing prev_sample
216
+ sample = sample.to(torch.float32)
217
+
218
+ sigma = self.sigmas[self.step_index]
219
+
220
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
221
+
222
+ noise = randn_tensor(
223
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
224
+ )
225
+
226
+ eps = noise * s_noise
227
+ sigma_hat = sigma * (gamma + 1)
228
+
229
+ if gamma > 0:
230
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
231
+
232
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
233
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
234
+ # backwards compatibility
235
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
236
+ pred_original_sample = model_output
237
+ elif self.config.prediction_type == "epsilon":
238
+ pred_original_sample = sample - sigma_hat * model_output
239
+ elif self.config.prediction_type == "v_prediction":
240
+ # denoised = model_output * c_out + input * c_skip
241
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
242
+ else:
243
+ raise ValueError(
244
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
245
+ )
246
+
247
+ for controller in self.controller:
248
+ pred_original_sample = controller.step_callback(pred_original_sample)
249
+ # color preservation
250
+ #color_delta = torch.mean(conditional_latents[:,0:1,:,:,:]) - torch.mean(pred_original_sample[:,0:1,:,:,:])
251
+ #print("color_delta", color_delta)
252
+ #pred_original_sample = pred_original_sample + color_delta
253
+
254
+
255
+
256
+ # 2. Convert to an ODE derivative
257
+ derivative = (sample - pred_original_sample) / sigma_hat
258
+
259
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
260
+
261
+ prev_sample = sample + derivative * dt
262
+
263
+ # Cast sample back to model compatible dtype
264
+ prev_sample = prev_sample.to(model_output.dtype)
265
+
266
+ # upon completion increase step index by one
267
+ self._step_index += 1
268
+
269
+ if not return_dict:
270
+ return (prev_sample,)
271
+
272
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
273
+
274
+ @dataclass
275
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
276
+ r"""
277
+ Output class for zero-shot text-to-video pipeline.
278
+
279
+ Args:
280
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
281
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
282
+ num_channels)`.
283
+ """
284
+
285
+ frames: Union[List[PIL.Image.Image], np.ndarray]
286
+ latents: torch.Tensor
287
+
288
+ class P2PStableVideoDiffusionPipeline(StableVideoDiffusionPipeline):
289
+
290
+ def _encode_vae_image(
291
+ self,
292
+ image: torch.Tensor,
293
+ device,
294
+ num_videos_per_prompt,
295
+ do_classifier_free_guidance,
296
+ image_latents: torch.Tensor = None
297
+ ):
298
+ if image_latents is None:
299
+ image = image.to(device=device)
300
+ image_latents = self.vae.encode(image).latent_dist.mode()
301
+ else:
302
+ image_latents = rearrange(image_latents, "b f c h w -> (b f) c h w")
303
+
304
+ if do_classifier_free_guidance:
305
+ negative_image_latents = torch.zeros_like(image_latents)
306
+
307
+ # For classifier free guidance, we need to do two forward passes.
308
+ # Here we concatenate the unconditional and text embeddings into a single batch
309
+ # to avoid doing two forward passes
310
+ image_latents = torch.cat([negative_image_latents, image_latents])
311
+
312
+ # duplicate image_latents for each generation per prompt, using mps friendly method
313
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
314
+
315
+ return image_latents
316
+
317
+ @torch.no_grad()
318
+ def __call__(
319
+ self,
320
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
321
+ edited_firstframe: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] = None,
322
+ image_latents: torch.FloatTensor = None,
323
+ height: int = 576,
324
+ width: int = 1024,
325
+ num_frames: Optional[int] = None,
326
+ num_inference_steps: int = 25,
327
+ min_guidance_scale: float = 1.0,
328
+ max_guidance_scale: float = 2.5,
329
+ fps: int = 7,
330
+ motion_bucket_id: int = 127,
331
+ noise_aug_strength: int = 0.02,
332
+ decode_chunk_size: Optional[int] = None,
333
+ num_videos_per_prompt: Optional[int] = 1,
334
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
335
+ latents: Optional[torch.FloatTensor] = None,
336
+ output_type: Optional[str] = "pil",
337
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
338
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
339
+ return_dict: bool = True,
340
+ ):
341
+ r"""
342
+ The call function to the pipeline for generation.
343
+
344
+ Args:
345
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
346
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
347
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
348
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
349
+ The height in pixels of the generated image.
350
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
351
+ The width in pixels of the generated image.
352
+ num_frames (`int`, *optional*):
353
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
354
+ num_inference_steps (`int`, *optional*, defaults to 25):
355
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
356
+ expense of slower inference. This parameter is modulated by `strength`.
357
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
358
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
359
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
360
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
361
+ fps (`int`, *optional*, defaults to 7):
362
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
363
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
364
+ motion_bucket_id (`int`, *optional*, defaults to 127):
365
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
366
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
367
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
368
+ decode_chunk_size (`int`, *optional*):
369
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
370
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
371
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
372
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
373
+ The number of images to generate per prompt.
374
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
375
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
376
+ generation deterministic.
377
+ latents (`torch.FloatTensor`, *optional*):
378
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
379
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
380
+ tensor is generated by sampling using the supplied random `generator`.
381
+ output_type (`str`, *optional*, defaults to `"pil"`):
382
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
383
+ callback_on_step_end (`Callable`, *optional*):
384
+ A function that calls at the end of each denoising steps during the inference. The function is called
385
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
386
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
387
+ `callback_on_step_end_tensor_inputs`.
388
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
389
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
390
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
391
+ `._callback_tensor_inputs` attribute of your pipeline class.
392
+ return_dict (`bool`, *optional*, defaults to `True`):
393
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
394
+ plain tuple.
395
+
396
+ Returns:
397
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
398
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
399
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
400
+
401
+ Examples:
402
+
403
+ ```py
404
+ from diffusers import StableVideoDiffusionPipeline
405
+ from diffusers.utils import load_image, export_to_video
406
+
407
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
408
+ pipe.to("cuda")
409
+
410
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
411
+ image = image.resize((1024, 576))
412
+
413
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
414
+ export_to_video(frames, "generated.mp4", fps=7)
415
+ ```
416
+ """
417
+ # 0. Default height and width to unet
418
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
419
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
420
+
421
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
422
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
423
+
424
+ # 1. Check inputs. Raise error if not correct
425
+ self.check_inputs(image, height, width)
426
+
427
+ # 2. Define call parameters
428
+ if isinstance(image, PIL.Image.Image):
429
+ batch_size = 1
430
+ elif isinstance(image, list):
431
+ batch_size = len(image)
432
+ else:
433
+ batch_size = image.shape[0]
434
+ device = self._execution_device
435
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
436
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
437
+ # corresponds to doing no classifier free guidance.
438
+ self._guidance_scale = max_guidance_scale
439
+
440
+ # 3. Encode input image
441
+ edited_firstframe = image if edited_firstframe is None else edited_firstframe
442
+ image_embeddings = self._encode_image(edited_firstframe, device, num_videos_per_prompt, self.do_classifier_free_guidance)
443
+
444
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
445
+ # is why it is reduced here.
446
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
447
+ fps = fps - 1
448
+
449
+ # 4. Encode input image using VAE
450
+ image = self.image_processor.preprocess(image, height=height, width=width)
451
+ edited_firstframe = self.image_processor.preprocess(edited_firstframe, height=height, width=width)
452
+ #print("before vae", image.min(), image.max())
453
+ #noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
454
+ #image = image + noise_aug_strength * noise
455
+ #edited_firstframe = edited_firstframe + noise_aug_strength * noise
456
+ if image_latents is not None:
457
+ #noise_tmp = randn_tensor(image_latents.shape, generator=generator, device=image_latents.device, dtype=image_latents.dtype)
458
+ image_latents = image_latents / self.vae.config.scaling_factor
459
+ #image_latents = image_latents + noise_aug_strength * noise_tmp
460
+
461
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
462
+ if needs_upcasting:
463
+ self.vae.to(dtype=torch.float32)
464
+
465
+ #print("before vae", image.min(), image.max())
466
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance, image_latents = image_latents)
467
+ firstframe_latents = self._encode_vae_image(edited_firstframe, device, num_videos_per_prompt, self.do_classifier_free_guidance)
468
+ noise = randn_tensor(image_latents.shape, generator=generator, device=image_latents.device, dtype=image_latents.dtype)[1:]
469
+ image_latents[1:] = image_latents[1:] + noise_aug_strength * noise #/ self.vae.config.scaling_factor
470
+ #firstframe_latents = firstframe_latents + noise_aug_strength * noise / self.vae.config.scaling_factor
471
+
472
+ image_latents = image_latents.to(image_embeddings.dtype)
473
+ firstframe_latents = firstframe_latents.to(image_embeddings.dtype)
474
+
475
+ # cast back to fp16 if needed
476
+ if needs_upcasting:
477
+ self.vae.to(dtype=torch.float16)
478
+
479
+ # Repeat the image latents for each frame so we can concatenate them with the noise
480
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
481
+
482
+ skip=num_frames
483
+ image_latents = torch.cat(
484
+ [
485
+ image_latents.unsqueeze(1).repeat(1, skip, 1, 1, 1),
486
+ firstframe_latents.unsqueeze(1).repeat(1, num_frames-skip, 1, 1, 1)
487
+ ],
488
+ dim=1
489
+ )
490
+
491
+ #image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
492
+ #print("image latents", image_latents.min(), image_latents.max())
493
+
494
+ # 5. Get Added Time IDs
495
+ added_time_ids = self._get_add_time_ids(
496
+ fps,
497
+ motion_bucket_id,
498
+ noise_aug_strength,
499
+ image_embeddings.dtype,
500
+ batch_size,
501
+ num_videos_per_prompt,
502
+ self.do_classifier_free_guidance,
503
+ )
504
+ added_time_ids = added_time_ids.to(device)
505
+
506
+ # 4. Prepare timesteps
507
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
508
+ timesteps = self.scheduler.timesteps
509
+
510
+ # 5. Prepare latent variables
511
+ num_channels_latents = self.unet.config.in_channels
512
+ latents = self.prepare_latents(
513
+ batch_size * num_videos_per_prompt,
514
+ num_frames,
515
+ num_channels_latents,
516
+ height,
517
+ width,
518
+ image_embeddings.dtype,
519
+ device,
520
+ generator,
521
+ latents,
522
+ )
523
+
524
+ # 7. Prepare guidance scale
525
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
526
+ guidance_scale = guidance_scale.to(device, latents.dtype)
527
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
528
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
529
+
530
+ self._guidance_scale = guidance_scale
531
+
532
+ # 8. Denoising loop
533
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
534
+ self._num_timesteps = len(timesteps)
535
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
536
+ for i, t in enumerate(timesteps):
537
+ # expand the latents if we are doing classifier free guidance
538
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
539
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
540
+
541
+ # Concatenate image_latents over channels dimention
542
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
543
+
544
+ # predict the noise residual
545
+ noise_pred = self.unet(
546
+ latent_model_input,
547
+ t,
548
+ encoder_hidden_states=image_embeddings,
549
+ added_time_ids=added_time_ids,
550
+ return_dict=False,
551
+ )[0]
552
+
553
+ # perform guidance
554
+ if self.do_classifier_free_guidance:
555
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
556
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
557
+
558
+ # compute the previous noisy sample x_t -> x_t-1
559
+ #conditional_latents = image_latents.chunk(2)[1] if self.do_classifier_free_guidance else image_latents
560
+ #latents = self.scheduler.step(noise_pred, t, latents, conditional_latents=conditional_latents*self.vae.config.scaling_factor).prev_sample
561
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
562
+
563
+ if callback_on_step_end is not None:
564
+ callback_kwargs = {}
565
+ for k in callback_on_step_end_tensor_inputs:
566
+ callback_kwargs[k] = locals()[k]
567
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
568
+
569
+ latents = callback_outputs.pop("latents", latents)
570
+
571
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
572
+ progress_bar.update()
573
+
574
+ if not output_type == "latent":
575
+ # cast back to fp16 if needed
576
+ if needs_upcasting:
577
+ self.vae.to(dtype=torch.float16)
578
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
579
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
580
+ else:
581
+ frames = latents
582
+
583
+ self.maybe_free_model_hooks()
584
+
585
+ if not return_dict:
586
+ return frames
587
+
588
+ return StableVideoDiffusionPipelineOutput(frames=frames, latents=latents)
i2vedit/utils/svd_util.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib
3
+ import os
4
+ from functools import partial
5
+ from inspect import isfunction
6
+
7
+ import fsspec
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from safetensors.torch import load_file as load_safetensors
13
+
14
+ import decord
15
+ from einops import rearrange, repeat
16
+ from torchvision.transforms import Resize, Pad, InterpolationMode
17
+ from torch import nn
18
+
19
+
20
+ def disabled_train(self, mode=True):
21
+ """Overwrite model.train with this function to make sure train/eval mode
22
+ does not change anymore."""
23
+ return self
24
+
25
+
26
+ def get_string_from_tuple(s):
27
+ try:
28
+ # Check if the string starts and ends with parentheses
29
+ if s[0] == "(" and s[-1] == ")":
30
+ # Convert the string to a tuple
31
+ t = eval(s)
32
+ # Check if the type of t is tuple
33
+ if type(t) == tuple:
34
+ return t[0]
35
+ else:
36
+ pass
37
+ except:
38
+ pass
39
+ return s
40
+
41
+
42
+ def is_power_of_two(n):
43
+ """
44
+ chat.openai.com/chat
45
+ Return True if n is a power of 2, otherwise return False.
46
+
47
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
48
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
49
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
50
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
51
+
52
+ """
53
+ if n <= 0:
54
+ return False
55
+ return (n & (n - 1)) == 0
56
+
57
+
58
+ def autocast(f, enabled=True):
59
+ def do_autocast(*args, **kwargs):
60
+ with torch.cuda.amp.autocast(
61
+ enabled=enabled,
62
+ dtype=torch.get_autocast_gpu_dtype(),
63
+ cache_enabled=torch.is_autocast_cache_enabled(),
64
+ ):
65
+ return f(*args, **kwargs)
66
+
67
+ return do_autocast
68
+
69
+
70
+ def load_partial_from_config(config):
71
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
72
+
73
+
74
+ def log_txt_as_img(wh, xc, size=10):
75
+ # wh a tuple of (width, height)
76
+ # xc a list of captions to plot
77
+ b = len(xc)
78
+ txts = list()
79
+ for bi in range(b):
80
+ txt = Image.new("RGB", wh, color="white")
81
+ draw = ImageDraw.Draw(txt)
82
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
83
+ nc = int(40 * (wh[0] / 256))
84
+ if isinstance(xc[bi], list):
85
+ text_seq = xc[bi][0]
86
+ else:
87
+ text_seq = xc[bi]
88
+ lines = "\n".join(
89
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
90
+ )
91
+
92
+ try:
93
+ draw.text((0, 0), lines, fill="black", font=font)
94
+ except UnicodeEncodeError:
95
+ print("Cant encode string for logging. Skipping.")
96
+
97
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
98
+ txts.append(txt)
99
+ txts = np.stack(txts)
100
+ txts = torch.tensor(txts)
101
+ return txts
102
+
103
+
104
+ def partialclass(cls, *args, **kwargs):
105
+ class NewCls(cls):
106
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
107
+
108
+ return NewCls
109
+
110
+
111
+ def make_path_absolute(path):
112
+ fs, p = fsspec.core.url_to_fs(path)
113
+ if fs.protocol == "file":
114
+ return os.path.abspath(p)
115
+ return path
116
+
117
+
118
+ def ismap(x):
119
+ if not isinstance(x, torch.Tensor):
120
+ return False
121
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
122
+
123
+
124
+ def isimage(x):
125
+ if not isinstance(x, torch.Tensor):
126
+ return False
127
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
128
+
129
+
130
+ def isheatmap(x):
131
+ if not isinstance(x, torch.Tensor):
132
+ return False
133
+
134
+ return x.ndim == 2
135
+
136
+
137
+ def isneighbors(x):
138
+ if not isinstance(x, torch.Tensor):
139
+ return False
140
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
141
+
142
+
143
+ def exists(x):
144
+ return x is not None
145
+
146
+
147
+ def expand_dims_like(x, y):
148
+ while x.dim() != y.dim():
149
+ x = x.unsqueeze(-1)
150
+ return x
151
+
152
+
153
+ def default(val, d):
154
+ if exists(val):
155
+ return val
156
+ return d() if isfunction(d) else d
157
+
158
+
159
+ def mean_flat(tensor):
160
+ """
161
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
162
+ Take the mean over all non-batch dimensions.
163
+ """
164
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
165
+
166
+
167
+ def count_params(model, verbose=False):
168
+ total_params = sum(p.numel() for p in model.parameters())
169
+ if verbose:
170
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
171
+ return total_params
172
+
173
+
174
+ def instantiate_from_config(config):
175
+ if not "target" in config:
176
+ if config == "__is_first_stage__":
177
+ return None
178
+ elif config == "__is_unconditional__":
179
+ return None
180
+ raise KeyError("Expected key `target` to instantiate.")
181
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
182
+
183
+
184
+ def get_obj_from_str(string, reload=False, invalidate_cache=True):
185
+ module, cls = string.rsplit(".", 1)
186
+ if invalidate_cache:
187
+ importlib.invalidate_caches()
188
+ if reload:
189
+ module_imp = importlib.import_module(module)
190
+ importlib.reload(module_imp)
191
+ return getattr(importlib.import_module(module, package=None), cls)
192
+
193
+
194
+ def append_zero(x):
195
+ return torch.cat([x, x.new_zeros([1])])
196
+
197
+
198
+ def append_dims(x, target_dims):
199
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
200
+ dims_to_append = target_dims - x.ndim
201
+ if dims_to_append < 0:
202
+ raise ValueError(
203
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
204
+ )
205
+ return x[(...,) + (None,) * dims_to_append]
206
+
207
+
208
+ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
209
+ print(f"Loading model from {ckpt}")
210
+ if ckpt.endswith("ckpt"):
211
+ pl_sd = torch.load(ckpt, map_location="cpu")
212
+ if "global_step" in pl_sd:
213
+ print(f"Global Step: {pl_sd['global_step']}")
214
+ sd = pl_sd["state_dict"]
215
+ elif ckpt.endswith("safetensors"):
216
+ sd = load_safetensors(ckpt)
217
+ else:
218
+ raise NotImplementedError
219
+
220
+ model = instantiate_from_config(config.model)
221
+
222
+ m, u = model.load_state_dict(sd, strict=False)
223
+
224
+ if len(m) > 0 and verbose:
225
+ print("missing keys:")
226
+ print(m)
227
+ if len(u) > 0 and verbose:
228
+ print("unexpected keys:")
229
+ print(u)
230
+
231
+ if freeze:
232
+ for param in model.parameters():
233
+ param.requires_grad = False
234
+
235
+ model.eval()
236
+ return model
237
+
238
+
239
+ def get_configs_path() -> str:
240
+ """
241
+ Get the `configs` directory.
242
+ For a working copy, this is the one in the root of the repository,
243
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
244
+ """
245
+ this_dir = os.path.dirname(__file__)
246
+ candidates = (
247
+ os.path.join(this_dir, "configs"),
248
+ os.path.join(this_dir, "..", "configs"),
249
+ )
250
+ for candidate in candidates:
251
+ candidate = os.path.abspath(candidate)
252
+ if os.path.isdir(candidate):
253
+ return candidate
254
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
255
+
256
+
257
+ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
258
+ """
259
+ Will return the result of a recursive get attribute call.
260
+ E.g.:
261
+ a.b.c
262
+ = getattr(getattr(a, "b"), "c")
263
+ = get_nested_attribute(a, "b.c")
264
+ If any part of the attribute call is an integer x with current obj a, will
265
+ try to call a[x] instead of a.x first.
266
+ """
267
+ attributes = attribute_path.split(".")
268
+ if depth is not None and depth > 0:
269
+ attributes = attributes[:depth]
270
+ assert len(attributes) > 0, "At least one attribute should be selected"
271
+ current_attribute = obj
272
+ current_key = None
273
+ for level, attribute in enumerate(attributes):
274
+ current_key = ".".join(attributes[: level + 1])
275
+ try:
276
+ id_ = int(attribute)
277
+ current_attribute = current_attribute[id_]
278
+ except ValueError:
279
+ current_attribute = getattr(current_attribute, attribute)
280
+
281
+ return (current_attribute, current_key) if return_key else current_attribute
282
+
283
+ def pad_with_ratio(frames, res):
284
+ _, _, ih, iw = frames.shape
285
+ #print("ih, iw", ih, iw)
286
+ i_ratio = ih / iw
287
+ h, w = res
288
+ #print("h,w", h ,w)
289
+ n_ratio = h / w
290
+ if i_ratio > n_ratio:
291
+ nw = int(ih / h * w)
292
+ #print("nw", nw)
293
+ frames = Pad((nw - iw)//2)(frames)
294
+ frames = frames[...,(nw - iw)//2:-(nw - iw)//2,:]
295
+ else:
296
+ nh = int(iw / w * h)
297
+ frames = Pad((nh - ih)//2)(frames)
298
+ frames = frames[...,:,(nh - ih)//2:-(nh - ih)//2]
299
+ #print("after pad", frames.shape)
300
+ return frames
301
+
302
+ def prepare_video(video_path:str, resolution, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1, pad_to_fix=False):
303
+ vr = decord.VideoReader(video_path)
304
+ initial_fps = vr.get_avg_fps()
305
+ if output_fps == -1:
306
+ output_fps = int(initial_fps)
307
+ if end_t == -1:
308
+ end_t = len(vr) / initial_fps
309
+ else:
310
+ end_t = min(len(vr) / initial_fps, end_t)
311
+ assert 0 <= start_t < end_t
312
+ assert output_fps > 0
313
+ start_f_ind = int(start_t * initial_fps)
314
+ end_f_ind = int(end_t * initial_fps)
315
+ num_f = int((end_t - start_t) * output_fps)
316
+ sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
317
+ video = vr.get_batch(sample_idx)
318
+ if torch.is_tensor(video):
319
+ video = video.detach().cpu().numpy()
320
+ else:
321
+ video = video.asnumpy()
322
+ _, h, w, _ = video.shape
323
+ video = rearrange(video, "f h w c -> f c h w")
324
+ video = torch.Tensor(video).to(device).to(dtype)
325
+
326
+ # Use max if you want the larger side to be equal to resolution (e.g. 512)
327
+ # k = float(resolution) / min(h, w)
328
+ if pad_to_fix and resolution is not None:
329
+ video = pad_with_ratio(video, resolution)
330
+ if isinstance(resolution, tuple):
331
+ #video = Resize(resolution, interpolation=InterpolationMode.BICUBIC, antialias=True)(video)
332
+ video = nn.functional.interpolate(video, size=resolution, mode='bilinear')
333
+ else:
334
+ k = float(resolution) / max(h, w)
335
+ h *= k
336
+ w *= k
337
+ h = int(np.round(h / 64.0)) * 64
338
+ w = int(np.round(w / 64.0)) * 64
339
+ video = Resize((h, w), interpolation=InterpolationMode.BICUBIC, antialias=True)(video)
340
+
341
+ if normalize:
342
+ video = video / 127.5 - 1.0
343
+
344
+ return video, output_fps
345
+
346
+ def return_to_original_res(frames, res, pad_to_fix=False):
347
+ #print("original res", res)
348
+ _, _, h, w = frames.shape
349
+ #print("h w", h, w)
350
+ n_ratio = h / w
351
+ ih, iw = res
352
+ i_ratio = ih / iw
353
+ if pad_to_fix:
354
+ if i_ratio > n_ratio:
355
+ nw = int(ih / h * w)
356
+ frames = Resize((ih, iw+2*(nw - iw)//2), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
357
+ frames = frames[...,:,(nw - iw)//2:-(nw - iw)//2]
358
+ else:
359
+ nh = int(iw / w * h)
360
+ frames = Resize((ih+2*(nh - ih)//2, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
361
+
362
+ frames = frames[...,(nh - ih)//2:-(nh - ih)//2,:]
363
+ else:
364
+ frames = Resize((ih, iw), interpolation=InterpolationMode.BICUBIC, antialias=True)(frames)
365
+
366
+ return frames
367
+
368
+ class SmoothAreaRandomDetection(object):
369
+
370
+ def __init__(self, device="cuda", dtype=torch.float16):
371
+
372
+ kernel_x = torch.zeros(3,3,3,3)
373
+ for i in range(3):
374
+ kernel_x[i,i,:,:] = torch.Tensor([[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]])
375
+ kernel_y = torch.zeros(3,3,3,3)
376
+ for i in range(3):
377
+ kernel_y[i,i,:,:] = torch.Tensor([[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]])
378
+ kernel_x = kernel_x.to(device, dtype)
379
+ kernel_y = kernel_y.to(device, dtype)
380
+ self.weight_x = kernel_x
381
+ self.weight_y = kernel_y
382
+
383
+ self.eps = 1/256.
384
+
385
+ def detection(self, x, thr=0.0):
386
+ original_dim = x.ndim
387
+ if x.ndim > 4:
388
+ b, f, c, h, w = x.shape
389
+ x = rearrange(x, "b f c h w -> (b f) c h w")
390
+ grad_xx = F.conv2d(x, self.weight_x, stride=1, padding=1)
391
+ grad_yx = F.conv2d(x, self.weight_y, stride=1, padding=1)
392
+ gradient_x = torch.abs(grad_xx) + torch.abs(grad_yx)
393
+ gradient_x = torch.mean(gradient_x, dim=1, keepdim=True)
394
+ gradient_x = repeat(gradient_x, "b 1 ... -> b 3 ...")
395
+ if original_dim > 4:
396
+ gradient_x = rearrange(gradient_x, "(b f) c h w -> b f c h w", b=b)
397
+ return gradient_x <= thr
i2vedit/version.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # GENERATED VERSION FILE
2
+ # TIME: Thu May 29 22:38:49 2025
3
+ __version__ = '0.1.0-dev'
4
+ __gitsha__ = 'unknown'
5
+ version_info = (0, 1, "0-dev")
main.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import random
8
+ import gc
9
+ import copy
10
+ import imageio
11
+ import numpy as np
12
+ from PIL import Image
13
+ from scipy.stats import anderson
14
+ from typing import Dict, Optional, Tuple, List
15
+ from omegaconf import OmegaConf
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ from torchvision import transforms
21
+ from torchvision.transforms import ToTensor
22
+ from tqdm.auto import tqdm
23
+
24
+ from accelerate import Accelerator
25
+ from accelerate.logging import get_logger
26
+ from accelerate.utils import set_seed
27
+
28
+ import transformers
29
+ from transformers import CLIPTextModel, CLIPTokenizer
30
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
31
+ from transformers.models.clip.modeling_clip import CLIPEncoder
32
+
33
+ import diffusers
34
+ from diffusers.models import AutoencoderKL
35
+ from diffusers import DDIMScheduler, TextToVideoSDPipeline
36
+ from diffusers.optimization import get_scheduler
37
+ from diffusers.utils.import_utils import is_xformers_available
38
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
39
+ from diffusers.models.attention import BasicTransformerBlock
40
+ from diffusers import StableVideoDiffusionPipeline
41
+ from diffusers.models.lora import LoRALinearLayer
42
+ from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
43
+ from diffusers.image_processor import VaeImageProcessor
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import EMAModel
46
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
47
+ from diffusers.utils.import_utils import is_xformers_available
48
+ from diffusers.models.unet_3d_blocks import \
49
+ (CrossAttnDownBlockSpatioTemporal,
50
+ DownBlockSpatioTemporal,
51
+ CrossAttnUpBlockSpatioTemporal,
52
+ UpBlockSpatioTemporal)
53
+ from i2vedit.utils.dataset import VideoJsonDataset, SingleVideoDataset, \
54
+ ImageDataset, VideoFolderDataset, CachedDataset, \
55
+ pad_with_ratio, return_to_original_res
56
+ from einops import rearrange, repeat
57
+ from i2vedit.utils.lora_handler import LoraHandler
58
+ from i2vedit.utils.lora import extract_lora_child_module
59
+ from i2vedit.utils.euler_utils import euler_inversion
60
+ from i2vedit.utils.svd_util import SmoothAreaRandomDetection
61
+
62
+ from i2vedit.data import VideoIO, SingleClipDataset, ResolutionControl
63
+ #from utils.model_utils import load_primary_models
64
+ from i2vedit.utils.euler_utils import inverse_video
65
+ from i2vedit.train import train_motion_lora, load_images_from_list
66
+ from i2vedit.inference import initialize_pipeline
67
+ from i2vedit.utils.model_utils import P2PEulerDiscreteScheduler, P2PStableVideoDiffusionPipeline
68
+ from i2vedit.prompt_attention import attention_util
69
+
70
+ def create_output_folders(output_dir, config):
71
+ os.makedirs(output_dir, exist_ok=True)
72
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
73
+ return output_dir
74
+
75
+
76
+
77
+ def main(
78
+ pretrained_model_path: str,
79
+ data_params: Dict,
80
+ train_motion_lora_params: Dict,
81
+ sarp_params: Dict,
82
+ attention_matching_params: Dict,
83
+ long_video_params: Dict = {"mode": "skip-interval"},
84
+ use_sarp: bool = True,
85
+ use_motion_lora: bool = True,
86
+ train_motion_lora_only: bool = False,
87
+ retrain_motion_lora: bool = True,
88
+ use_inversed_latents: bool = True,
89
+ use_attention_matching: bool = True,
90
+ use_consistency_attention_control: bool = False,
91
+ output_dir: str = "./outputs",
92
+ num_steps: int = 25,
93
+ device: str = "cuda",
94
+ seed: int = 23,
95
+ enable_xformers_memory_efficient_attention: bool = True,
96
+ enable_torch_2_attn: bool = False,
97
+ dtype: str = 'fp16',
98
+ load_from_last_frames_latents: List[str] = None,
99
+ save_last_frames: bool = True,
100
+ visualize_attention_store: bool = False,
101
+ visualize_attention_store_steps: List[int] = None,
102
+ use_latent_blend: bool = False,
103
+ use_previous_latent_for_train: bool = False,
104
+ use_latent_noise: bool = True,
105
+ load_from_previous_consistency_store_controller: str = None,
106
+ load_from_previous_consistency_edit_controller: List[str] = None
107
+ ):
108
+ *_, config = inspect.getargvalues(inspect.currentframe())
109
+
110
+ if dtype == "fp16":
111
+ dtype = torch.float16
112
+ elif dtype == "fp32":
113
+ dtype = torch.float32
114
+
115
+ # create folder
116
+ output_dir = create_output_folders(output_dir, config)
117
+
118
+ # prepare video data
119
+ data_params["output_dir"] = output_dir
120
+ data_params["device"] = device
121
+
122
+ videoio = VideoIO(**data_params, dtype=dtype)
123
+
124
+ # smooth area random perturbation
125
+ if use_sarp:
126
+ sard = SmoothAreaRandomDetection(device, dtype=torch.float32)
127
+ else:
128
+ sard = None
129
+
130
+ keyframe = None
131
+ previous_last_frames = load_images_from_list(data_params.keyframe_paths)
132
+ consistency_train_controller = None
133
+
134
+ if load_from_last_frames_latents is not None:
135
+ previous_last_frames_latents = [torch.load(thpath).to(device) for thpath in load_from_last_frames_latents]
136
+ else:
137
+ previous_last_frames_latents = [None,] * len(previous_last_frames)
138
+
139
+ if use_consistency_attention_control and load_from_previous_consistency_store_controller is not None:
140
+ previous_consistency_store_controller = attention_util.ConsistencyAttentionControl(
141
+ additional_attention_store=None,
142
+ use_inversion_attention=False,
143
+ save_self_attention=True,
144
+ save_latents=False,
145
+ disk_store=True,
146
+ load_attention_store=os.path.join(load_from_previous_consistency_store_controller, "clip_0")
147
+ )
148
+ else:
149
+ previous_consistency_store_controller = None
150
+
151
+ previous_consistency_edit_controller_list = [None,] * len(previous_last_frames)
152
+ if use_consistency_attention_control and load_from_previous_consistency_edit_controller is not None:
153
+ for i in range(len(load_from_previous_consistency_edit_controller)):
154
+ previous_consistency_edit_controller_list[i] = attention_util.ConsistencyAttentionControl(
155
+ additional_attention_store=None,
156
+ use_inversion_attention=False,
157
+ save_self_attention=True,
158
+ save_latents=False,
159
+ disk_store=True,
160
+ load_attention_store=os.path.join(load_from_previous_consistency_edit_controller[i], "clip_0")
161
+ )
162
+
163
+
164
+ # read data and process
165
+ for clip_id, video in enumerate(videoio.read_video_iter()):
166
+ if clip_id >= data_params.get("end_clip_id", 9):
167
+ break
168
+ if clip_id < data_params.get("begin_clip_id", 0):
169
+ continue
170
+ video = video.unsqueeze(0)
171
+
172
+ resctrl = ResolutionControl(video.shape[-2:], data_params.output_res, data_params.pad_to_fit, fill=-1)
173
+
174
+ # update keyframe and edited keyframe
175
+ if long_video_params.mode == "skip-interval":
176
+ assert data_params.overlay_size > 0
177
+ # save the first frame as the keyframe for cross-attention
178
+ #if clip_id == 0:
179
+ firstframe = video[:,0:1,:,:,:]
180
+ keyframe = video[:,0:1,:,:,:]
181
+ edited_keyframes = copy.deepcopy(previous_last_frames)
182
+ edited_firstframes = edited_keyframes
183
+ #edited_firstframes = load_images_from_list(data_params.keyframe_paths)
184
+
185
+ elif long_video_params.mode == "auto-regressive":
186
+ assert data_params.overlay_size == 1
187
+ firstframe = video[:,0:1,:,:,:]
188
+ keyframe = video[:,0:1,:,:,:]
189
+ edited_keyframes = copy.deepcopy(previous_last_frames)
190
+ edited_firstframes = edited_keyframes
191
+
192
+ # register for unet, perform inversion
193
+ load_attention_store = None
194
+ if use_attention_matching:
195
+ assert use_inversed_latents, "inversion is disabled."
196
+ if attention_matching_params.get("load_attention_store") is not None:
197
+ load_attention_store = os.path.join(attention_matching_params.get("load_attention_store"), f"clip_{clip_id}")
198
+ if not os.path.exists(load_attention_store):
199
+ print(f"Load {load_attention_store} failed, folder doesn't exists.")
200
+ load_attention_store = None
201
+
202
+ store_controller = attention_util.AttentionStore(
203
+ disk_store=attention_matching_params.disk_store,
204
+ save_latents = use_latent_blend,
205
+ save_self_attention=True,
206
+ load_attention_store=load_attention_store,
207
+ store_path=os.path.join(output_dir, "attention_store", f"clip_{clip_id}")
208
+ )
209
+ print("store_controller.store_dir:", store_controller.store_dir)
210
+ else:
211
+ store_controller = None
212
+
213
+ load_consistency_attention_store = None
214
+ if use_consistency_attention_control:
215
+ if clip_id==0 and attention_matching_params.get("load_consistency_attention_store") is not None:
216
+ load_consistency_attention_store = os.path.join(attention_matching_params.get("load_consistency_attention_store"), f"clip_{clip_id}")
217
+ if not os.path.exists(load_consistency_attention_store):
218
+ print(f"Load {load_consistency_attention_store} failed, folder doesn't exists.")
219
+ load_consistency_attention_store = None
220
+
221
+ consistency_store_controller = attention_util.ConsistencyAttentionControl(
222
+ additional_attention_store=previous_consistency_store_controller,
223
+ use_inversion_attention=False,
224
+ save_self_attention=(clip_id==0),
225
+ load_attention_store=load_consistency_attention_store,
226
+ save_latents=False,
227
+ disk_store=True,
228
+ store_path=os.path.join(output_dir, "consistency_attention_store", f"clip_{clip_id}")
229
+ )
230
+ print("consistency_store_controller.store_dir:", consistency_store_controller.store_dir)
231
+ else:
232
+ consistency_store_controller = None
233
+
234
+ if train_motion_lora_only:
235
+ assert use_motion_lora and retrain_motion_lora, "use_motion_lora/retrain_motion_lora should be enbled to train motion lora only."
236
+
237
+ # perform smooth area random perturbation
238
+ if use_inversed_latents:
239
+ print("begin inversion sampling for inference...")
240
+ inversion_noise = inverse_video(
241
+ pretrained_model_path,
242
+ video,
243
+ keyframe,
244
+ firstframe,
245
+ num_steps,
246
+ resctrl,
247
+ sard,
248
+ enable_xformers_memory_efficient_attention,
249
+ enable_torch_2_attn,
250
+ store_controller = store_controller,
251
+ consistency_store_controller = consistency_store_controller,
252
+ find_modules=attention_matching_params.registered_modules if load_attention_store is None else {},
253
+ consistency_find_modules=long_video_params.registered_modules if load_consistency_attention_store is None else {},
254
+ # dtype=dtype,
255
+ **sarp_params,
256
+ )
257
+ else:
258
+ if use_motion_lora and retrain_motion_lora:
259
+ assert not any([np > 0 for np in train_motion_lora_params.validation_data.noise_prior]), "inversion noise is not calculated but validation during motion lora training aims to use inversion noise as input latents."
260
+ inversion_noise = None
261
+
262
+
263
+ if use_motion_lora:
264
+ if retrain_motion_lora:
265
+ if use_consistency_attention_control:
266
+ if data_params.output_res[0] != train_motion_lora_params.train_data.height or \
267
+ data_params.output_res[1] != train_motion_lora_params.train_data.width:
268
+ if consistency_train_controller is None:
269
+ load_consistency_train_attention_store = None
270
+ if attention_matching_params.get("load_consistency_train_attention_store") is not None:
271
+ load_consistency_train_attention_store = os.path.join(attention_matching_params.get("load_consistency_train_attention_store"), f"clip_0")
272
+ if not os.path.exists(load_consistency_train_attention_store):
273
+ print(f"Load {load_consistency_train_attention_store} failed, folder doesn't exists.")
274
+ load_consistency_train_attention_store = None
275
+ if load_consistency_train_attention_store is None and clip_id > 0:
276
+ raise IOError(f"load_consistency_train_attention_store can't be None for clip {clip_id}.")
277
+ consistency_train_controller = attention_util.ConsistencyAttentionControl(
278
+ additional_attention_store=None,
279
+ use_inversion_attention=False,
280
+ save_self_attention=True,
281
+ load_attention_store=load_consistency_train_attention_store,
282
+ save_latents=False,
283
+ disk_store=True,
284
+ store_path=os.path.join(output_dir, "consistency_train_attention_store", "clip_0")
285
+ )
286
+ print("consistency_train_controller.store_dir:", consistency_train_controller.store_dir)
287
+ resctrl_train = ResolutionControl(
288
+ video.shape[-2:],
289
+ (train_motion_lora_params.train_data.height,train_motion_lora_params.train_data.width),
290
+ data_params.pad_to_fit, fill=-1
291
+ )
292
+ print("begin inversion sampling for training...")
293
+ inversion_noise_train = inverse_video(
294
+ pretrained_model_path,
295
+ video,
296
+ keyframe,
297
+ firstframe,
298
+ num_steps,
299
+ resctrl_train,
300
+ sard,
301
+ enable_xformers_memory_efficient_attention,
302
+ enable_torch_2_attn,
303
+ store_controller = None,
304
+ consistency_store_controller = consistency_train_controller,
305
+ find_modules={},
306
+ consistency_find_modules=long_video_params.registered_modules if long_video_params.get("load_attention_store") is None else {},
307
+ # dtype=dtype,
308
+ **sarp_params,
309
+ )
310
+ else:
311
+ if consistency_train_controller is None:
312
+ consistency_train_controller = consistency_store_controller
313
+ else:
314
+ consistency_train_controller = None
315
+
316
+ if retrain_motion_lora:
317
+ train_dataset = SingleClipDataset(
318
+ inversion_noise=inversion_noise,
319
+ video_clip=video,
320
+ keyframe=((ToTensor()(previous_last_frames[0])-0.5)/0.5).unsqueeze(0).unsqueeze(0) if use_previous_latent_for_train else keyframe,
321
+ keyframe_latent=previous_last_frames_latents[0] if use_previous_latent_for_train else None,
322
+ firstframe=firstframe,
323
+ height=train_motion_lora_params.train_data.height,
324
+ width=train_motion_lora_params.train_data.width,
325
+ use_data_aug=train_motion_lora_params.train_data.get("use_data_aug"),
326
+ pad_to_fit=train_motion_lora_params.train_data.get("pad_to_fit", False)
327
+ )
328
+ train_motion_lora_params.validation_data.num_inference_steps = num_steps
329
+ train_motion_lora(
330
+ pretrained_model_path,
331
+ output_dir,
332
+ train_dataset,
333
+ edited_firstframes=edited_firstframes,
334
+ validation_images=edited_keyframes,
335
+ validation_images_latents=previous_last_frames_latents,
336
+ seed=seed,
337
+ clip_id=clip_id,
338
+ consistency_edit_controller_list=previous_consistency_edit_controller_list,
339
+ consistency_controller=consistency_train_controller if clip_id!=0 else None,
340
+ consistency_find_modules=long_video_params.registered_modules,
341
+ enable_xformers_memory_efficient_attention=enable_xformers_memory_efficient_attention,
342
+ enable_torch_2_attn=enable_torch_2_attn,
343
+ **train_motion_lora_params
344
+ )
345
+
346
+ if train_motion_lora_only:
347
+ if not use_consistency_attention_control:
348
+ continue
349
+
350
+ # choose and load motion lora
351
+ best_checkpoint_index = attention_matching_params.get("best_checkpoint_index", 250)
352
+ if retrain_motion_lora:
353
+ lora_dir = f"{os.path.join(output_dir,'train_motion_lora')}/clip_{clip_id}"
354
+ lora_path = f"{lora_dir}/checkpoint-{best_checkpoint_index}/temporal/lora"
355
+ else:
356
+ lora_path = f"/homw/user/app/upload/lora"
357
+ assert os.path.exists(lora_path), f"lora path: {lora_path} doesn't exist!"
358
+
359
+ lora_rank = train_motion_lora_params.lora_rank
360
+ lora_scale = attention_matching_params.get("lora_scale", 1.0)
361
+
362
+ # prepare models
363
+ pipe = initialize_pipeline(
364
+ pretrained_model_path,
365
+ device,
366
+ enable_xformers_memory_efficient_attention,
367
+ enable_torch_2_attn,
368
+ lora_path,
369
+ lora_rank,
370
+ lora_scale,
371
+ load_spatial_lora = False #(clip_id != 0)
372
+ ).to(device, dtype=dtype)
373
+ else:
374
+ pipe = P2PStableVideoDiffusionPipeline.from_pretrained(
375
+ pretrained_model_path
376
+ ).to(device, dtype=dtype)
377
+
378
+ if use_attention_matching or use_consistency_attention_control:
379
+ pipe.scheduler = P2PEulerDiscreteScheduler.from_config(pipe.scheduler.config)
380
+
381
+ generator = torch.Generator(device="cpu")
382
+ generator.manual_seed(seed)
383
+
384
+ previous_last_frames = []
385
+
386
+ editing_params = [item for name, item in attention_matching_params.params.items()]
387
+ with torch.no_grad():
388
+ with torch.autocast(device, dtype=dtype):
389
+ for kf_id, (edited_keyframe, editing_param) in enumerate(zip(edited_keyframes, editing_params)):
390
+ print(kf_id, editing_param)
391
+
392
+ # control resolution
393
+ iw, ih = edited_keyframe.size
394
+ resctrl = ResolutionControl(
395
+ (ih, iw),
396
+ data_params.output_res,
397
+ data_params.pad_to_fit,
398
+ fill=0
399
+ )
400
+ edited_keyframe = resctrl(edited_keyframe)
401
+ edited_firstframe = resctrl(edited_firstframes[kf_id])
402
+
403
+ # control attention
404
+ pipe.scheduler.controller = []
405
+ if use_attention_matching:
406
+ edit_controller = attention_util.AttentionControlEdit(
407
+ num_steps = num_steps,
408
+ cross_replace_steps = attention_matching_params.cross_replace_steps,
409
+ temporal_self_replace_steps = attention_matching_params.temporal_self_replace_steps,
410
+ spatial_self_replace_steps = attention_matching_params.spatial_self_replace_steps,
411
+ mask_thr = editing_param.get("mask_thr", 0.35),
412
+ temporal_step_thr = editing_param.get("temporal_step_thr", [0.5,0.8]),
413
+ control_mode = attention_matching_params.control_mode,
414
+ spatial_attention_chunk_size = attention_matching_params.get("spatial_attention_chunk_size", 1),
415
+ additional_attention_store = store_controller,
416
+ use_inversion_attention = True,
417
+ save_self_attention = False,
418
+ save_latents = False,
419
+ latent_blend = use_latent_blend,
420
+ disk_store = attention_matching_params.disk_store
421
+ )
422
+ pipe.scheduler.controller.append(edit_controller)
423
+ else:
424
+ edit_controller = None
425
+
426
+ if use_consistency_attention_control:
427
+ consistency_edit_controller = attention_util.ConsistencyAttentionControl(
428
+ additional_attention_store=previous_consistency_edit_controller_list[kf_id],
429
+ use_inversion_attention=False,
430
+ save_self_attention=(clip_id==0),
431
+ save_latents=False,
432
+ disk_store=True,
433
+ store_path=os.path.join(output_dir, f"consistency_edit{kf_id}_attention_store", f"clip_{clip_id}")
434
+ )
435
+ pipe.scheduler.controller.append(consistency_edit_controller)
436
+ else:
437
+ consistency_edit_controller = None
438
+
439
+ if use_attention_matching or use_consistency_attention_control:
440
+ attention_util.register_attention_control(
441
+ pipe.unet,
442
+ edit_controller,
443
+ consistency_edit_controller,
444
+ find_modules=attention_matching_params.registered_modules,
445
+ consistency_find_modules=long_video_params.registered_modules
446
+ )
447
+
448
+ # should be reorganized to perform attention control
449
+ edited_output = pipe(
450
+ edited_keyframe,
451
+ edited_firstframe=edited_firstframe,
452
+ image_latents=previous_last_frames_latents[kf_id],
453
+ width=data_params.output_res[1],
454
+ height=data_params.output_res[0],
455
+ num_frames=video.shape[1],
456
+ num_inference_steps=num_steps,
457
+ decode_chunk_size=8,
458
+ motion_bucket_id=127,
459
+ fps=data_params.output_fps,
460
+ noise_aug_strength=0.02,
461
+ max_guidance_scale=attention_matching_params.get("max_guidance_scale", 2.5),
462
+ generator=generator,
463
+ latents=inversion_noise
464
+ )
465
+ edited_video = [img for sublist in edited_output.frames for img in sublist]
466
+ edited_video_latents = edited_output.latents
467
+
468
+ # callback to replace frames
469
+ videoio.write_video(edited_video, kf_id, resctrl)
470
+
471
+ # save previous frames
472
+ if long_video_params.mode == "skip-interval":
473
+ #previous_latents[kf_id] = edit_controller.get_all_last_latents(data_params.overlay_size)
474
+ previous_last_frames.append( resctrl.callback(edited_video[-1]) )
475
+ if use_latent_noise:
476
+ previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:]
477
+ else:
478
+ previous_last_frames_latents[kf_id] = None
479
+ elif long_video_params.mode == "auto-regressive":
480
+ previous_last_frames.append( resctrl.callback(edited_video[-1]) )
481
+ if use_latent_noise:
482
+ previous_last_frames_latents[kf_id] = edited_video_latents[:,-1:,:,:,:]
483
+ else:
484
+ previous_last_frames_latents[kf_id] = None
485
+
486
+ # save last frames for convenient
487
+ if save_last_frames:
488
+ try:
489
+ fname = os.path.join(output_dir, f"clip_{clip_id}_lastframe_{kf_id}")
490
+ previous_last_frames[kf_id].save(fname+".png")
491
+ if use_latent_noise:
492
+ torch.save(previous_last_frames_latents[kf_id], fname+".pt")
493
+ except:
494
+ print("save fail")
495
+
496
+ if use_attention_matching or use_consistency_attention_control:
497
+ attention_util.register_attention_control(
498
+ pipe.unet,
499
+ edit_controller,
500
+ consistency_edit_controller,
501
+ find_modules=attention_matching_params.registered_modules,
502
+ consistency_find_modules=long_video_params.registered_modules,
503
+ undo=True
504
+ )
505
+ if edit_controller is not None:
506
+ if visualize_attention_store:
507
+ vis_save_path = os.path.join(output_dir, "visualization", f"{kf_id}", f"clip_{clip_id}")
508
+ os.makedirs(vis_save_path, exist_ok=True)
509
+ attention_util.show_avg_difference_maps(
510
+ edit_controller,
511
+ save_path = vis_save_path
512
+ )
513
+ assert visualize_attention_store_steps is not None
514
+ attention_util.show_self_attention(
515
+ edit_controller,
516
+ steps = visualize_attention_store_steps,
517
+ save_path = vis_save_path,
518
+ inversed = False
519
+ )
520
+ edit_controller.delete()
521
+ del edit_controller
522
+ if use_consistency_attention_control:
523
+ if clip_id == 0:
524
+ previous_consistency_edit_controller_list[kf_id] = consistency_edit_controller
525
+ else:
526
+ consistency_edit_controller.delete()
527
+ del consistency_edit_controller
528
+ print(f"previous_consistency_edit_controller_list[{kf_id}]", previous_consistency_edit_controller_list[kf_id].store_dir)
529
+
530
+
531
+ if use_attention_matching:
532
+ del store_controller
533
+
534
+ if use_consistency_attention_control and clip_id == 0:
535
+ previous_consistency_store_controller = consistency_store_controller
536
+
537
+ videoio.close()
538
+
539
+ if use_consistency_attention_control:
540
+ print("consistency_store_controller for clip 0:", previous_consistency_store_controller.store_dir)
541
+ if retrain_motion_lora:
542
+ print("consistency_train_controller for clip 0:", consistency_train_controller.store_dir)
543
+ for kf_id in range(len(previous_consistency_edit_controller_list)):
544
+ print(f"previous_consistency_edit_controller_list[{kf_id}]:", previous_consistency_edit_controller_list[kf_id].store_dir)
545
+
546
+
547
+ if __name__ == "__main__":
548
+ parser = argparse.ArgumentParser()
549
+ parser.add_argument("--config", type=str, default='./configs/svdedit/item2_2.yaml')
550
+ args = parser.parse_args()
551
+ main(**OmegaConf.load(args.config))
552
+
553
+
554
+
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+
566
+
567
+
568
+
569
+
570
+
571
+
572
+
573
+
574
+
575
+
576
+
577
+
578
+
579
+
580
+
581
+
582
+
583
+
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+
592
+
593
+
594
+
595
+
mydata/source_and_edits/source.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5e206a8e21c8e3b51779ad09352fa96c93e1bd2ec8ad05e2042bd554b733d15
3
+ size 1127699
mydata/source_and_edits/white.jpg ADDED

Git LFS Details

  • SHA256: 64bfe81d18ce7936cbac6684e9160852c613e65d70b89375592e25edae119489
  • Pointer size: 130 Bytes
  • Size of remote file: 78.2 kB
req.txt ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.30.1
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aiohttp==3.9.5
6
+ aiosignal==1.3.1
7
+ albucore==0.0.13
8
+ albumentations==1.4.13
9
+ annotated-types==0.7.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.4.0
12
+ appdirs==1.4.4
13
+ async-timeout==4.0.3
14
+ attrs==23.2.0
15
+ av==12.1.0
16
+ black==24.8.0
17
+ brotli==1.1.0
18
+ certifi==2024.6.2
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ cloudpickle==3.0.0
22
+ coloredlogs==15.0.1
23
+ contourpy==1.2.1
24
+ cycler==0.12.1
25
+ cython==3.0.11
26
+ decorator==4.4.2
27
+ decord==0.6.0
28
+ defusedxml==0.7.1
29
+ diffusers==0.25.1
30
+ easydict==1.13
31
+ einops==0.8.0
32
+ et-xmlfile==1.1.0
33
+ eval-type-backport==0.2.0
34
+ exceptiongroup==1.2.2
35
+ fairscale==0.4.13
36
+ fastapi==0.112.0
37
+ ffmpy==0.3.2
38
+ filelock==3.14.0
39
+ fire==0.6.0
40
+ flatbuffers==24.3.25
41
+ fonttools==4.53.1
42
+ frozenlist==1.4.1
43
+ fsspec==2024.6.0
44
+ ftfy==6.2.0
45
+ fvcore==0.1.5.post20221221
46
+ gradio==4.41.0
47
+ gradio-client==1.3.0
48
+ grpcio==1.64.1
49
+ h11==0.14.0
50
+ httpcore==1.0.5
51
+ httpx==0.27.0
52
+ huggingface-hub==0.23.3
53
+ humanfriendly==10.0
54
+ hydra-core==1.3.2
55
+ idna==3.7
56
+ imageio==2.34.1
57
+ imageio-ffmpeg==0.5.1
58
+ importlib-metadata==7.1.0
59
+ importlib-resources==6.4.2
60
+ insightface==0.7.3
61
+ iopath==0.1.9
62
+ jinja2==3.1.4
63
+ joblib==1.4.2
64
+ kiwisolver==1.4.5
65
+ kornia==0.7.2
66
+ kornia-rs==0.1.3
67
+ lazy-loader==0.4
68
+ lightning-utilities==0.3.0
69
+ markdown==3.6
70
+ markdown-it-py==3.0.0
71
+ markupsafe==2.1.5
72
+ matplotlib==3.9.2
73
+ mdurl==0.1.2
74
+ moviepy==1.0.3
75
+ mpmath==1.3.0
76
+ multidict==6.0.5
77
+ mutagen==1.47.0
78
+ mypy-extensions==1.0.0
79
+ networkx==3.2.1
80
+ ninja==1.11.1.1
81
+ numpy==1.26.4
82
+ omegaconf==2.3.0
83
+ onnx==1.16.2
84
+ onnxruntime==1.18.1
85
+ open-clip-torch==2.24.0
86
+ opencv-python==4.10.0.84
87
+ opencv-python-headless==4.10.0.84
88
+ openpyxl==3.1.5
89
+ orjson==3.10.7
90
+ packaging==24.0
91
+ pandas==2.2.2
92
+ pathspec==0.12.1
93
+ peft==0.11.1
94
+ pillow==10.3.0
95
+ platformdirs==4.2.2
96
+ portalocker==2.10.1
97
+ prettytable==3.11.0
98
+ proglog==0.1.10
99
+ protobuf==4.25.3
100
+ psutil==5.9.8
101
+ pyasn1==0.6.0
102
+ av==12.1.0
103
+ pycocotools==2.0.8
104
+ pycryptodomex==3.20.0
105
+ pydantic==2.8.2
106
+ pydantic-core==2.20.1
107
+ pydub==0.25.1
108
+ pygments==2.18.0
109
+ pyparsing==3.1.2
110
+ pysocks==1.7.1
111
+ python-dateutil==2.9.0.post0
112
+ python-multipart==0.0.9
113
+ pytz==2024.1
114
+ pyyaml==6.0.1
115
+ regex==2023.12.25
116
+ requests==2.32.3
117
+ rich==13.7.1
118
+ rpds-py==0.18.1
119
+ ruff==0.6.0
120
+ safetensors==0.4.3
121
+ scenedetect==0.6.4
122
+ scikit-image==0.24.0
123
+ scikit-learn==1.5.1
124
+ scipy==1.13.1
125
+ segment-anything==1.0
126
+ semantic-version==2.10.0
127
+ sentencepiece==0.1.99
128
+ sentry-sdk==2.5.1
129
+ shellingham==1.5.4
130
+ six==1.16.0
131
+ smmap==5.0.1
132
+ sniffio==1.3.1
133
+ soupsieve==2.5
134
+ starlette==0.37.2
135
+ supervision==0.22.0
136
+ sympy==1.12.1
137
+ tabulate==0.9.0
138
+ tensorboard==2.17.0
139
+ tensorboard-data-server==0.7.2
140
+ tensorboardx==2.6.2.2
141
+ termcolor==2.4.0
142
+ threadpoolctl==3.5.0
143
+ tifffile==2024.8.10
144
+ timm==1.0.8
145
+ tokenizers==0.15.2
146
+ tomli==2.0.1
147
+ tomlkit==0.12.0
148
+ toolz==0.12.1
149
+ torchmetrics==0.11.4
150
+ tqdm==4.66.4
151
+ transformers==4.36.2
152
+ typer==0.12.3
153
+ typing-extensions==4.12.2
154
+ tzdata==2024.1
155
+ urllib3==2.2.1
156
+ uvicorn==0.30.6
157
+ wcwidth==0.2.13
158
+ websockets==12.0
159
+ werkzeug==3.0.3
160
+ xformers==0.0.26.post1
161
+ yacs==0.1.8
162
+ yapf==0.40.2
163
+ yarl==1.9.4
164
+ yt-dlp==2024.8.6
165
+ zipp==3.19.2
requirements.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ absl-py==2.1.0
4
+ accelerate==0.30.1
5
+ addict==2.4.0
6
+ aiofiles==23.2.1
7
+ aiohttp==3.9.5
8
+ aiosignal==1.3.1
9
+ albucore==0.0.13
10
+ albumentations==1.4.13
11
+ annotated-types==0.7.0
12
+ antlr4-python3-runtime==4.9.3
13
+ anyio==4.4.0
14
+ appdirs==1.4.4
15
+ async-timeout==4.0.3
16
+ attrs==23.2.0
17
+ av==12.1.0
18
+ black==24.8.0
19
+ Brotli==1.1.0
20
+ certifi==2024.6.2
21
+ charset-normalizer==3.3.2
22
+ click==8.1.7
23
+ cloudpickle==3.0.0
24
+ colorama==0.4.6
25
+ coloredlogs==15.0.1
26
+ contourpy==1.2.1
27
+ cycler==0.12.1
28
+ Cython==3.0.11
29
+ decorator==4.4.2
30
+ decord==0.6.0
31
+ defusedxml==0.7.1
32
+ diffusers>=0.27.0
33
+ easydict==1.13
34
+ einops==0.8.0
35
+ et-xmlfile==1.1.0
36
+ eval_type_backport==0.2.0
37
+ exceptiongroup==1.2.2
38
+ fairscale==0.4.13
39
+ fastapi==0.112.0
40
+ ffmpy==0.3.2
41
+ filelock==3.14.0
42
+ fire==0.6.0
43
+ flatbuffers==24.3.25
44
+ fonttools==4.53.1
45
+ frozenlist==1.4.1
46
+ fsspec==2024.6.0
47
+ ftfy==6.2.0
48
+ fvcore==0.1.5.post20221221
49
+ gradio==4.41.0
50
+ gradio_client==1.3.0
51
+ grpcio==1.64.1
52
+ h11==0.14.0
53
+ httpcore==1.0.5
54
+ httpx==0.27.0
55
+ huggingface-hub==0.23.3
56
+ humanfriendly==10.0
57
+ hydra-core==1.3.2
58
+ idna==3.7
59
+ imageio==2.34.1
60
+ imageio-ffmpeg==0.5.1
61
+ importlib_metadata==7.1.0
62
+ importlib_resources==6.4.2
63
+ insightface==0.7.3
64
+ intel-openmp==2021.4.0
65
+ iopath==0.1.9
66
+ Jinja2==3.1.4
67
+ joblib==1.4.2
68
+ kiwisolver==1.4.5
69
+ kornia==0.7.2
70
+ kornia_rs==0.1.3
71
+ lazy_loader==0.4
72
+ lightning-utilities==0.3.0
73
+ Markdown==3.6
74
+ markdown-it-py==3.0.0
75
+ MarkupSafe==2.1.5
76
+ matplotlib==3.9.2
77
+ mdurl==0.1.2
78
+ mkl==2021.4.0
79
+ moviepy==1.0.3
80
+ mpmath==1.3.0
81
+ multidict==6.0.5
82
+ mutagen==1.47.0
83
+ mypy-extensions==1.0.0
84
+ networkx==3.2.1
85
+ ninja==1.11.1.1
86
+ numpy==1.26.4
87
+ omegaconf==2.3.0
88
+ onnx==1.16.2
89
+ onnxruntime==1.18.1
90
+ open-clip-torch==2.24.0
91
+ opencv-python==4.10.0.84
92
+ opencv-python-headless==4.10.0.84
93
+ openpyxl==3.1.5
94
+ orjson==3.10.7
95
+ packaging==24.0
96
+ pandas==2.2.2
97
+ pathspec==0.12.1
98
+ peft==0.11.1
99
+ pillow==10.3.0
100
+ platformdirs==4.2.2
101
+ portalocker==2.10.1
102
+ prettytable==3.11.0
103
+ proglog==0.1.10
104
+ protobuf==4.25.3
105
+ psutil==5.9.8
106
+ pyasn1==0.6.0
107
+ pycocotools==2.0.8
108
+ pycryptodomex==3.20.0
109
+ pydantic==2.8.2
110
+ pydantic_core==2.20.1
111
+ pydub==0.25.1
112
+ Pygments==2.18.0
113
+ pyparsing==3.1.2
114
+ pyreadline3==3.5.4
115
+ PySocks==1.7.1
116
+ python-dateutil==2.9.0.post0
117
+ python-multipart==0.0.9
118
+ pytz==2024.1
119
+ PyYAML==6.0.1
120
+ regex==2023.12.25
121
+ requests==2.32.3
122
+ rich==13.7.1
123
+ rpds-py==0.18.1
124
+ ruff==0.6.0
125
+ safetensors==0.4.3
126
+ scenedetect==0.6.4
127
+ scikit-image==0.24.0
128
+ scikit-learn==1.5.1
129
+ scipy==1.13.1
130
+ segment-anything==1.0
131
+ semantic-version==2.10.0
132
+ sentencepiece==0.1.99
133
+ sentry-sdk==2.5.1
134
+ setuptools==69.5.1
135
+ shellingham==1.5.4
136
+ six==1.16.0
137
+ smmap==5.0.1
138
+ sniffio==1.3.1
139
+ soupsieve==2.5
140
+ starlette==0.37.2
141
+ supervision==0.22.0
142
+ sympy==1.12.1
143
+ tabulate==0.9.0
144
+ tbb==2021.13.1
145
+ tensorboard==2.17.0
146
+ tensorboard-data-server==0.7.2
147
+ tensorboardX==2.6.2.2
148
+ termcolor==2.4.0
149
+ threadpoolctl==3.5.0
150
+ tifffile==2024.8.10
151
+ timm==1.0.8
152
+ tokenizers==0.15.2
153
+ tomli==2.0.1
154
+ tomlkit==0.12.0
155
+ toolz==0.12.1
156
+ torch==2.3.0+cu121
157
+ torchaudio==2.3.0+cu121
158
+ torchmetrics==0.11.4
159
+ torchvision==0.18.0+cu121
160
+ tqdm==4.66.4
161
+ transformers==4.36.2
162
+ typer==0.12.3
163
+ typing_extensions==4.12.2
164
+ tzdata==2024.1
165
+ urllib3==2.2.1
166
+ uvicorn==0.30.6
167
+ wcwidth==0.2.13
168
+ websockets==12.0
169
+ Werkzeug==3.0.3
170
+ wheel==0.43.0
171
+ xformers==0.0.26.post1
172
+ yacs==0.1.8
173
+ yapf==0.40.2
174
+ yarl==1.9.4
175
+ yt-dlp==2024.8.6
176
+ zipp==3.19.2
setup.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from setuptools import setup, find_packages
3
+ import sys
4
+ import os
5
+ import os.path as osp
6
+
7
+ WORK_DIR = "i2vedit"
8
+ NAME = "i2vedit"
9
+ author = "wenqi.oywq"
10
+ author_email = '[email protected]'
11
+
12
+ version_file = 'i2vedit/version.py'
13
+
14
+ def get_hash():
15
+ if False:#os.path.exists('.git'):
16
+ sha = get_git_hash()[:7]
17
+ # currently ignore this
18
+ # elif os.path.exists(version_file):
19
+ # try:
20
+ # from basicsr.version import __version__
21
+ # sha = __version__.split('+')[-1]
22
+ # except ImportError:
23
+ # raise ImportError('Unable to get git version')
24
+ else:
25
+ sha = 'unknown'
26
+
27
+ return sha
28
+
29
+ def get_version():
30
+ with open(version_file, 'r') as f:
31
+ exec(compile(f.read(), version_file, 'exec'))
32
+ return locals()['__version__']
33
+
34
+ def write_version_py():
35
+ content = """# GENERATED VERSION FILE
36
+ # TIME: {}
37
+ __version__ = '{}'
38
+ __gitsha__ = '{}'
39
+ version_info = ({})
40
+ """
41
+ sha = get_hash()
42
+ with open('VERSION', 'r') as f:
43
+ SHORT_VERSION = f.read().strip()
44
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
45
+
46
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
47
+ with open(version_file, 'w') as f:
48
+ f.write(version_file_str)
49
+
50
+ REQUIRE = [
51
+ ]
52
+
53
+ def install_requires(REQUIRE):
54
+ for item in REQUIRE:
55
+ os.system(f'pip install {item}')
56
+
57
+ write_version_py()
58
+ install_requires(REQUIRE)
59
+ setup(
60
+ name=NAME,
61
+ packages=find_packages(),
62
+ version=get_version(),
63
+ description="image-to-video editing",
64
+ author=author,
65
+ author_email=author_email,
66
+ keywords=["image-to-video editing"],
67
+ install_requires=[],
68
+ include_package_data=False,
69
+ exclude_package_data={'':['.gitignore','README.md','./configs','./outputs'],
70
+ },
71
+ entry_points={'console_scripts': ['pyinstrument = pyinstrument.__main__:main']},
72
+ zip_safe=False,
73
+ )