File size: 19,762 Bytes
f36e497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f5732f
f36e497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f5732f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e497
 
 
 
 
9f5732f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
import os
import requests
import hashlib
import re
from typing import Sequence, Mapping, Any, Union, Set
from pathlib import Path
import shutil

import gradio as gr
from huggingface_hub import hf_hub_download, constants as hf_constants
import torch
import numpy as np
from PIL import Image, ImageChops


from core.settings import *

DISK_LIMIT_GB = 120
MODELS_ROOT_DIR = "ComfyUI/models"

PREPROCESSOR_MODEL_MAP = None
PREPROCESSOR_PARAMETER_MAP = None


def save_uploaded_file_with_hash(file_obj: gr.File, target_dir: str) -> str:
    if not file_obj:
        return ""
    
    temp_path = file_obj.name
    
    sha256 = hashlib.sha256()
    with open(temp_path, 'rb') as f:
        for block in iter(lambda: f.read(65536), b''):
            sha256.update(block)
    
    file_hash = sha256.hexdigest()
    _, extension = os.path.splitext(temp_path)
    hashed_filename = f"{file_hash}{extension.lower()}"
    
    dest_path = os.path.join(target_dir, hashed_filename)
    
    os.makedirs(target_dir, exist_ok=True)
    if not os.path.exists(dest_path):
        shutil.copy(temp_path, dest_path)
        print(f"✅ Saved uploaded file as: {dest_path}")
    else:
        print(f"ℹ️ File already exists (deduplicated): {dest_path}")
        
    return hashed_filename


def bytes_to_gb(byte_size: int) -> float:
    if byte_size is None or byte_size == 0:
        return 0.0
    return round(byte_size / (1024 ** 3), 2)

def get_directory_size(path: str) -> int:
    total_size = 0
    if not os.path.exists(path):
        return 0
    try:
        for dirpath, _, filenames in os.walk(path):
            for f in filenames:
                fp = os.path.join(dirpath, f)
                if os.path.isfile(fp) and not os.path.islink(fp):
                    total_size += os.path.getsize(fp)
    except OSError as e:
        print(f"Warning: Could not access {path} to calculate size: {e}")
    return total_size

def enforce_disk_limit():
    disk_limit_bytes = DISK_LIMIT_GB * (1024 ** 3)
    cache_dir = hf_constants.HF_HUB_CACHE

    if not os.path.exists(cache_dir):
        return

    print(f"--- [Storage Manager] Checking disk usage in '{cache_dir}' (Limit: {DISK_LIMIT_GB} GB) ---")

    try:
        all_files = []
        current_size_bytes = 0
        for dirpath, _, filenames in os.walk(cache_dir):
            for f in filenames:
                if f.endswith(".incomplete") or f.endswith(".lock"):
                    continue
                file_path = os.path.join(dirpath, f)
                if os.path.isfile(file_path) and not os.path.islink(file_path):
                    try:
                        file_size = os.path.getsize(file_path)
                        creation_time = os.path.getctime(file_path)
                        all_files.append((creation_time, file_path, file_size))
                        current_size_bytes += file_size
                    except OSError:
                        continue

        print(f"--- [Storage Manager] Current usage: {bytes_to_gb(current_size_bytes)} GB ---")

        if current_size_bytes > disk_limit_bytes:
            print(f"--- [Storage Manager] Usage exceeds limit. Starting cleanup... ---")
            all_files.sort(key=lambda x: x[0])

            while current_size_bytes > disk_limit_bytes and all_files:
                oldest_file_time, oldest_file_path, oldest_file_size = all_files.pop(0)
                try:
                    os.remove(oldest_file_path)
                    current_size_bytes -= oldest_file_size
                    print(f"--- [Storage Manager] Deleted oldest file: {os.path.basename(oldest_file_path)} ({bytes_to_gb(oldest_file_size)} GB freed) ---")
                except OSError as e:
                    print(f"--- [Storage Manager] Error deleting file {oldest_file_path}: {e} ---")

            print(f"--- [Storage Manager] Cleanup finished. New usage: {bytes_to_gb(current_size_bytes)} GB ---")
        else:
            print("--- [Storage Manager] Disk usage is within the limit. No action needed. ---")

    except Exception as e:
        print(f"--- [Storage Manager] An unexpected error occurred: {e} ---")


def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
    try:
        return obj[index]
    except (KeyError, IndexError):
        try:
            return obj["result"][index]
        except (KeyError, IndexError):
            return None

def sanitize_prompt(prompt: str) -> str:
    if not isinstance(prompt, str):
        return ""
    return "".join(char for char in prompt if char.isprintable() or char in ('\n', '\t'))

def sanitize_id(input_id: str) -> str:
    if not isinstance(input_id, str):
        return ""
    return re.sub(r'[^0-9]', '', input_id)

def sanitize_url(url: str) -> str:
    if not isinstance(url, str):
        raise ValueError("URL must be a string.")
    url = url.strip()
    if not re.match(r'^https?://[^\s/$.?#].[^\s]*$', url):
        raise ValueError("Invalid URL format or scheme. Only HTTP and HTTPS are allowed.")
    return url

def sanitize_filename(filename: str) -> str:
    if not isinstance(filename, str):
        return ""
    sanitized = filename.replace('..', '')
    sanitized = re.sub(r'[^\w\.\-]', '_', sanitized)
    return sanitized.lstrip('/\\')


def get_civitai_file_info(version_id: str) -> dict | None:
    api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
    try:
        response = requests.get(api_url, timeout=10)
        response.raise_for_status()
        data = response.json()
        
        for file_data in data.get('files', []):
            if file_data.get('type') == 'Model' and file_data['name'].endswith(('.safetensors', '.pt', '.bin')):
                return file_data
        
        if data.get('files'):
            return data['files'][0]
    except Exception:
        return None


def download_file(url: str, save_path: str, api_key: str = None, progress=None, desc: str = "") -> str:
    enforce_disk_limit()
    
    if os.path.exists(save_path):
        return f"File already exists: {os.path.basename(save_path)}"
    
    headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {}
    try:
        if progress:
            progress(0, desc=desc)
        
        response = requests.get(url, stream=True, headers=headers, timeout=15)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        
        with open(save_path, "wb") as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                if progress and total_size > 0:
                    downloaded += len(chunk)
                    progress(downloaded / total_size, desc=desc)
        return f"Successfully downloaded: {os.path.basename(save_path)}"
    except Exception as e:
        if os.path.exists(save_path):
            os.remove(save_path)
        return f"Download failed for {os.path.basename(save_path)}: {e}"


def get_lora_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
    if not id_or_url or not id_or_url.strip():
        return None, "No ID/URL provided."

    try:
        if source == "Civitai":
            version_id = sanitize_id(id_or_url)
            if not version_id:
                return None, "Invalid Civitai ID provided. Must be numeric."
            filename = sanitize_filename(f"civitai_{version_id}.safetensors")
            local_path = os.path.join(LORA_DIR, filename)
            file_info = get_civitai_file_info(version_id)
            api_key_to_use = civitai_key
            source_name = f"Civitai ID {version_id}"
        else:
            return None, "Invalid source."

    except ValueError as e:
        return None, f"Input validation failed: {e}"

    if os.path.exists(local_path):
        return local_path, "File already exists."

    if not file_info or not file_info.get('downloadUrl'):
        return None, f"Could not get download link for {source_name}."

    status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
    
    return (local_path, status) if "Successfully" in status else (None, status)

def get_embedding_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
    if not id_or_url or not id_or_url.strip():
        return None, "No ID/URL provided."

    try:
        file_ext = ".safetensors"

        if source == "Civitai":
            version_id = sanitize_id(id_or_url)
            if not version_id:
                return None, "Invalid Civitai ID. Must be numeric."
            
            file_info = get_civitai_file_info(version_id)
            if file_info and file_info['name'].lower().endswith(('.pt', '.bin')):
                file_ext = os.path.splitext(file_info['name'])[1]

            filename = sanitize_filename(f"civitai_{version_id}{file_ext}")
            local_path = os.path.join(EMBEDDING_DIR, filename)
            api_key_to_use = civitai_key
            source_name = f"Embedding Civitai ID {version_id}"
        else:
            return None, "Invalid source."

    except ValueError as e:
        return None, f"Input validation failed: {e}"

    if os.path.exists(local_path):
        return local_path, "File already exists."

    if not file_info or not file_info.get('downloadUrl'):
        return None, f"Could not get download link for {source_name}."

    status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
    
    return (local_path, status) if "Successfully" in status else (None, status)

def get_vae_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]:
    if not id_or_url or not id_or_url.strip():
        return None, "No ID/URL provided."

    try:
        file_ext = ".safetensors"

        if source == "Civitai":
            version_id = sanitize_id(id_or_url)
            if not version_id:
                return None, "Invalid Civitai ID. Must be numeric."
            
            file_info = get_civitai_file_info(version_id)
            if file_info and file_info['name'].lower().endswith(('.pt', '.bin')):
                file_ext = os.path.splitext(file_info['name'])[1]

            filename = sanitize_filename(f"civitai_{version_id}{file_ext}")
            local_path = os.path.join(VAE_DIR, filename)
            api_key_to_use = civitai_key
            source_name = f"VAE Civitai ID {version_id}"
        else:
            return None, "Invalid source."

    except ValueError as e:
        return None, f"Input validation failed: {e}"

    if os.path.exists(local_path):
        return local_path, "File already exists."

    if not file_info or not file_info.get('downloadUrl'):
        return None, f"Could not get download link for {source_name}."

    status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
    
    return (local_path, status) if "Successfully" in status else (None, status)


def _ensure_model_downloaded(filename: str, progress=gr.Progress()):
    download_info = ALL_FILE_DOWNLOAD_MAP.get(filename)
    if not download_info:
        raise gr.Error(f"Model component '{filename}' not found in file_list.yaml. Cannot download.")
        
    category_to_dir_map = {
        "diffusion_models": DIFFUSION_MODELS_DIR,
        "text_encoders": TEXT_ENCODERS_DIR,
        "vae": VAE_DIR,
        "checkpoints": CHECKPOINT_DIR,
        "loras": LORA_DIR,
        "controlnet": CONTROLNET_DIR,
        "model_patches": MODEL_PATCHES_DIR,
        "clip_vision": os.path.join(os.path.dirname(LORA_DIR), "clip_vision")
    }
    
    category = download_info.get('category')
    dest_dir = category_to_dir_map.get(category)
    if not dest_dir:
        raise ValueError(f"Unknown model category '{category}' for file '{filename}'.")
    
    dest_path = os.path.join(dest_dir, filename)

    if os.path.lexists(dest_path):
        if not os.path.exists(dest_path):
            print(f"⚠️ Found and removed broken symlink: {dest_path}")
            os.remove(dest_path)
        else:
            return filename

    source = download_info.get("source")
    try:
        progress(0, desc=f"Downloading: {filename}")
        
        if source == "hf":
            repo_id = download_info.get("repo_id")
            hf_filename = download_info.get("repository_file_path", filename)
            if not repo_id:
                raise ValueError(f"repo_id is missing for HF model '{filename}'")
            
            cached_path = hf_hub_download(repo_id=repo_id, filename=hf_filename)
            os.makedirs(dest_dir, exist_ok=True)
            os.symlink(cached_path, dest_path)
            print(f"✅ Symlinked '{cached_path}' to '{dest_path}'")

        elif source == "civitai":
            model_version_id = download_info.get("model_version_id")
            if not model_version_id:
                raise ValueError(f"model_version_id is missing for Civitai model '{filename}'")
            
            file_info = get_civitai_file_info(model_version_id)
            if not file_info or not file_info.get('downloadUrl'):
                raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}")
            
            status = download_file(
                file_info['downloadUrl'], dest_path, progress=progress, desc=f"Downloading: {filename}"
            )
            if "Failed" in status:
                raise ConnectionError(status)
        else:
            raise NotImplementedError(f"Download source '{source}' is not implemented for '{filename}'")
            
        progress(1.0, desc=f"Downloaded: {filename}")

    except Exception as e:
        if os.path.lexists(dest_path):
            try:
                os.remove(dest_path)
            except OSError: pass
        raise gr.Error(f"Failed to download and link '{filename}': {e}")
    
    return filename

def ensure_controlnet_model_downloaded(filename: str, progress):
    if not filename or filename == "None":
        return
    _ensure_model_downloaded(filename, progress)


def build_preprocessor_model_map():
    global PREPROCESSOR_MODEL_MAP
    if PREPROCESSOR_MODEL_MAP is not None: return PREPROCESSOR_MODEL_MAP
    print("--- Building ControlNet Preprocessor model map ---")
    manual_map = {
        "dwpose": [("yzd-v/DWPose", "yolox_l.onnx"), ("yzd-v/DWPose", "dw-ll_ucoco_384.onnx"), ("hr16/UnJIT-DWPose", "dw-ll_ucoco.onnx"), ("hr16/DWPose-TorchScript-BatchSize5", "dw-ll_ucoco_384_bs5.torchscript.pt"), ("hr16/DWPose-TorchScript-BatchSize5", "rtmpose-m_ap10k_256_bs5.torchscript.pt"), ("hr16/yolo-nas-fp16", "yolo_nas_l_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_m_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_s_fp16.onnx")],
        "densepose": [("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r50_fpn_dl.torchscript"), ("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r101_fpn_dl.torchscript")]
    }
    temp_map = {}
    from nodes import NODE_DISPLAY_NAME_MAPPINGS
    wrappers_dir = Path("./custom_nodes/comfyui_controlnet_aux/node_wrappers/")
    if not wrappers_dir.exists():
        print("⚠️ ControlNet AUX wrappers directory not found. Cannot build model map.")
        PREPROCESSOR_MODEL_MAP = {}; return PREPROCESSOR_MODEL_MAP
    for wrapper_file in wrappers_dir.glob("*.py"):
        if wrapper_file.name == "__init__.py": continue
        with open(wrapper_file, 'r', encoding='utf-8') as f:
            content = f.read()
            display_name_matches = re.findall(r'NODE_DISPLAY_NAME_MAPPINGS\s*=\s*{(?:.|\n)*?["\'](.*?)["\']\s*:\s*["\'](.*?)["\']', content)
            for _, display_name in display_name_matches:
                if display_name not in temp_map: temp_map[display_name] = []
                manual_key = wrapper_file.stem
                if manual_key in manual_map: temp_map[display_name].extend(manual_map[manual_key])
                matches = re.findall(r"from_pretrained\s*\(\s*(?:filename=)?\s*f?[\"']([^\"']+)[\"']", content)
                for model_filename in matches:
                    repo_id = "lllyasviel/Annotators"
                    if "depth_anything" in model_filename and "v2" in model_filename: repo_id = "LiheYoung/Depth-Anything-V2"
                    elif "depth_anything" in model_filename: repo_id = "LiheYoung/Depth-Anything"
                    elif "diffusion_edge" in model_filename: repo_id = "hr16/Diffusion-Edge"
                    temp_map[display_name].append((repo_id, model_filename))
    final_map = {name: sorted(list(set(models))) for name, models in temp_map.items() if models}
    PREPROCESSOR_MODEL_MAP = final_map
    print("✅ ControlNet Preprocessor model map built."); return PREPROCESSOR_MODEL_MAP

def build_preprocessor_parameter_map():
    global PREPROCESSOR_PARAMETER_MAP
    if PREPROCESSOR_PARAMETER_MAP is not None: return
    print("--- Building ControlNet Preprocessor parameter map ---")
    param_map = {}
    from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
    for class_name, node_class in NODE_CLASS_MAPPINGS.items():
        if not hasattr(node_class, "INPUT_TYPES"): continue
        if hasattr(node_class, '__module__') and 'comfyui_controlnet_aux.node_wrappers' not in node_class.__module__: continue
        display_name = NODE_DISPLAY_NAME_MAPPINGS.get(class_name)
        if not display_name: continue
        try:
            input_types = node_class.INPUT_TYPES()
            all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
            params = []
            for name, details in all_inputs.items():
                if name in ['image', 'resolution', 'pose_kps']: continue
                if not isinstance(details, (list, tuple)) or not details: continue
                param_type = details[0]
                param_config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
                param_info = {"name": name, "type": param_type, "config": param_config}
                params.append(param_info)
            if params: param_map[display_name] = params
        except Exception as e:
            print(f"⚠️ Could not parse parameters for {display_name}: {e}")
    PREPROCESSOR_PARAMETER_MAP = param_map
    print("✅ ControlNet Preprocessor parameter map built.")

def print_welcome_message():
    author_name = "RioShiina"
    project_url = "https://huggingface.co/RioShiina"
    border = "=" * 72
    
    message = (
        f"\n{border}\n\n"
        f"  Thank you for using this project!\n\n"
        f"  **Author:** {author_name}\n"
        f"  **Find more from the author:** {project_url}\n\n"
        f"  This project is open-source under the GNU General Public License v3.0 (GPL-3.0).\n"
        f"  As it's built upon GPL-3.0 components (like ComfyUI), any modifications you\n"
        f"  distribute must also be open-sourced under the same license.\n\n"
        f"  Your respect for the principles of free software is greatly appreciated!\n\n"
        f"{border}\n"
    )
    
    print(message)