Spaces:
Runtime error
Runtime error
| import glob | |
| import tempfile | |
| from decimal import Decimal | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import gradio as gr | |
| from PIL import Image | |
| import open_clip | |
| import torch | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from gradio import processing_utils, utils | |
| from download_example_images import read_actor_files, save_images_to_folder | |
| DEFAULT_INITIAL_NAME = "John Doe" | |
| PROMPTS = [ | |
| '{0}', | |
| 'an image of {0}', | |
| 'a photo of {0}', | |
| '{0} on a photo', | |
| 'a photo of a person named {0}', | |
| 'a person named {0}', | |
| 'a man named {0}', | |
| 'a woman named {0}', | |
| 'the name of the person is {0}', | |
| 'a photo of a person with the name {0}', | |
| '{0} at a gala', | |
| 'a photo of the celebrity {0}', | |
| 'actor {0}', | |
| 'actress {0}', | |
| 'a colored photo of {0}', | |
| 'a black and white photo of {0}', | |
| 'a cool photo of {0}', | |
| 'a cropped photo of {0}', | |
| 'a cropped image of {0}', | |
| '{0} in a suit', | |
| '{0} in a dress' | |
| ] | |
| OPEN_CLIP_LAION400M_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14'] | |
| OPEN_CLIP_LAION2B_MODEL_NAMES = [('ViT-B-32', 'laion2b_s34b_b79k'), ('ViT-L-14', 'laion2b_s32b_b82k')] | |
| OPEN_AI_MODELS = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14'] | |
| NUM_TOTAL_NAMES = 1_000 | |
| SEED = 42 | |
| MIN_NUM_CORRECT_PROMPT_PREDS = 1 | |
| EDAMPLE_IMAGE_DIR = './example_images/' | |
| IMG_BATCHSIZE = 16 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| EXAMPLE_IMAGE_URLS = read_actor_files(EDAMPLE_IMAGE_DIR) | |
| save_images_to_folder(os.path.join(EDAMPLE_IMAGE_DIR, 'images'), EXAMPLE_IMAGE_URLS) | |
| MODELS = {} | |
| for model_name in OPEN_CLIP_LAION400M_MODEL_NAMES: | |
| dataset = 'LAION400M' | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| model_name, | |
| pretrained=f'{dataset.lower()}_e32' | |
| ) | |
| model = model.eval() | |
| MODELS[f'OpenClip {model_name} trained on {dataset}'] = { | |
| 'model_instance': model, | |
| 'preprocessing': preprocess, | |
| 'model_name': model_name, | |
| 'tokenizer': open_clip.get_tokenizer(model_name), | |
| 'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt') | |
| } | |
| for model_name, dataset_name in OPEN_CLIP_LAION2B_MODEL_NAMES: | |
| dataset = 'LAION2B' | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| model_name, | |
| pretrained=dataset_name | |
| ) | |
| model = model.eval() | |
| MODELS[f'OpenClip {model_name} trained on {dataset}'] = { | |
| 'model_instance': model, | |
| 'preprocessing': preprocess, | |
| 'model_name': model_name, | |
| 'tokenizer': open_clip.get_tokenizer(model_name), | |
| 'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt') | |
| } | |
| for model_name in OPEN_AI_MODELS: | |
| dataset = 'OpenAI' | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| model_name, | |
| pretrained=dataset.lower() | |
| ) | |
| model = model.eval() | |
| MODELS[f'OpenClip {model_name} trained by {dataset}'] = { | |
| 'model_instance': model, | |
| 'preprocessing': preprocess, | |
| 'model_name': model_name, | |
| 'tokenizer': open_clip.get_tokenizer(model_name), | |
| 'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt') | |
| } | |
| FULL_NAMES_DF = pd.read_csv('full_names.csv', index_col=0) | |
| LAION_MEMBERSHIP_OCCURENCE = pd.read_csv('laion_membership_occurence_count.csv', index_col=0) | |
| EXAMPLE_ACTORS_BY_MODEL = { | |
| ("ViT-B-32", "laion400m"): ["T._J._Thyne"], | |
| ("ViT-B-16", "laion400m"): ["Barbara_SchΓΆneberger", "Carolin_Kebekus"], | |
| ("ViT-L-14", "laion400m"): ["Max_Giermann", "Nicole_De_Boer"] | |
| } | |
| EXAMPLES = [] | |
| for (model_name, dataset_name), person_names in EXAMPLE_ACTORS_BY_MODEL.items(): | |
| for name in person_names: | |
| image_folder = os.path.join("./example_images/images/", name) | |
| for dd_model_name in MODELS.keys(): | |
| if not (model_name.lower() in dd_model_name.lower() and dataset_name.lower() in dd_model_name.lower()): | |
| continue | |
| EXAMPLES.append([ | |
| dd_model_name, | |
| name.replace("_", " "), | |
| [[x.format(name.replace("_", " ")) for x in PROMPTS]], | |
| [os.path.join(image_folder, x) for x in os.listdir(image_folder)] | |
| ]) | |
| LICENSE_DETAILS = """ | |
| See [README.md](https://huggingface.co/spaces/AIML-TUDA/does-clip-know-my-face/blob/main/README.md) for more information about the licenses of the example images. | |
| """ | |
| CORRECT_RESULT_INTERPRETATION = """<br> | |
| <h2>{0} is in the Training Data!</h2> | |
| The name of {0} has been <b>correctly predicted for {1} out of {2} prompts.</b> This means that <b>{0} was in | |
| the training data and was used to train the model.</b> | |
| Keep in mind that the probability of correctly predicting the name for {3} by chance {4} times with {5} possible names for the model to | |
| choose from, is only (<sup>1</sup> ⁄ <sub>{5}</sub>)<sup>{6}</sup> = {7}%. | |
| """ | |
| INDECISIVE_RESULT_INTERPRETATION = """<br> | |
| <h2>{0} might be in the Training Data!</h2> | |
| For none of the {1} prompts the majority vote for the name of {0} was correct. However, while the majority votes are not | |
| correct, the name of {0} was correctly predicted {2} times for {3}. This is an indication that the model has seen {0} | |
| during training. A different selection of images might have a clearer result. Keep in mind that the probability | |
| that the name is correctly predicted by chance {2} times for {3} is | |
| (<sup>1</sup> ⁄ <sub>{4}</sub>)<sup>{2}</sup> = {5}%. | |
| """ | |
| INCORRECT_RESULT_INTERPRETATION = """<br> | |
| <h2>{0} is most likely not in the Training Data!</h2> | |
| The name of {0} has not been correctly predicted for any of the {1} prompts. This is an indication that {0} has | |
| most likely not been used for training the model. | |
| """ | |
| OCCURENCE_INFORMATION = """<br><br> | |
| According to our analysis {0} appeared {1} times among 400 million image-text pairs in the LAION-400M training dataset. | |
| """ | |
| CSS = """ | |
| .footer { | |
| margin-bottom: 45px; | |
| margin-top: 35px; | |
| text-align: center; | |
| border-bottom: 1px solid #e5e5e5; | |
| } | |
| #file_upload { | |
| max-height: 250px; | |
| overflow-y: auto !important; | |
| } | |
| .footer>p { | |
| font-size: .8rem; | |
| display: inline-block; | |
| padding: 0 10px; | |
| transform: translateY(10px); | |
| background: white; | |
| } | |
| .dark .footer { | |
| border-color: #303030; | |
| } | |
| .dark .footer>p { | |
| background: #0b0f19; | |
| } | |
| .acknowledgments h4{ | |
| margin: 1.25em 0 .25em 0; | |
| font-weight: bold; | |
| font-size: 115%; | |
| } | |
| """ | |
| # monkey patch the update function of the Files component since otherwise it is not possible to access the original | |
| # file name | |
| def preprocess( | |
| self, x: List[Dict[str, Any]] | None | |
| ) -> bytes | tempfile._TemporaryFileWrapper | List[ | |
| bytes | tempfile._TemporaryFileWrapper | |
| ] | None: | |
| """ | |
| Parameters: | |
| x: List of JSON objects with filename as 'name' property and base64 data as 'data' property | |
| Returns: | |
| File objects in requested format | |
| """ | |
| if x is None: | |
| return None | |
| def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: | |
| file_name, orig_name, data, is_file = ( | |
| f["name"] if "name" in f.keys() else f["orig_name"], | |
| f["orig_name"] if "orig_name" in f.keys() else f["name"], | |
| f["data"], | |
| f.get("is_file", False), | |
| ) | |
| if self.type == "file": | |
| if is_file: | |
| temp_file_path = self.make_temp_copy_if_needed(file_name) | |
| file = tempfile.NamedTemporaryFile(delete=False) | |
| file.name = temp_file_path | |
| file.orig_name = os.path.basename(orig_name.replace(self.hash_file(file_name), "")) # type: ignore | |
| else: | |
| file = processing_utils.decode_base64_to_file( | |
| data, file_path=file_name | |
| ) | |
| file.orig_name = file_name # type: ignore | |
| self.temp_files.add(str(utils.abspath(file.name))) | |
| return file | |
| elif ( | |
| self.type == "binary" or self.type == "bytes" | |
| ): # "bytes" is included for backwards compatibility | |
| if is_file: | |
| with open(file_name, "rb") as file_data: | |
| return file_data.read() | |
| return processing_utils.decode_base64_to_binary(data)[0] | |
| else: | |
| raise ValueError( | |
| "Unknown type: " | |
| + str(self.type) | |
| + ". Please choose from: 'file', 'bytes'." | |
| ) | |
| if self.file_count == "single": | |
| if isinstance(x, list): | |
| return process_single_file(x[0]) | |
| else: | |
| return process_single_file(x) | |
| else: | |
| if isinstance(x, list): | |
| return [process_single_file(f) for f in x] | |
| else: | |
| return process_single_file(x) | |
| gr.Files.preprocess = preprocess | |
| def calculate_text_embeddings(model_name, prompts): | |
| tokenizer = MODELS[model_name]['tokenizer'] | |
| context_vecs = tokenizer(prompts) | |
| model_instance = MODELS[model_name]['model_instance'] | |
| model_instance = model_instance.to(DEVICE) | |
| context_vecs = context_vecs.to(DEVICE) | |
| text_features = model_instance.encode_text(context_vecs, normalize=True).cpu() | |
| model_instance = model_instance.cpu() | |
| context_vecs = context_vecs.cpu() | |
| return text_features | |
| def calculate_image_embeddings(model_name, images): | |
| preprocessing = MODELS[model_name]['preprocessing'] | |
| model_instance = MODELS[model_name]['model_instance'] | |
| # load the given images | |
| user_imgs = [] | |
| for tmp_file_img in images: | |
| img = Image.open(tmp_file_img.name) | |
| # preprocess the images | |
| user_imgs.append(preprocessing(img)) | |
| # calculate the image embeddings | |
| image_embeddings = [] | |
| model_instance = model_instance.to(DEVICE) | |
| for batch_idx in range(0, len(user_imgs), IMG_BATCHSIZE): | |
| imgs = user_imgs[batch_idx:batch_idx + IMG_BATCHSIZE] | |
| imgs = torch.stack(imgs) | |
| imgs = imgs.to(DEVICE) | |
| emb = model_instance.encode_image(imgs, normalize=True).cpu() | |
| image_embeddings.append(emb) | |
| imgs = imgs.cpu() | |
| model_instance = model_instance.cpu() | |
| return torch.cat(image_embeddings) | |
| def get_possible_names(true_name): | |
| possible_names = FULL_NAMES_DF | |
| possible_names['full_names'] = FULL_NAMES_DF['first_name'].astype(str) + ' ' + FULL_NAMES_DF['last_name'].astype( | |
| str) | |
| possible_names = possible_names[possible_names['full_names'] != true_name] | |
| # sample the same amount of male and female names | |
| sampled_names = possible_names.groupby('sex').sample(int(NUM_TOTAL_NAMES / 2), random_state=42) | |
| # shuffle the rows randomly | |
| sampled_names = sampled_names.sample(frac=1) | |
| # get only the full names since we don't need first and last name and gender anymore | |
| possible_full_names = sampled_names['full_names'] | |
| return possible_full_names | |
| def round_to_first_digit(value: Decimal): | |
| tmp = np.format_float_positional(value) | |
| prob_str = [] | |
| for c in str(tmp): | |
| if c in ("0", "."): | |
| prob_str.append(c) | |
| else: | |
| prob_str.append(c) | |
| break | |
| return "".join(prob_str) | |
| def get_majority_predictions(predictions: pd.Series, values_only=False, counts_only=False, value=None): | |
| """Takes a series of predictions and returns the unique values and the number of prediction occurrences | |
| in descending order.""" | |
| values, counts = np.unique(predictions, return_counts=True) | |
| descending_counts_indices = counts.argsort()[::-1] | |
| values, counts = values[descending_counts_indices], counts[descending_counts_indices] | |
| idx_most_often_pred_names = np.argwhere(counts == counts.max()).flatten() | |
| if values_only: | |
| return values[idx_most_often_pred_names] | |
| elif counts_only: | |
| return counts[idx_most_often_pred_names] | |
| elif value is not None: | |
| if value not in values: | |
| return [0] | |
| # return how often the values appears in the predictions | |
| return counts[np.where(values == value)[0]] | |
| else: | |
| return values[idx_most_often_pred_names], counts[idx_most_often_pred_names] | |
| def on_submit_btn_click(model_name, true_name, prompts, images): | |
| # assert that the name is in the prompts | |
| if not prompts.iloc[0].str.contains(true_name).sum() == len(prompts.T): | |
| return None, None, """<br> | |
| <div class="error-message" style="background-color: #fce4e4; border: 1px solid #fcc2c3; padding: 20px 30px; border-radius: var(--radius-lg);"> | |
| <span class="error-text" style="color: #cc0033; font-weight: bold;"> | |
| The given name does not match the name in the prompts. Sometimes the UI is responding slow. | |
| Please retype the name and check that it is inserted fully into the prompts. | |
| </span> | |
| </div> | |
| """ | |
| if images is None or len(images) < 1: | |
| return None, None, f"""<br> | |
| <div class="error-message" style="background-color: #fce4e4; border: 1px solid #fcc2c3; padding: 20px 30px; border-radius: var(--radius-lg);"> | |
| <span class="error-text" style="color: #cc0033; font-weight: bold;"> | |
| No images are given. Images are needed to determin whether {true_name} was in the dataset. Please upload at least a single image of {true_name}. | |
| </span> | |
| </div> | |
| """ | |
| # calculate the image embeddings | |
| img_embeddings = calculate_image_embeddings(model_name, images) | |
| # calculate the text embeddings of the populated prompts | |
| user_text_emb = calculate_text_embeddings(model_name, prompts.values[0].tolist()) | |
| # get the indices of the possible names | |
| possible_names = get_possible_names(true_name) | |
| # get the text embeddings of the possible names | |
| prompt_text_embeddings = MODELS[model_name]['prompt_text_embeddings'] | |
| text_embeddings_used_for_prediction = prompt_text_embeddings.index_select(1, | |
| torch.tensor(possible_names.index.values)) | |
| # add the true name and the text embeddings to the possible names | |
| names_used_for_prediction = pd.concat([possible_names, pd.Series(true_name)], ignore_index=True) | |
| text_embeddings_used_for_prediction = torch.cat([text_embeddings_used_for_prediction, user_text_emb.unsqueeze(1)], | |
| dim=1) | |
| # calculate the similarity of the images and the given texts | |
| with torch.no_grad(): | |
| logits_per_image = MODELS[model_name][ | |
| 'model_instance' | |
| ].logit_scale.exp().cpu() * img_embeddings @ text_embeddings_used_for_prediction.swapaxes(-1, -2) | |
| preds = logits_per_image.argmax(-1) | |
| # get the predicted names for each prompt | |
| predicted_names = [] | |
| for pred in preds: | |
| predicted_names.append(names_used_for_prediction.iloc[pred]) | |
| predicted_names = np.array(predicted_names) | |
| # convert the predictions into a dataframe | |
| name_predictions = pd.DataFrame(predicted_names).T.reset_index().rename( | |
| columns={i: f'Prompt {i + 1}' for i in range(len(predicted_names))} | |
| ).rename(columns={'index': 'Image'}) | |
| # add the image names | |
| name_predictions['Image'] = [x.orig_name for x in images] | |
| # get the majority votes | |
| majority_preds = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
| lambda x: get_majority_predictions(x, values_only=True) | |
| ) | |
| # get how often the majority name was predicted | |
| majority_preds_counts = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
| lambda x: get_majority_predictions(x, counts_only=True) | |
| ).apply(lambda x: x[0]) | |
| # get how often the correct name was predicted - even if no majority | |
| true_name_preds_counts = name_predictions[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].apply( | |
| lambda x: get_majority_predictions(x, value=true_name) | |
| ).apply(lambda x: x[0]) | |
| # convert the majority preds to a series of lists if it is a dataframe | |
| majority_preds = majority_preds.T.squeeze().apply(lambda x: [x]) if len(majority_preds) == 1 else majority_preds | |
| # create the results dataframe for display | |
| result = pd.concat( | |
| [name_predictions, | |
| pd.concat([pd.Series({'Image': 'Correct Name Predictions'}), true_name_preds_counts]).to_frame().T], | |
| ignore_index=True | |
| ) | |
| result = pd.concat( | |
| [result, pd.concat([pd.Series({'Image': 'Majority Vote'}), majority_preds]).to_frame().T], | |
| ignore_index=True | |
| ) | |
| result = pd.concat( | |
| [result, pd.concat([pd.Series({'Image': 'Majority Vote Counts'}), majority_preds_counts]).to_frame().T], | |
| ignore_index=True | |
| ) | |
| result = result.set_index('Image') | |
| # check whether there is only one majority vote. If not, display Not Applicable | |
| result.loc['Majority Vote'] = result.loc['Majority Vote'].apply( | |
| lambda x: x[0] if len(x) == 1 else "N/A") | |
| # check whether the majority prediction is the correct name | |
| result.loc['Correct Majority Prediction'] = result.apply(lambda x: x['Majority Vote'] == true_name, axis=0) | |
| result = result[[f'Prompt {i + 1}' for i in range(len(PROMPTS))]].sort_values( | |
| ['Correct Name Predictions', 'Majority Vote Counts', "Correct Majority Prediction"], axis=1, ascending=False | |
| ) | |
| predictions = result.loc[[x.orig_name for x in images]] | |
| prediction_results = result.loc[['Correct Name Predictions', 'Majority Vote', 'Correct Majority Prediction']] | |
| # if there are correct predictions | |
| num_correct_maj_preds = prediction_results.loc['Correct Majority Prediction'].sum() | |
| num_correct_name_preds = result.loc['Correct Name Predictions'].max() | |
| if num_correct_maj_preds > 0: | |
| interpretation = CORRECT_RESULT_INTERPRETATION.format( | |
| true_name, | |
| num_correct_maj_preds, | |
| len(PROMPTS), | |
| prediction_results.columns[0], | |
| prediction_results.iloc[0, 0], | |
| len(possible_names), | |
| predictions.iloc[:, 0].value_counts()[true_name], | |
| round_to_first_digit( | |
| ( | |
| (Decimal(1) / Decimal(len(possible_names))) ** predictions.iloc[:, 0].value_counts()[true_name] | |
| ) * Decimal(100) | |
| ) | |
| ) | |
| elif num_correct_name_preds > 0: | |
| interpretation = INDECISIVE_RESULT_INTERPRETATION.format( | |
| true_name, | |
| len(PROMPTS), | |
| num_correct_name_preds, | |
| prediction_results.columns[result.loc['Correct Name Predictions'].to_numpy().argmax()], | |
| len(possible_names), | |
| round_to_first_digit( | |
| ( | |
| (Decimal(1) / Decimal(len(possible_names))) ** Decimal(num_correct_name_preds) | |
| ) * Decimal(100) | |
| ) | |
| ) | |
| else: | |
| interpretation = INCORRECT_RESULT_INTERPRETATION.format( | |
| true_name, | |
| len(PROMPTS) | |
| ) | |
| if 'laion400m' in model_name.lower() and true_name.lower() in LAION_MEMBERSHIP_OCCURENCE['name'].str.lower().values: | |
| row = LAION_MEMBERSHIP_OCCURENCE[LAION_MEMBERSHIP_OCCURENCE['name'].str.lower() == true_name.lower()] | |
| interpretation = interpretation + OCCURENCE_INFORMATION.format(true_name, row['count'].values[0]) | |
| return predictions.reset_index(), prediction_results.reset_index(names=[""]), interpretation | |
| def populate_prompts(name): | |
| return [[x.format(name) for x in PROMPTS]] | |
| def load_uploaded_imgs(images): | |
| if images is None: | |
| return None | |
| imgs = [] | |
| for file_wrapper in images: | |
| img = Image.open(file_wrapper.name) | |
| imgs.append((img, file_wrapper.orig_name)) | |
| return imgs | |
| block = gr.Blocks(css=CSS) | |
| with block as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 750px; margin: 0 auto;"> | |
| <div> | |
| <img | |
| class="logo" | |
| src="https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1666181274838-62fa1d95e8c9c532aa75331c.png" | |
| alt="AIML Logo" | |
| style="margin: auto; max-width: 7rem;" | |
| > | |
| <h1 style="font-weight: 900; font-size: 3rem;"> | |
| Does CLIP Know My Face? | |
| </h1> | |
| </div> | |
| <p style="margin-bottom: 10px; font-size: 94%"> | |
| Want to know whether you were used to train a CLIP model? Below you can choose a model, enter your name and upload some pictures. | |
| If the model correctly predicts your name for multiple images, it is very likely that you were part of the training data. | |
| Pick some of the examples below and try it out!<br><br> | |
| Details and further analysis can be found in the paper | |
| <a href="https://arxiv.org/abs/2209.07341" style="text-decoration: underline;" target="_blank"> | |
| Does CLIP Know My Face? | |
| </a>. Our code can be found at | |
| <a href="https://github.com/D0miH/does-clip-know-my-face" style="text-decoration: underline;" target="_blank"> | |
| GitHub | |
| </a>. | |
| <br><br> | |
| <b>How does it work?</b> We are giving CLIP your images and let it choose from 1000 possible names. | |
| As CLIP is predicting the names that match the given images, we can probe whether the model has seen your images | |
| during training. The more images you upload the more confident you can be in the result! | |
| <br><br> | |
| <b>Disclaimer:</b> In order to process the images, they are cached on the server. The images are only used for predicting whether the person was in the training data. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Box(): | |
| gr.Markdown("## Inputs") | |
| with gr.Column(): | |
| model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0]) | |
| true_name = gr.Textbox(label='Name of Person (make sure it matches the prompts):', lines=1, value=DEFAULT_INITIAL_NAME) | |
| prompts = gr.Dataframe( | |
| value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]], | |
| label='Prompts Used (hold shift to scroll sideways):', | |
| interactive=False | |
| ) | |
| true_name.change(fn=populate_prompts, inputs=[true_name], outputs=prompts, show_progress=True, | |
| status_tracker=None) | |
| uploaded_imgs = gr.Files(label='Upload Images:', file_types=['image'], elem_id='file_upload').style() | |
| image_gallery = gr.Gallery(label='Images Used:', show_label=True, elem_id="image_gallery").style(grid=[5]) | |
| uploaded_imgs.change(load_uploaded_imgs, inputs=uploaded_imgs, outputs=image_gallery) | |
| submit_btn = gr.Button(value='Submit') | |
| with gr.Box(): | |
| gr.Markdown("## Outputs") | |
| prediction_df = gr.Dataframe(label="Prediction Output (hold shift to scroll sideways):", interactive=False) | |
| result_df = gr.DataFrame(label="Result (hold shift to scroll sideways):", interactive=False) | |
| interpretation = gr.HTML() | |
| submit_btn.click(on_submit_btn_click, inputs=[model_dd, true_name, prompts, uploaded_imgs], | |
| outputs=[prediction_df, result_df, interpretation]) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[model_dd, true_name, prompts, uploaded_imgs], | |
| outputs=[prediction_df, result_df, interpretation], | |
| fn=on_submit_btn_click, | |
| cache_examples=True | |
| ) | |
| gr.Markdown(LICENSE_DETAILS) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <p> Gradio Demo by AIML@TU Darmstadt</p> | |
| </div> | |
| <div class="acknowledgments"> | |
| <p>Created by <a href="https://www.ml.informatik.tu-darmstadt.de/people/dhintersdorf/">Dominik Hintersdorf</a> at <a href="https://www.aiml.informatik.tu-darmstadt.de">AIML Lab</a>.</p> | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |