Spaces:
Runtime error
Runtime error
Anurag Bhardwaj
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,7 +13,6 @@ import torch
|
|
| 13 |
from PIL import Image
|
| 14 |
import gradio as gr
|
| 15 |
|
| 16 |
-
|
| 17 |
from diffusers import (
|
| 18 |
DiffusionPipeline,
|
| 19 |
AutoencoderTiny,
|
|
@@ -26,13 +25,21 @@ from huggingface_hub import (
|
|
| 26 |
hf_hub_download,
|
| 27 |
HfFileSystem,
|
| 28 |
ModelCard,
|
| 29 |
-
snapshot_download
|
|
|
|
|
|
|
| 30 |
|
| 31 |
from diffusers.utils import load_image
|
| 32 |
|
| 33 |
import spaces
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def calculate_shift(
|
| 38 |
image_seq_len,
|
|
@@ -2089,24 +2096,25 @@ loras = [
|
|
| 2089 |
]
|
| 2090 |
|
| 2091 |
#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
|
| 2092 |
-
|
| 2093 |
dtype = torch.bfloat16
|
| 2094 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 2095 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 2096 |
|
| 2097 |
-
#TAEF1 is very tiny autoencoder
|
| 2098 |
-
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
| 2099 |
-
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
| 2100 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
| 2101 |
-
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
|
| 2105 |
-
|
| 2106 |
-
|
| 2107 |
-
|
| 2108 |
-
|
| 2109 |
-
|
|
|
|
|
|
|
| 2110 |
|
| 2111 |
MAX_SEED = 2**32-1
|
| 2112 |
|
|
@@ -2210,7 +2218,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 2210 |
pipe.unload_lora_weights()
|
| 2211 |
pipe_i2i.unload_lora_weights()
|
| 2212 |
|
| 2213 |
-
#LoRA weights flow
|
| 2214 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
| 2215 |
pipe_to_use = pipe_i2i if image_input is not None else pipe
|
| 2216 |
weight_name = selected_lora.get("weights", None)
|
|
@@ -2235,7 +2243,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 2235 |
final_image = None
|
| 2236 |
step_counter = 0
|
| 2237 |
for image in image_generator:
|
| 2238 |
-
step_counter+=1
|
| 2239 |
final_image = image
|
| 2240 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
| 2241 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
|
@@ -2243,41 +2251,37 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 2243 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
| 2244 |
|
| 2245 |
def get_huggingface_safetensors(link):
|
| 2246 |
-
|
| 2247 |
-
|
| 2248 |
-
|
| 2249 |
-
|
| 2250 |
-
|
| 2251 |
|
| 2252 |
-
|
| 2253 |
-
|
| 2254 |
-
|
| 2255 |
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
|
| 2264 |
-
|
| 2265 |
-
|
| 2266 |
-
|
| 2267 |
-
|
| 2268 |
-
|
| 2269 |
-
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
print(e)
|
| 2274 |
-
gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
| 2275 |
-
raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
| 2276 |
-
return split_link[1], link, safetensors_name, trigger_word, image_url
|
| 2277 |
|
| 2278 |
def check_custom_model(link):
|
| 2279 |
-
if
|
| 2280 |
-
if
|
| 2281 |
link_split = link.split("huggingface.co/")
|
| 2282 |
return get_huggingface_safetensors(link_split[1])
|
| 2283 |
else:
|
|
@@ -2285,7 +2289,7 @@ def check_custom_model(link):
|
|
| 2285 |
|
| 2286 |
def add_custom_lora(custom_lora):
|
| 2287 |
global loras
|
| 2288 |
-
if
|
| 2289 |
try:
|
| 2290 |
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
|
| 2291 |
print(f"Loaded custom LoRA: {repo}")
|
|
@@ -2302,7 +2306,7 @@ def add_custom_lora(custom_lora):
|
|
| 2302 |
</div>
|
| 2303 |
'''
|
| 2304 |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
|
| 2305 |
-
if
|
| 2306 |
new_item = {
|
| 2307 |
"image": image,
|
| 2308 |
"title": title,
|
|
@@ -2316,8 +2320,8 @@ def add_custom_lora(custom_lora):
|
|
| 2316 |
|
| 2317 |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
|
| 2318 |
except Exception as e:
|
| 2319 |
-
gr.Warning(
|
| 2320 |
-
return gr.update(visible=True, value=
|
| 2321 |
else:
|
| 2322 |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
| 2323 |
|
|
@@ -2371,7 +2375,7 @@ with gr.Blocks(theme="prithivMLmods/Minecraft-Theme", css=css, delete_cache=(60,
|
|
| 2371 |
custom_lora_info = gr.HTML(visible=False)
|
| 2372 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
| 2373 |
with gr.Column():
|
| 2374 |
-
progress_bar = gr.Markdown(elem_id="progress",visible=False)
|
| 2375 |
result = gr.Image(label="Generated Image")
|
| 2376 |
|
| 2377 |
with gr.Row():
|
|
|
|
| 13 |
from PIL import Image
|
| 14 |
import gradio as gr
|
| 15 |
|
|
|
|
| 16 |
from diffusers import (
|
| 17 |
DiffusionPipeline,
|
| 18 |
AutoencoderTiny,
|
|
|
|
| 25 |
hf_hub_download,
|
| 26 |
HfFileSystem,
|
| 27 |
ModelCard,
|
| 28 |
+
snapshot_download,
|
| 29 |
+
login # imported for one-time authentication
|
| 30 |
+
)
|
| 31 |
|
| 32 |
from diffusers.utils import load_image
|
| 33 |
|
| 34 |
import spaces
|
| 35 |
|
| 36 |
+
# -------------------------------
|
| 37 |
+
# Authenticate with Hugging Face once
|
| 38 |
+
# -------------------------------
|
| 39 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 40 |
+
if HF_TOKEN:
|
| 41 |
+
login(HF_TOKEN)
|
| 42 |
+
print("Authenticated with Hugging Face.")
|
| 43 |
|
| 44 |
def calculate_shift(
|
| 45 |
image_seq_len,
|
|
|
|
| 2096 |
]
|
| 2097 |
|
| 2098 |
#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
|
|
|
|
| 2099 |
dtype = torch.bfloat16
|
| 2100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 2101 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 2102 |
|
| 2103 |
+
# TAEF1 is a very tiny autoencoder using the same "latent API" as FLUX.1's VAE.
|
| 2104 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
|
| 2105 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
|
| 2106 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, use_auth_token=HF_TOKEN).to(device)
|
| 2107 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
| 2108 |
+
base_model,
|
| 2109 |
+
vae=good_vae,
|
| 2110 |
+
transformer=pipe.transformer,
|
| 2111 |
+
text_encoder=pipe.text_encoder,
|
| 2112 |
+
tokenizer=pipe.tokenizer,
|
| 2113 |
+
text_encoder_2=pipe.text_encoder_2,
|
| 2114 |
+
tokenizer_2=pipe.tokenizer_2,
|
| 2115 |
+
torch_dtype=dtype,
|
| 2116 |
+
use_auth_token=HF_TOKEN
|
| 2117 |
+
)
|
| 2118 |
|
| 2119 |
MAX_SEED = 2**32-1
|
| 2120 |
|
|
|
|
| 2218 |
pipe.unload_lora_weights()
|
| 2219 |
pipe_i2i.unload_lora_weights()
|
| 2220 |
|
| 2221 |
+
# LoRA weights flow
|
| 2222 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
| 2223 |
pipe_to_use = pipe_i2i if image_input is not None else pipe
|
| 2224 |
weight_name = selected_lora.get("weights", None)
|
|
|
|
| 2243 |
final_image = None
|
| 2244 |
step_counter = 0
|
| 2245 |
for image in image_generator:
|
| 2246 |
+
step_counter += 1
|
| 2247 |
final_image = image
|
| 2248 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
| 2249 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
|
|
|
| 2251 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
| 2252 |
|
| 2253 |
def get_huggingface_safetensors(link):
|
| 2254 |
+
split_link = link.split("/")
|
| 2255 |
+
if len(split_link) == 2:
|
| 2256 |
+
model_card = ModelCard.load(link)
|
| 2257 |
+
base_model = model_card.data.get("base_model")
|
| 2258 |
+
print(base_model)
|
| 2259 |
|
| 2260 |
+
# Allows Both
|
| 2261 |
+
if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
|
| 2262 |
+
raise Exception("Flux LoRA Not Found!")
|
| 2263 |
|
| 2264 |
+
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
| 2265 |
+
trigger_word = model_card.data.get("instance_prompt", "")
|
| 2266 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
| 2267 |
+
fs = HfFileSystem()
|
| 2268 |
+
try:
|
| 2269 |
+
list_of_files = fs.ls(link, detail=False)
|
| 2270 |
+
for file in list_of_files:
|
| 2271 |
+
if file.endswith(".safetensors"):
|
| 2272 |
+
safetensors_name = file.split("/")[-1]
|
| 2273 |
+
if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
|
| 2274 |
+
image_elements = file.split("/")
|
| 2275 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
|
| 2276 |
+
except Exception as e:
|
| 2277 |
+
print(e)
|
| 2278 |
+
gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
| 2279 |
+
raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
| 2280 |
+
return split_link[1], link, safetensors_name, trigger_word, image_url
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2281 |
|
| 2282 |
def check_custom_model(link):
|
| 2283 |
+
if link.startswith("https://"):
|
| 2284 |
+
if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
|
| 2285 |
link_split = link.split("huggingface.co/")
|
| 2286 |
return get_huggingface_safetensors(link_split[1])
|
| 2287 |
else:
|
|
|
|
| 2289 |
|
| 2290 |
def add_custom_lora(custom_lora):
|
| 2291 |
global loras
|
| 2292 |
+
if custom_lora:
|
| 2293 |
try:
|
| 2294 |
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
|
| 2295 |
print(f"Loaded custom LoRA: {repo}")
|
|
|
|
| 2306 |
</div>
|
| 2307 |
'''
|
| 2308 |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
|
| 2309 |
+
if existing_item_index is None:
|
| 2310 |
new_item = {
|
| 2311 |
"image": image,
|
| 2312 |
"title": title,
|
|
|
|
| 2320 |
|
| 2321 |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
|
| 2322 |
except Exception as e:
|
| 2323 |
+
gr.Warning("Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
|
| 2324 |
+
return gr.update(visible=True, value="Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
|
| 2325 |
else:
|
| 2326 |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
| 2327 |
|
|
|
|
| 2375 |
custom_lora_info = gr.HTML(visible=False)
|
| 2376 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
| 2377 |
with gr.Column():
|
| 2378 |
+
progress_bar = gr.Markdown(elem_id="progress", visible=False)
|
| 2379 |
result = gr.Image(label="Generated Image")
|
| 2380 |
|
| 2381 |
with gr.Row():
|