Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import FluxControlPipeline, FluxTransformer2DModel | |
| #################################### | |
| # Load the model(s) on GPU # | |
| #################################### | |
| path = "sayakpaul/FLUX.1-dev-edit-v0" | |
| edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16) | |
| pipeline = FluxControlPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| pipeline.load_lora_weights( | |
| hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" | |
| ) | |
| pipeline.set_adapters(["hyper-sd"], adapter_weights=[0.125]) | |
| ##################################### | |
| # The function for our Gradio app # | |
| ##################################### | |
| def generate(prompt, input_image): | |
| """ | |
| Runs the Flux Control pipeline for editing the given `input_image` | |
| with the specified `prompt`. The pipeline is on CPU by default. | |
| """ | |
| # Perform inference | |
| output_image = pipeline( | |
| control_image=image, | |
| prompt=prompt, | |
| guidance_scale=30., | |
| num_inference_steps=8, | |
| max_sequence_length=512, | |
| height=image.height, | |
| width=image.width, | |
| generator=torch.manual_seed(0) | |
| ).images[0] | |
| return output_image | |
| def launch_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Flux Control Editing ποΈ | |
| This demo uses the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) | |
| pipeline with an edit transformer from [Sayak Paul](https://huggingface.co/sayakpaul). | |
| **Acknowledgements**: | |
| - [Sayak Paul](https://huggingface.co/sayakpaul) for open-sourcing FLUX.1-dev-edit-v0 | |
| - [black-forest-labs](https://huggingface.co/black-forest-labs) for FLUX.1-dev | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="e.g. 'Edit a certain thing in the image'" | |
| ) | |
| input_image = gr.Image( | |
| label="Image", | |
| type="pil", | |
| ) | |
| generate_button = gr.Button("Generate") | |
| output_image = gr.Image(label="Edited Image") | |
| # Connect button to function | |
| generate_button.click( | |
| fn=generate, | |
| inputs=[prompt, input_image], | |
| outputs=[output_image], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Turn the color of the mushroom to gray", "mushroom.jpg"], | |
| ["Make the mushroom polka-dotted", "mushroom.jpg"], | |
| ], | |
| inputs=[prompt, input_image], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = launch_app() | |
| demo.launch() |