# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import json from pathlib import Path import gradio as gr import torch import spaces from uno.flux.pipeline import UNOPipeline def get_examples(examples_dir: str = "assets/examples") -> list: examples = Path(examples_dir) ans = [] for example in examples.iterdir(): if not example.is_dir(): continue with open(example / "config.json") as f: example_dict = json.load(f) example_list = [] example_list.append(example_dict["useage"]) # case for example_list.append(example_dict["prompt"]) # prompt for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]: if key in example_dict: example_list.append(str(example / example_dict[key])) else: example_list.append(None) example_list.append(example_dict["seed"]) ans.append(example_list) return ans def create_demo( model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, ): pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512) pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate) # 自定义CSS样式 css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } .main-header { text-align: center; margin-bottom: 2rem; background: linear-gradient(to right, #4776E6, #8E54E9); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 700; padding: 1rem 0; } .container { border-radius: 12px; box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1); padding: 20px; background: white; margin-bottom: 1.5rem; } .input-container { background: rgba(245, 247, 250, 0.7); border-radius: 10px; padding: 1rem; margin-bottom: 1rem; } .image-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 10px; } .generate-btn { background: linear-gradient(90deg, #4776E6, #8E54E9); border: none; color: white; padding: 10px 20px; border-radius: 50px; font-weight: 600; box-shadow: 0 4px 10px rgba(0,0,0,0.1); transition: all 0.3s ease; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 6px 15px rgba(0,0,0,0.15); } .badge-container { display: flex; justify-content: center; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 1rem; } .badge { display: inline-block; padding: 0.25rem 0.75rem; font-size: 0.875rem; font-weight: 500; line-height: 1.5; text-align: center; white-space: nowrap; vertical-align: middle; border-radius: 30px; color: white; background: #6c5ce7; text-decoration: none; } .output-container { background: rgba(243, 244, 246, 0.7); border-radius: 10px; padding: 1.5rem; } .slider-container label { font-weight: 600; margin-bottom: 0.5rem; color: #4a5568; } """ badges_text = r"""
GitHub Stars Project Page arXiv
""".strip() with gr.Blocks(css=css) as demo: gr.Markdown("#
UNO-FLUX Image Generator
") gr.Markdown(badges_text) with gr.Row(): with gr.Column(scale=3): with gr.Box(elem_classes="container"): prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", value="handsome woman in the city", elem_classes="input-container" ) gr.Markdown("### Reference Images") with gr.Row(elem_classes="image-grid"): image_prompt1 = gr.Image(label="Ref Img 1", visible=True, interactive=True, type="pil") image_prompt2 = gr.Image(label="Ref Img 2", visible=True, interactive=True, type="pil") image_prompt3 = gr.Image(label="Ref Img 3", visible=True, interactive=True, type="pil") image_prompt4 = gr.Image(label="Ref Img 4", visible=True, interactive=True, type="pil") with gr.Row(): with gr.Column(scale=2): with gr.Box(elem_classes="slider-container"): width = gr.Slider(512, 2048, 512, step=16, label="Generation Width") height = gr.Slider(512, 2048, 512, step=16, label="Generation Height") with gr.Column(scale=1): gr.Markdown("
📌 The model was trained on 512x512 resolution.
Sizes closer to 512 are more stable, higher sizes give better visual effects but are less stable.
") with gr.Accordion("Advanced Options", open=False): with gr.Row(): with gr.Column(): num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") with gr.Column(): guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) with gr.Column(): seed = gr.Number(-1, label="Seed (-1 for random)") generate_btn = gr.Button("Generate", elem_classes="generate-btn") with gr.Column(scale=2): with gr.Box(elem_classes="output-container"): gr.Markdown("### Generated Result") output_image = gr.Image(label="Generated Image") download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) inputs = [ prompt, width, height, guidance, num_steps, seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4 ] generate_btn.click( fn=pipeline.gradio_generate, inputs=inputs, outputs=[output_image, download_btn], ) example_text = gr.Text("", visible=False, label="Case For:") examples = get_examples("./assets/examples") with gr.Box(elem_classes="container"): gr.Markdown("###
Examples
") gr.Examples( examples=examples, inputs=[ example_text, prompt, image_prompt1, image_prompt2, image_prompt3, image_prompt4, seed, output_image ], ) return demo if __name__ == "__main__": from typing import Literal from transformers import HfArgumentParser @dataclasses.dataclass class AppArgs: name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" offload: bool = dataclasses.field( default=False, metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."} ) port: int = 7860 parser = HfArgumentParser([AppArgs]) args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] args = args_tuple[0] demo = create_demo(args.name, args.device, args.offload) demo.launch(server_port=args.port, ssr_mode=False)