Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import requests | |
| import hashlib | |
| import re | |
| from typing import Sequence, Mapping, Any, Union, Set | |
| from pathlib import Path | |
| import shutil | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, constants as hf_constants | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageChops | |
| from core.settings import * | |
| DISK_LIMIT_GB = 120 | |
| MODELS_ROOT_DIR = "ComfyUI/models" | |
| PREPROCESSOR_MODEL_MAP = None | |
| PREPROCESSOR_PARAMETER_MAP = None | |
| def save_uploaded_file_with_hash(file_obj: gr.File, target_dir: str) -> str: | |
| if not file_obj: | |
| return "" | |
| temp_path = file_obj.name | |
| sha256 = hashlib.sha256() | |
| with open(temp_path, 'rb') as f: | |
| for block in iter(lambda: f.read(65536), b''): | |
| sha256.update(block) | |
| file_hash = sha256.hexdigest() | |
| _, extension = os.path.splitext(temp_path) | |
| hashed_filename = f"{file_hash}{extension.lower()}" | |
| dest_path = os.path.join(target_dir, hashed_filename) | |
| os.makedirs(target_dir, exist_ok=True) | |
| if not os.path.exists(dest_path): | |
| shutil.copy(temp_path, dest_path) | |
| print(f"✅ Saved uploaded file as: {dest_path}") | |
| else: | |
| print(f"ℹ️ File already exists (deduplicated): {dest_path}") | |
| return hashed_filename | |
| def bytes_to_gb(byte_size: int) -> float: | |
| if byte_size is None or byte_size == 0: | |
| return 0.0 | |
| return round(byte_size / (1024 ** 3), 2) | |
| def get_directory_size(path: str) -> int: | |
| total_size = 0 | |
| if not os.path.exists(path): | |
| return 0 | |
| try: | |
| for dirpath, _, filenames in os.walk(path): | |
| for f in filenames: | |
| fp = os.path.join(dirpath, f) | |
| if os.path.isfile(fp) and not os.path.islink(fp): | |
| total_size += os.path.getsize(fp) | |
| except OSError as e: | |
| print(f"Warning: Could not access {path} to calculate size: {e}") | |
| return total_size | |
| def enforce_disk_limit(): | |
| disk_limit_bytes = DISK_LIMIT_GB * (1024 ** 3) | |
| cache_dir = hf_constants.HF_HUB_CACHE | |
| if not os.path.exists(cache_dir): | |
| return | |
| print(f"--- [Storage Manager] Checking disk usage in '{cache_dir}' (Limit: {DISK_LIMIT_GB} GB) ---") | |
| try: | |
| all_files = [] | |
| current_size_bytes = 0 | |
| for dirpath, _, filenames in os.walk(cache_dir): | |
| for f in filenames: | |
| if f.endswith(".incomplete") or f.endswith(".lock"): | |
| continue | |
| file_path = os.path.join(dirpath, f) | |
| if os.path.isfile(file_path) and not os.path.islink(file_path): | |
| try: | |
| file_size = os.path.getsize(file_path) | |
| creation_time = os.path.getctime(file_path) | |
| all_files.append((creation_time, file_path, file_size)) | |
| current_size_bytes += file_size | |
| except OSError: | |
| continue | |
| print(f"--- [Storage Manager] Current usage: {bytes_to_gb(current_size_bytes)} GB ---") | |
| if current_size_bytes > disk_limit_bytes: | |
| print(f"--- [Storage Manager] Usage exceeds limit. Starting cleanup... ---") | |
| all_files.sort(key=lambda x: x[0]) | |
| while current_size_bytes > disk_limit_bytes and all_files: | |
| oldest_file_time, oldest_file_path, oldest_file_size = all_files.pop(0) | |
| try: | |
| os.remove(oldest_file_path) | |
| current_size_bytes -= oldest_file_size | |
| print(f"--- [Storage Manager] Deleted oldest file: {os.path.basename(oldest_file_path)} ({bytes_to_gb(oldest_file_size)} GB freed) ---") | |
| except OSError as e: | |
| print(f"--- [Storage Manager] Error deleting file {oldest_file_path}: {e} ---") | |
| print(f"--- [Storage Manager] Cleanup finished. New usage: {bytes_to_gb(current_size_bytes)} GB ---") | |
| else: | |
| print("--- [Storage Manager] Disk usage is within the limit. No action needed. ---") | |
| except Exception as e: | |
| print(f"--- [Storage Manager] An unexpected error occurred: {e} ---") | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| try: | |
| return obj[index] | |
| except (KeyError, IndexError): | |
| try: | |
| return obj["result"][index] | |
| except (KeyError, IndexError): | |
| return None | |
| def sanitize_prompt(prompt: str) -> str: | |
| if not isinstance(prompt, str): | |
| return "" | |
| return "".join(char for char in prompt if char.isprintable() or char in ('\n', '\t')) | |
| def sanitize_id(input_id: str) -> str: | |
| if not isinstance(input_id, str): | |
| return "" | |
| return re.sub(r'[^0-9]', '', input_id) | |
| def sanitize_url(url: str) -> str: | |
| if not isinstance(url, str): | |
| raise ValueError("URL must be a string.") | |
| url = url.strip() | |
| if not re.match(r'^https?://[^\s/$.?#].[^\s]*$', url): | |
| raise ValueError("Invalid URL format or scheme. Only HTTP and HTTPS are allowed.") | |
| return url | |
| def sanitize_filename(filename: str) -> str: | |
| if not isinstance(filename, str): | |
| return "" | |
| sanitized = filename.replace('..', '') | |
| sanitized = re.sub(r'[^\w\.\-]', '_', sanitized) | |
| return sanitized.lstrip('/\\') | |
| def get_civitai_file_info(version_id: str) -> dict | None: | |
| api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
| try: | |
| response = requests.get(api_url, timeout=10) | |
| response.raise_for_status() | |
| data = response.json() | |
| for file_data in data.get('files', []): | |
| if file_data.get('type') == 'Model' and file_data['name'].endswith(('.safetensors', '.pt', '.bin')): | |
| return file_data | |
| if data.get('files'): | |
| return data['files'][0] | |
| except Exception: | |
| return None | |
| def download_file(url: str, save_path: str, api_key: str = None, progress=None, desc: str = "") -> str: | |
| enforce_disk_limit() | |
| if os.path.exists(save_path): | |
| return f"File already exists: {os.path.basename(save_path)}" | |
| headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {} | |
| try: | |
| if progress: | |
| progress(0, desc=desc) | |
| response = requests.get(url, stream=True, headers=headers, timeout=15) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(save_path, "wb") as f: | |
| downloaded = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| if progress and total_size > 0: | |
| downloaded += len(chunk) | |
| progress(downloaded / total_size, desc=desc) | |
| return f"Successfully downloaded: {os.path.basename(save_path)}" | |
| except Exception as e: | |
| if os.path.exists(save_path): | |
| os.remove(save_path) | |
| return f"Download failed for {os.path.basename(save_path)}: {e}" | |
| def get_lora_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID provided. Must be numeric." | |
| filename = sanitize_filename(f"civitai_{version_id}.safetensors") | |
| local_path = os.path.join(LORA_DIR, filename) | |
| file_info = get_civitai_file_info(version_id) | |
| api_key_to_use = civitai_key | |
| source_name = f"Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def get_embedding_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| file_ext = ".safetensors" | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID. Must be numeric." | |
| file_info = get_civitai_file_info(version_id) | |
| if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): | |
| file_ext = os.path.splitext(file_info['name'])[1] | |
| filename = sanitize_filename(f"civitai_{version_id}{file_ext}") | |
| local_path = os.path.join(EMBEDDING_DIR, filename) | |
| api_key_to_use = civitai_key | |
| source_name = f"Embedding Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def get_vae_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| file_ext = ".safetensors" | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID. Must be numeric." | |
| file_info = get_civitai_file_info(version_id) | |
| if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): | |
| file_ext = os.path.splitext(file_info['name'])[1] | |
| filename = sanitize_filename(f"civitai_{version_id}{file_ext}") | |
| local_path = os.path.join(VAE_DIR, filename) | |
| api_key_to_use = civitai_key | |
| source_name = f"VAE Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def _ensure_model_downloaded(filename: str, progress=gr.Progress()): | |
| download_info = ALL_FILE_DOWNLOAD_MAP.get(filename) | |
| if not download_info: | |
| raise gr.Error(f"Model component '{filename}' not found in file_list.yaml. Cannot download.") | |
| category_to_dir_map = { | |
| "diffusion_models": DIFFUSION_MODELS_DIR, | |
| "text_encoders": TEXT_ENCODERS_DIR, | |
| "vae": VAE_DIR, | |
| "checkpoints": CHECKPOINT_DIR, | |
| "loras": LORA_DIR, | |
| "controlnet": CONTROLNET_DIR, | |
| "model_patches": MODEL_PATCHES_DIR, | |
| "clip_vision": os.path.join(os.path.dirname(LORA_DIR), "clip_vision") | |
| } | |
| category = download_info.get('category') | |
| dest_dir = category_to_dir_map.get(category) | |
| if not dest_dir: | |
| raise ValueError(f"Unknown model category '{category}' for file '{filename}'.") | |
| dest_path = os.path.join(dest_dir, filename) | |
| if os.path.lexists(dest_path): | |
| if not os.path.exists(dest_path): | |
| print(f"⚠️ Found and removed broken symlink: {dest_path}") | |
| os.remove(dest_path) | |
| else: | |
| return filename | |
| source = download_info.get("source") | |
| try: | |
| progress(0, desc=f"Downloading: {filename}") | |
| if source == "hf": | |
| repo_id = download_info.get("repo_id") | |
| hf_filename = download_info.get("repository_file_path", filename) | |
| if not repo_id: | |
| raise ValueError(f"repo_id is missing for HF model '{filename}'") | |
| cached_path = hf_hub_download(repo_id=repo_id, filename=hf_filename) | |
| os.makedirs(dest_dir, exist_ok=True) | |
| os.symlink(cached_path, dest_path) | |
| print(f"✅ Symlinked '{cached_path}' to '{dest_path}'") | |
| elif source == "civitai": | |
| model_version_id = download_info.get("model_version_id") | |
| if not model_version_id: | |
| raise ValueError(f"model_version_id is missing for Civitai model '{filename}'") | |
| file_info = get_civitai_file_info(model_version_id) | |
| if not file_info or not file_info.get('downloadUrl'): | |
| raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}") | |
| status = download_file( | |
| file_info['downloadUrl'], dest_path, progress=progress, desc=f"Downloading: {filename}" | |
| ) | |
| if "Failed" in status: | |
| raise ConnectionError(status) | |
| else: | |
| raise NotImplementedError(f"Download source '{source}' is not implemented for '{filename}'") | |
| progress(1.0, desc=f"Downloaded: {filename}") | |
| except Exception as e: | |
| if os.path.lexists(dest_path): | |
| try: | |
| os.remove(dest_path) | |
| except OSError: pass | |
| raise gr.Error(f"Failed to download and link '{filename}': {e}") | |
| return filename | |
| def ensure_controlnet_model_downloaded(filename: str, progress): | |
| if not filename or filename == "None": | |
| return | |
| _ensure_model_downloaded(filename, progress) | |
| def build_preprocessor_model_map(): | |
| global PREPROCESSOR_MODEL_MAP | |
| if PREPROCESSOR_MODEL_MAP is not None: return PREPROCESSOR_MODEL_MAP | |
| print("--- Building ControlNet Preprocessor model map ---") | |
| manual_map = { | |
| "dwpose": [("yzd-v/DWPose", "yolox_l.onnx"), ("yzd-v/DWPose", "dw-ll_ucoco_384.onnx"), ("hr16/UnJIT-DWPose", "dw-ll_ucoco.onnx"), ("hr16/DWPose-TorchScript-BatchSize5", "dw-ll_ucoco_384_bs5.torchscript.pt"), ("hr16/DWPose-TorchScript-BatchSize5", "rtmpose-m_ap10k_256_bs5.torchscript.pt"), ("hr16/yolo-nas-fp16", "yolo_nas_l_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_m_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_s_fp16.onnx")], | |
| "densepose": [("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r50_fpn_dl.torchscript"), ("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r101_fpn_dl.torchscript")] | |
| } | |
| temp_map = {} | |
| from nodes import NODE_DISPLAY_NAME_MAPPINGS | |
| wrappers_dir = Path("./custom_nodes/comfyui_controlnet_aux/node_wrappers/") | |
| if not wrappers_dir.exists(): | |
| print("⚠️ ControlNet AUX wrappers directory not found. Cannot build model map.") | |
| PREPROCESSOR_MODEL_MAP = {}; return PREPROCESSOR_MODEL_MAP | |
| for wrapper_file in wrappers_dir.glob("*.py"): | |
| if wrapper_file.name == "__init__.py": continue | |
| with open(wrapper_file, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| display_name_matches = re.findall(r'NODE_DISPLAY_NAME_MAPPINGS\s*=\s*{(?:.|\n)*?["\'](.*?)["\']\s*:\s*["\'](.*?)["\']', content) | |
| for _, display_name in display_name_matches: | |
| if display_name not in temp_map: temp_map[display_name] = [] | |
| manual_key = wrapper_file.stem | |
| if manual_key in manual_map: temp_map[display_name].extend(manual_map[manual_key]) | |
| matches = re.findall(r"from_pretrained\s*\(\s*(?:filename=)?\s*f?[\"']([^\"']+)[\"']", content) | |
| for model_filename in matches: | |
| repo_id = "lllyasviel/Annotators" | |
| if "depth_anything" in model_filename and "v2" in model_filename: repo_id = "LiheYoung/Depth-Anything-V2" | |
| elif "depth_anything" in model_filename: repo_id = "LiheYoung/Depth-Anything" | |
| elif "diffusion_edge" in model_filename: repo_id = "hr16/Diffusion-Edge" | |
| temp_map[display_name].append((repo_id, model_filename)) | |
| final_map = {name: sorted(list(set(models))) for name, models in temp_map.items() if models} | |
| PREPROCESSOR_MODEL_MAP = final_map | |
| print("✅ ControlNet Preprocessor model map built."); return PREPROCESSOR_MODEL_MAP | |
| def build_preprocessor_parameter_map(): | |
| global PREPROCESSOR_PARAMETER_MAP | |
| if PREPROCESSOR_PARAMETER_MAP is not None: return | |
| print("--- Building ControlNet Preprocessor parameter map ---") | |
| param_map = {} | |
| from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS | |
| for class_name, node_class in NODE_CLASS_MAPPINGS.items(): | |
| if not hasattr(node_class, "INPUT_TYPES"): continue | |
| if hasattr(node_class, '__module__') and 'comfyui_controlnet_aux.node_wrappers' not in node_class.__module__: continue | |
| display_name = NODE_DISPLAY_NAME_MAPPINGS.get(class_name) | |
| if not display_name: continue | |
| try: | |
| input_types = node_class.INPUT_TYPES() | |
| all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})} | |
| params = [] | |
| for name, details in all_inputs.items(): | |
| if name in ['image', 'resolution', 'pose_kps']: continue | |
| if not isinstance(details, (list, tuple)) or not details: continue | |
| param_type = details[0] | |
| param_config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {} | |
| param_info = {"name": name, "type": param_type, "config": param_config} | |
| params.append(param_info) | |
| if params: param_map[display_name] = params | |
| except Exception as e: | |
| print(f"⚠️ Could not parse parameters for {display_name}: {e}") | |
| PREPROCESSOR_PARAMETER_MAP = param_map | |
| print("✅ ControlNet Preprocessor parameter map built.") | |
| def print_welcome_message(): | |
| author_name = "RioShiina" | |
| project_url = "https://huggingface.co/RioShiina" | |
| border = "=" * 72 | |
| message = ( | |
| f"\n{border}\n\n" | |
| f" Thank you for using this project!\n\n" | |
| f" **Author:** {author_name}\n" | |
| f" **Find more from the author:** {project_url}\n\n" | |
| f" This project is open-source under the GNU General Public License v3.0 (GPL-3.0).\n" | |
| f" As it's built upon GPL-3.0 components (like ComfyUI), any modifications you\n" | |
| f" distribute must also be open-sourced under the same license.\n\n" | |
| f" Your respect for the principles of free software is greatly appreciated!\n\n" | |
| f"{border}\n" | |
| ) | |
| print(message) |