ImageGen-Z-Image / ui /events.py
RioShiina's picture
Upload folder using huggingface_hub
9f5732f verified
import gradio as gr
import yaml
import os
import shutil
from functools import lru_cache
from core.settings import *
from utils.app_utils import *
from core.generation_logic import *
from comfy_integration.nodes import SAMPLER_CHOICES, SCHEDULER_CHOICES
from core.pipelines.controlnet_preprocessor import CPU_ONLY_PREPROCESSORS
from utils.app_utils import PREPROCESSOR_MODEL_MAP, PREPROCESSOR_PARAMETER_MAP, save_uploaded_file_with_hash
from ui.shared.ui_components import RESOLUTION_MAP, MAX_CONTROLNETS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_LORAS
@lru_cache(maxsize=1)
def load_controlnet_config():
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_CN_MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'controlnet_models.yaml')
try:
print("--- Loading controlnet_models.yaml ---")
with open(_CN_MODEL_LIST_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
print("--- ✅ controlnet_models.yaml loaded successfully ---")
return config.get("ControlNet", {}).get("SDXL", [])
except Exception as e:
print(f"Error loading controlnet_models.yaml: {e}")
return []
@lru_cache(maxsize=1)
def load_diffsynth_controlnet_config():
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_CN_MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'diffsynth_controlnet_models.yaml')
try:
print("--- Loading diffsynth_controlnet_models.yaml ---")
with open(_CN_MODEL_LIST_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
print("--- ✅ diffsynth_controlnet_models.yaml loaded successfully ---")
return config.get("DiffSynth_ControlNet", {}).get("Z-Image", [])
except Exception as e:
print(f"Error loading diffsynth_controlnet_models.yaml: {e}")
return []
def attach_event_handlers(ui_components, demo):
def update_cn_input_visibility(choice):
return {
ui_components["cn_image_input"]: gr.update(visible=choice == "Image"),
ui_components["cn_video_input"]: gr.update(visible=choice == "Video")
}
ui_components["cn_input_type"].change(
fn=update_cn_input_visibility,
inputs=[ui_components["cn_input_type"]],
outputs=[ui_components["cn_image_input"], ui_components["cn_video_input"]]
)
def update_preprocessor_models_dropdown(preprocessor_name):
models = PREPROCESSOR_MODEL_MAP.get(preprocessor_name)
if models:
model_filenames = [m[1] for m in models]
return gr.update(choices=model_filenames, value=model_filenames[0], visible=True)
else:
return gr.update(choices=[], value=None, visible=False)
def update_preprocessor_settings_ui(preprocessor_name):
from ui.layout import MAX_DYNAMIC_CONTROLS
params = PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
slider_updates, dropdown_updates, checkbox_updates = [], [], []
s_idx, d_idx, c_idx = 0, 0, 0
for param in params:
if s_idx + d_idx + c_idx >= MAX_DYNAMIC_CONTROLS: break
name = param["name"]
ptype = param["type"]
config = param["config"]
label = name.replace('_', ' ').title()
if ptype == "INT" or ptype == "FLOAT":
if s_idx < MAX_DYNAMIC_CONTROLS:
slider_updates.append(gr.update(
label=label,
minimum=config.get('min', 0),
maximum=config.get('max', 255),
step=config.get('step', 0.1 if ptype == "FLOAT" else 1),
value=config.get('default', 0),
visible=True
))
s_idx += 1
elif isinstance(ptype, list):
if d_idx < MAX_DYNAMIC_CONTROLS:
dropdown_updates.append(gr.update(
label=label,
choices=ptype,
value=config.get('default', ptype[0] if ptype else None),
visible=True
))
d_idx += 1
elif ptype == "BOOLEAN":
if c_idx < MAX_DYNAMIC_CONTROLS:
checkbox_updates.append(gr.update(
label=label,
value=config.get('default', False),
visible=True
))
c_idx += 1
for _ in range(s_idx, MAX_DYNAMIC_CONTROLS): slider_updates.append(gr.update(visible=False))
for _ in range(d_idx, MAX_DYNAMIC_CONTROLS): dropdown_updates.append(gr.update(visible=False))
for _ in range(c_idx, MAX_DYNAMIC_CONTROLS): checkbox_updates.append(gr.update(visible=False))
return slider_updates + dropdown_updates + checkbox_updates
def update_run_button_for_cpu(preprocessor_name):
if preprocessor_name in CPU_ONLY_PREPROCESSORS:
return gr.update(value="Run Preprocessor CPU Only", variant="primary"), gr.update(visible=False)
else:
return gr.update(value="Run Preprocessor", variant="primary"), gr.update(visible=True)
ui_components["preprocessor_cn"].change(
fn=update_preprocessor_models_dropdown,
inputs=[ui_components["preprocessor_cn"]],
outputs=[ui_components["preprocessor_model_cn"]]
).then(
fn=update_preprocessor_settings_ui,
inputs=[ui_components["preprocessor_cn"]],
outputs=ui_components["cn_sliders"] + ui_components["cn_dropdowns"] + ui_components["cn_checkboxes"]
).then(
fn=update_run_button_for_cpu,
inputs=[ui_components["preprocessor_cn"]],
outputs=[ui_components["run_cn"], ui_components["zero_gpu_cn"]]
)
all_dynamic_inputs = (
ui_components["cn_sliders"] +
ui_components["cn_dropdowns"] +
ui_components["cn_checkboxes"]
)
ui_components["run_cn"].click(
fn=run_cn_preprocessor_entry,
inputs=[
ui_components["cn_input_type"],
ui_components["cn_image_input"],
ui_components["cn_video_input"],
ui_components["preprocessor_cn"],
ui_components["preprocessor_model_cn"],
ui_components["zero_gpu_cn"],
] + all_dynamic_inputs,
outputs=[ui_components["output_gallery_cn"]]
)
def create_lora_event_handlers(prefix):
lora_rows = ui_components[f'lora_rows_{prefix}']
lora_ids = ui_components[f'lora_ids_{prefix}']
lora_scales = ui_components[f'lora_scales_{prefix}']
lora_uploads = ui_components[f'lora_uploads_{prefix}']
count_state = ui_components[f'lora_count_state_{prefix}']
add_button = ui_components[f'add_lora_button_{prefix}']
del_button = ui_components[f'delete_lora_button_{prefix}']
def add_lora_row(c):
updates = {}
if c < MAX_LORAS:
c += 1
updates[lora_rows[c - 1]] = gr.update(visible=True)
updates[count_state] = c
updates[add_button] = gr.update(visible=c < MAX_LORAS)
updates[del_button] = gr.update(visible=c > 1)
return updates
def del_lora_row(c):
updates = {}
if c > 1:
updates[lora_rows[c - 1]] = gr.update(visible=False)
updates[lora_ids[c - 1]] = ""
updates[lora_scales[c - 1]] = 0.0
updates[lora_uploads[c - 1]] = None
c -= 1
updates[count_state] = c
updates[add_button] = gr.update(visible=True)
updates[del_button] = gr.update(visible=c > 1)
return updates
add_outputs = [count_state, add_button, del_button] + lora_rows
del_outputs = [count_state, add_button, del_button] + lora_rows + lora_ids + lora_scales + lora_uploads
add_button.click(add_lora_row, [count_state], add_outputs, show_progress=False)
del_button.click(del_lora_row, [count_state], del_outputs, show_progress=False)
def create_controlnet_event_handlers(prefix):
cn_rows = ui_components[f'controlnet_rows_{prefix}']
cn_types = ui_components[f'controlnet_types_{prefix}']
cn_series = ui_components[f'controlnet_series_{prefix}']
cn_filepaths = ui_components[f'controlnet_filepaths_{prefix}']
cn_images = ui_components[f'controlnet_images_{prefix}']
cn_strengths = ui_components[f'controlnet_strengths_{prefix}']
count_state = ui_components[f'controlnet_count_state_{prefix}']
add_button = ui_components[f'add_controlnet_button_{prefix}']
del_button = ui_components[f'delete_controlnet_button_{prefix}']
accordion = ui_components[f'controlnet_accordion_{prefix}']
def add_cn_row(c):
c += 1
updates = {
count_state: c,
cn_rows[c-1]: gr.update(visible=True),
add_button: gr.update(visible=c < MAX_CONTROLNETS),
del_button: gr.update(visible=True)
}
return updates
def del_cn_row(c):
c -= 1
updates = {
count_state: c,
cn_rows[c]: gr.update(visible=False),
cn_images[c]: None,
cn_strengths[c]: 1.0,
add_button: gr.update(visible=True),
del_button: gr.update(visible=c > 0)
}
return updates
add_outputs = [count_state, add_button, del_button] + cn_rows
del_outputs = [count_state, add_button, del_button] + cn_rows + cn_images + cn_strengths
add_button.click(fn=add_cn_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
del_button.click(fn=del_cn_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
def on_cn_type_change(selected_type):
cn_config = load_controlnet_config()
series_choices = []
if selected_type:
series_choices = sorted(list(set(
model.get("Series", "Default") for model in cn_config
if selected_type in model.get("Type", [])
)))
default_series = series_choices[0] if series_choices else None
filepath = "None"
if default_series:
for model in cn_config:
if model.get("Series") == default_series and selected_type in model.get("Type", []):
filepath = model.get("Filepath")
break
return gr.update(choices=series_choices, value=default_series), filepath
def on_cn_series_change(selected_series, selected_type):
cn_config = load_controlnet_config()
filepath = "None"
if selected_series and selected_type:
for model in cn_config:
if model.get("Series") == selected_series and selected_type in model.get("Type", []):
filepath = model.get("Filepath")
break
return filepath
for i in range(MAX_CONTROLNETS):
cn_types[i].change(
fn=on_cn_type_change,
inputs=[cn_types[i]],
outputs=[cn_series[i], cn_filepaths[i]],
show_progress=False
)
cn_series[i].change(
fn=on_cn_series_change,
inputs=[cn_series[i], cn_types[i]],
outputs=[cn_filepaths[i]],
show_progress=False
)
def on_accordion_expand(*images):
return [gr.update() for _ in images]
accordion.expand(
fn=on_accordion_expand,
inputs=cn_images,
outputs=cn_images,
show_progress=False
)
def create_diffsynth_controlnet_event_handlers(prefix):
cn_rows = ui_components[f'diffsynth_controlnet_rows_{prefix}']
cn_types = ui_components[f'diffsynth_controlnet_types_{prefix}']
cn_series = ui_components[f'diffsynth_controlnet_series_{prefix}']
cn_filepaths = ui_components[f'diffsynth_controlnet_filepaths_{prefix}']
cn_images = ui_components[f'diffsynth_controlnet_images_{prefix}']
cn_strengths = ui_components[f'diffsynth_controlnet_strengths_{prefix}']
count_state = ui_components[f'diffsynth_controlnet_count_state_{prefix}']
add_button = ui_components[f'add_diffsynth_controlnet_button_{prefix}']
del_button = ui_components[f'delete_diffsynth_controlnet_button_{prefix}']
accordion = ui_components[f'diffsynth_controlnet_accordion_{prefix}']
def add_cn_row(c):
c += 1
updates = {
count_state: c,
cn_rows[c-1]: gr.update(visible=True),
add_button: gr.update(visible=c < MAX_CONTROLNETS),
del_button: gr.update(visible=True)
}
return updates
def del_cn_row(c):
c -= 1
updates = {
count_state: c,
cn_rows[c]: gr.update(visible=False),
cn_images[c]: None,
cn_strengths[c]: 1.0,
add_button: gr.update(visible=True),
del_button: gr.update(visible=c > 0)
}
return updates
add_outputs = [count_state, add_button, del_button] + cn_rows
del_outputs = [count_state, add_button, del_button] + cn_rows + cn_images + cn_strengths
add_button.click(fn=add_cn_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
del_button.click(fn=del_cn_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
def on_cn_type_change(selected_type):
cn_config = load_diffsynth_controlnet_config()
series_choices = []
if selected_type:
series_choices = sorted(list(set(
model.get("Series", "Default") for model in cn_config
if selected_type in model.get("Type", [])
)))
default_series = series_choices[0] if series_choices else None
filepath = "None"
if default_series:
for model in cn_config:
if model.get("Series") == default_series and selected_type in model.get("Type", []):
filepath = model.get("Filepath")
break
return gr.update(choices=series_choices, value=default_series), filepath
def on_cn_series_change(selected_series, selected_type):
cn_config = load_diffsynth_controlnet_config()
filepath = "None"
if selected_series and selected_type:
for model in cn_config:
if model.get("Series") == selected_series and selected_type in model.get("Type", []):
filepath = model.get("Filepath")
break
return filepath
for i in range(MAX_CONTROLNETS):
cn_types[i].change(
fn=on_cn_type_change,
inputs=[cn_types[i]],
outputs=[cn_series[i], cn_filepaths[i]],
show_progress=False
)
cn_series[i].change(
fn=on_cn_series_change,
inputs=[cn_series[i], cn_types[i]],
outputs=[cn_filepaths[i]],
show_progress=False
)
def on_accordion_expand(*images):
return [gr.update() for _ in images]
accordion.expand(
fn=on_accordion_expand,
inputs=cn_images,
outputs=cn_images,
show_progress=False
)
def create_embedding_event_handlers(prefix):
rows = ui_components[f'embedding_rows_{prefix}']
ids = ui_components[f'embeddings_ids_{prefix}']
files = ui_components[f'embeddings_files_{prefix}']
count_state = ui_components[f'embedding_count_state_{prefix}']
add_button = ui_components[f'add_embedding_button_{prefix}']
del_button = ui_components[f'delete_embedding_button_{prefix}']
def add_row(c):
c += 1
return {
count_state: c,
rows[c - 1]: gr.update(visible=True),
add_button: gr.update(visible=c < MAX_EMBEDDINGS),
del_button: gr.update(visible=True)
}
def del_row(c):
c -= 1
return {
count_state: c,
rows[c]: gr.update(visible=False),
ids[c]: "",
files[c]: None,
add_button: gr.update(visible=True),
del_button: gr.update(visible=c > 0)
}
add_outputs = [count_state, add_button, del_button] + rows
del_outputs = [count_state, add_button, del_button] + rows + ids + files
add_button.click(fn=add_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
del_button.click(fn=del_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
def create_conditioning_event_handlers(prefix):
rows = ui_components[f'conditioning_rows_{prefix}']
prompts = ui_components[f'conditioning_prompts_{prefix}']
count_state = ui_components[f'conditioning_count_state_{prefix}']
add_button = ui_components[f'add_conditioning_button_{prefix}']
del_button = ui_components[f'delete_conditioning_button_{prefix}']
def add_row(c):
c += 1
return {
count_state: c,
rows[c - 1]: gr.update(visible=True),
add_button: gr.update(visible=c < MAX_CONDITIONINGS),
del_button: gr.update(visible=True),
}
def del_row(c):
c -= 1
return {
count_state: c,
rows[c]: gr.update(visible=False),
prompts[c]: "",
add_button: gr.update(visible=True),
del_button: gr.update(visible=c > 0),
}
add_outputs = [count_state, add_button, del_button] + rows
del_outputs = [count_state, add_button, del_button] + rows + prompts
add_button.click(fn=add_row, inputs=[count_state], outputs=add_outputs, show_progress=False)
del_button.click(fn=del_row, inputs=[count_state], outputs=del_outputs, show_progress=False)
def on_vae_upload(file_obj):
if not file_obj:
return gr.update(), gr.update(), None
hashed_filename = save_uploaded_file_with_hash(file_obj, VAE_DIR)
return hashed_filename, "File", file_obj
def on_lora_upload(file_obj):
if not file_obj:
return gr.update(), gr.update()
hashed_filename = save_uploaded_file_with_hash(file_obj, LORA_DIR)
return hashed_filename, "File"
def on_embedding_upload(file_obj):
if not file_obj:
return gr.update(), gr.update(), None
hashed_filename = save_uploaded_file_with_hash(file_obj, EMBEDDING_DIR)
return hashed_filename, "File", file_obj
def create_run_event(prefix: str, task_type: str):
run_inputs_map = {
'model_display_name': ui_components[f'base_model_{prefix}'],
'positive_prompt': ui_components[f'prompt_{prefix}'],
'negative_prompt': ui_components[f'neg_prompt_{prefix}'],
'seed': ui_components[f'seed_{prefix}'],
'batch_size': ui_components[f'batch_size_{prefix}'],
'guidance_scale': ui_components[f'cfg_{prefix}'],
'num_inference_steps': ui_components[f'steps_{prefix}'],
'sampler': ui_components[f'sampler_{prefix}'],
'scheduler': ui_components[f'scheduler_{prefix}'],
'zero_gpu_duration': ui_components[f'zero_gpu_{prefix}'],
'civitai_api_key': ui_components.get(f'civitai_api_key_{prefix}'),
'clip_skip': ui_components[f'clip_skip_{prefix}'],
'task_type': gr.State(task_type)
}
if task_type not in ['img2img', 'inpaint']:
run_inputs_map.update({'width': ui_components[f'width_{prefix}'], 'height': ui_components[f'height_{prefix}']})
task_specific_map = {
'img2img': {'img2img_image': f'input_image_{prefix}', 'img2img_denoise': f'denoise_{prefix}'},
'inpaint': {'inpaint_image_dict': f'input_image_dict_{prefix}'},
'outpaint': {'outpaint_image': f'input_image_{prefix}', 'outpaint_left': f'outpaint_left_{prefix}', 'outpaint_top': f'outpaint_top_{prefix}', 'outpaint_right': f'outpaint_right_{prefix}', 'outpaint_bottom': f'outpaint_bottom_{prefix}'},
'hires_fix': {'hires_image': f'input_image_{prefix}', 'hires_upscaler': f'hires_upscaler_{prefix}', 'hires_scale_by': f'hires_scale_by_{prefix}', 'hires_denoise': f'denoise_{prefix}'}
}
if task_type in task_specific_map:
for key, comp_name in task_specific_map[task_type].items():
run_inputs_map[key] = ui_components[comp_name]
lora_data_components = ui_components.get(f'all_lora_components_flat_{prefix}', [])
controlnet_data_components = ui_components.get(f'all_controlnet_components_flat_{prefix}', [])
diffsynth_controlnet_data_components = ui_components.get(f'all_diffsynth_controlnet_components_flat_{prefix}', [])
embedding_data_components = ui_components.get(f'all_embedding_components_flat_{prefix}', [])
conditioning_data_components = ui_components.get(f'all_conditioning_components_flat_{prefix}', [])
run_inputs_map['vae_source'] = ui_components.get(f'vae_source_{prefix}')
run_inputs_map['vae_id'] = ui_components.get(f'vae_id_{prefix}')
run_inputs_map['vae_file'] = ui_components.get(f'vae_file_{prefix}')
input_keys = list(run_inputs_map.keys())
input_list_flat = [v for v in run_inputs_map.values() if v is not None]
input_list_flat += lora_data_components + controlnet_data_components + diffsynth_controlnet_data_components + embedding_data_components + conditioning_data_components
def create_ui_inputs_dict(*args):
valid_keys = [k for k in input_keys if run_inputs_map[k] is not None]
ui_dict = dict(zip(valid_keys, args[:len(valid_keys)]))
arg_idx = len(valid_keys)
ui_dict['lora_data'] = list(args[arg_idx : arg_idx + len(lora_data_components)])
arg_idx += len(lora_data_components)
ui_dict['controlnet_data'] = list(args[arg_idx : arg_idx + len(controlnet_data_components)])
arg_idx += len(controlnet_data_components)
ui_dict['diffsynth_controlnet_data'] = list(args[arg_idx : arg_idx + len(diffsynth_controlnet_data_components)])
arg_idx += len(diffsynth_controlnet_data_components)
ui_dict['embedding_data'] = list(args[arg_idx : arg_idx + len(embedding_data_components)])
arg_idx += len(embedding_data_components)
ui_dict['conditioning_data'] = list(args[arg_idx : arg_idx + len(conditioning_data_components)])
return ui_dict
ui_components[f'run_{prefix}'].click(
fn=lambda *args, progress=gr.Progress(track_tqdm=True): generate_image_wrapper(create_ui_inputs_dict(*args), progress),
inputs=input_list_flat,
outputs=[ui_components[f'result_{prefix}']]
)
for prefix, task_type in [
("txt2img", "txt2img"), ("img2img", "img2img"), ("inpaint", "inpaint"),
("outpaint", "outpaint"), ("hires_fix", "hires_fix"),
]:
if f'add_lora_button_{prefix}' in ui_components:
create_lora_event_handlers(prefix)
lora_uploads = ui_components[f'lora_uploads_{prefix}']
lora_ids = ui_components[f'lora_ids_{prefix}']
lora_sources = ui_components[f'lora_sources_{prefix}']
for i in range(MAX_LORAS):
lora_uploads[i].upload(
fn=on_lora_upload,
inputs=[lora_uploads[i]],
outputs=[lora_ids[i], lora_sources[i]],
show_progress=False
)
if f'add_controlnet_button_{prefix}' in ui_components: create_controlnet_event_handlers(prefix)
if f'add_diffsynth_controlnet_button_{prefix}' in ui_components: create_diffsynth_controlnet_event_handlers(prefix)
if f'add_embedding_button_{prefix}' in ui_components:
create_embedding_event_handlers(prefix)
if f'embeddings_uploads_{prefix}' in ui_components:
emb_uploads = ui_components[f'embeddings_uploads_{prefix}']
emb_ids = ui_components[f'embeddings_ids_{prefix}']
emb_sources = ui_components[f'embeddings_sources_{prefix}']
emb_files = ui_components[f'embeddings_files_{prefix}']
for i in range(MAX_EMBEDDINGS):
emb_uploads[i].upload(
fn=on_embedding_upload,
inputs=[emb_uploads[i]],
outputs=[emb_ids[i], emb_sources[i], emb_files[i]],
show_progress=False
)
if f'add_conditioning_button_{prefix}' in ui_components: create_conditioning_event_handlers(prefix)
if f'vae_source_{prefix}' in ui_components:
upload_button = ui_components.get(f'vae_upload_button_{prefix}')
if upload_button:
upload_button.upload(
fn=on_vae_upload,
inputs=[upload_button],
outputs=[
ui_components[f'vae_id_{prefix}'],
ui_components[f'vae_source_{prefix}'],
ui_components[f'vae_file_{prefix}']
]
)
create_run_event(prefix, task_type)
def on_aspect_ratio_change(ratio_key, model_display_name):
model_type = MODEL_TYPE_MAP.get(model_display_name, 'sdxl').lower()
res_map = RESOLUTION_MAP.get(model_type, RESOLUTION_MAP.get("sdxl", {}))
w, h = res_map.get(ratio_key, (1024, 1024))
return w, h
for prefix in ["txt2img", "img2img", "inpaint", "outpaint", "hires_fix"]:
if f'aspect_ratio_{prefix}' in ui_components:
aspect_ratio_dropdown = ui_components[f'aspect_ratio_{prefix}']
width_component = ui_components[f'width_{prefix}']
height_component = ui_components[f'height_{prefix}']
model_dropdown = ui_components[f'base_model_{prefix}']
aspect_ratio_dropdown.change(fn=on_aspect_ratio_change, inputs=[aspect_ratio_dropdown, model_dropdown], outputs=[width_component, height_component], show_progress=False)
if 'view_mode_inpaint' in ui_components:
def toggle_inpaint_fullscreen_view(view_mode):
is_fullscreen = (view_mode == "Fullscreen View")
other_elements_visible = not is_fullscreen
editor_height = 800 if is_fullscreen else 272
return {
ui_components['model_and_run_row_inpaint']: gr.update(visible=other_elements_visible),
ui_components['prompts_column_inpaint']: gr.update(visible=other_elements_visible),
ui_components['params_and_gallery_row_inpaint']: gr.update(visible=other_elements_visible),
ui_components['accordion_wrapper_inpaint']: gr.update(visible=other_elements_visible),
ui_components['input_image_dict_inpaint']: gr.update(height=editor_height),
}
output_components = [
ui_components['model_and_run_row_inpaint'], ui_components['prompts_column_inpaint'],
ui_components['params_and_gallery_row_inpaint'], ui_components['accordion_wrapper_inpaint'],
ui_components['input_image_dict_inpaint']
]
ui_components['view_mode_inpaint'].change(fn=toggle_inpaint_fullscreen_view, inputs=[ui_components['view_mode_inpaint']], outputs=output_components, show_progress=False)
def initialize_all_cn_dropdowns():
# Standard ControlNet
cn_config = load_controlnet_config()
cn_updates = {}
if cn_config:
all_types = sorted(list(set(t for model in cn_config for t in model.get("Type", []))))
default_type = all_types[0] if all_types else None
series_choices = []
if default_type:
series_choices = sorted(list(set(model.get("Series", "Default") for model in cn_config if default_type in model.get("Type", []))))
default_series = series_choices[0] if series_choices else None
filepath = "None"
if default_series and default_type:
for model in cn_config:
if model.get("Series") == default_series and default_type in model.get("Type", []):
filepath = model.get("Filepath")
break
for prefix in ["txt2img", "img2img", "inpaint", "outpaint", "hires_fix"]:
if f'controlnet_types_{prefix}' in ui_components:
for type_dd in ui_components[f'controlnet_types_{prefix}']: cn_updates[type_dd] = gr.update(choices=all_types, value=default_type)
for series_dd in ui_components[f'controlnet_series_{prefix}']: cn_updates[series_dd] = gr.update(choices=series_choices, value=default_series)
for filepath_state in ui_components[f'controlnet_filepaths_{prefix}']: cn_updates[filepath_state] = filepath
# DiffSynth ControlNet
diffsynth_cn_config = load_diffsynth_controlnet_config()
diffsynth_updates = {}
if diffsynth_cn_config:
all_types = sorted(list(set(t for model in diffsynth_cn_config for t in model.get("Type", []))))
default_type = all_types[0] if all_types else None
series_choices = []
if default_type:
series_choices = sorted(list(set(model.get("Series", "Default") for model in diffsynth_cn_config if default_type in model.get("Type", []))))
default_series = series_choices[0] if series_choices else None
filepath = "None"
if default_series and default_type:
for model in diffsynth_cn_config:
if model.get("Series") == default_series and default_type in model.get("Type", []):
filepath = model.get("Filepath")
break
for prefix in ["txt2img", "img2img", "inpaint", "outpaint", "hires_fix"]:
if f'diffsynth_controlnet_types_{prefix}' in ui_components:
for type_dd in ui_components[f'diffsynth_controlnet_types_{prefix}']: diffsynth_updates[type_dd] = gr.update(choices=all_types, value=default_type)
for series_dd in ui_components[f'diffsynth_controlnet_series_{prefix}']: diffsynth_updates[series_dd] = gr.update(choices=series_choices, value=default_series)
for filepath_state in ui_components[f'diffsynth_controlnet_filepaths_{prefix}']: diffsynth_updates[filepath_state] = filepath
return {**cn_updates, **diffsynth_updates}
def run_on_load():
all_updates = initialize_all_cn_dropdowns()
default_preprocessor = "Canny Edge"
model_update = update_preprocessor_models_dropdown(default_preprocessor)
all_updates[ui_components["preprocessor_model_cn"]] = model_update
settings_outputs = update_preprocessor_settings_ui(default_preprocessor)
dynamic_outputs = ui_components["cn_sliders"] + ui_components["cn_dropdowns"] + ui_components["cn_checkboxes"]
for i, comp in enumerate(dynamic_outputs):
all_updates[comp] = settings_outputs[i]
run_button_update, zero_gpu_update = update_run_button_for_cpu(default_preprocessor)
all_updates[ui_components["run_cn"]] = run_button_update
all_updates[ui_components["zero_gpu_cn"]] = zero_gpu_update
return all_updates
all_load_outputs = []
for prefix in ["txt2img", "img2img", "inpaint", "outpaint", "hires_fix"]:
if f'controlnet_types_{prefix}' in ui_components:
all_load_outputs.extend(ui_components[f'controlnet_types_{prefix}'])
all_load_outputs.extend(ui_components[f'controlnet_series_{prefix}'])
all_load_outputs.extend(ui_components[f'controlnet_filepaths_{prefix}'])
if f'diffsynth_controlnet_types_{prefix}' in ui_components:
all_load_outputs.extend(ui_components[f'diffsynth_controlnet_types_{prefix}'])
all_load_outputs.extend(ui_components[f'diffsynth_controlnet_series_{prefix}'])
all_load_outputs.extend(ui_components[f'diffsynth_controlnet_filepaths_{prefix}'])
all_load_outputs.extend([
ui_components["preprocessor_model_cn"],
*ui_components["cn_sliders"],
*ui_components["cn_dropdowns"],
*ui_components["cn_checkboxes"],
ui_components["run_cn"],
ui_components["zero_gpu_cn"]
])
if all_load_outputs:
demo.load(
fn=run_on_load,
outputs=all_load_outputs
)