Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from image_evaluators import LlamaEvaluator | |
| from prompt_refiners import LlamaPromptRefiner | |
| from weave_prompt import PromptOptimizer | |
| from similarity_metrics import LPIPSImageSimilarityMetric | |
| from image_generators import FalImageGenerator | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| st.set_page_config( | |
| page_title="WeavePrompt", | |
| page_icon="🎨", | |
| layout="wide" | |
| ) | |
| def main(): | |
| st.title("🎨 WeavePrompt: Iterative Prompt Optimization") | |
| st.markdown(""" | |
| Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it. | |
| """) | |
| # Initialize session state | |
| if 'optimizer' not in st.session_state: | |
| st.session_state.optimizer = PromptOptimizer( | |
| model=FalImageGenerator(), | |
| evaluator=LlamaEvaluator(), | |
| refiner=LlamaPromptRefiner(), | |
| similarity_metric=LPIPSImageSimilarityMetric(), | |
| max_iterations=10, | |
| similarity_threshold=0.95 | |
| ) | |
| if 'optimization_started' not in st.session_state: | |
| st.session_state.optimization_started = False | |
| if 'current_results' not in st.session_state: | |
| st.session_state.current_results = None | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg']) | |
| if uploaded_file is not None: | |
| # Display target image | |
| target_image = Image.open(uploaded_file) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Target Image") | |
| st.image(target_image, width='stretch') | |
| # Start button | |
| if not st.session_state.optimization_started: | |
| if st.button("Start Optimization"): | |
| st.session_state.optimization_started = True | |
| # Initialize optimization | |
| is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image) | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| # Display optimization progress | |
| if st.session_state.optimization_started: | |
| with col2: | |
| st.subheader("Generated Image") | |
| is_completed, prompt, generated_image = st.session_state.current_results | |
| st.image(generated_image, width='stretch') | |
| # Display prompt and controls | |
| st.text_area("Current Prompt", prompt, height=100) | |
| # Progress metrics | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Iteration", len(st.session_state.optimizer.history)) | |
| with col2: | |
| if len(st.session_state.optimizer.history) > 0: | |
| similarity = st.session_state.optimizer.history[-1]['similarity'] | |
| st.metric("Similarity", f"{similarity:.2%}") | |
| with col3: | |
| st.metric("Status", "Completed" if is_completed else "In Progress") | |
| # Next step button | |
| if not is_completed: | |
| if st.button("Next Step"): | |
| is_completed, prompt, generated_image = st.session_state.optimizer.step() | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| st.rerun() | |
| else: | |
| st.success("Optimization completed! Click 'Reset' to try another image.") | |
| # Reset button | |
| if st.button("Reset"): | |
| st.session_state.optimization_started = False | |
| st.session_state.current_results = None | |
| st.rerun() | |
| # Display history | |
| if len(st.session_state.optimizer.history) > 0: | |
| st.subheader("Optimization History") | |
| for idx, hist_entry in enumerate(st.session_state.optimizer.history): | |
| st.markdown(f"### Step {idx + 1}") | |
| col1, col2 = st.columns([2, 3]) | |
| with col1: | |
| st.image(hist_entry['image'], width='stretch') | |
| with col2: | |
| st.text(f"Similarity: {hist_entry['similarity']:.2%}") | |
| st.text("Prompt:") | |
| st.text(hist_entry['prompt']) | |
| st.text("\nAnalysis:") | |
| for key, value in hist_entry['analysis'].items(): | |
| st.text(f"{key}: {value}") | |
| if __name__ == "__main__": | |
| main() |