Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import io | |
| import gradio as gr | |
| from transformers import AutoModel | |
| import ecg_plot | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch | |
| #pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog") | |
| model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True) | |
| def predict(): | |
| prediction = (model(1)[0].t()/1000) # to micro volts | |
| lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0) | |
| lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0) | |
| lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0) | |
| lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0) | |
| all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0) | |
| all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])] | |
| ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12') | |
| #ecg_plot.show() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| img = Image.open(buf) | |
| return img | |
| gr.Interface( | |
| predict, | |
| inputs=None, | |
| outputs="image", | |
| title="Generating Fake ECGs", | |
| ).launch() | |