dreibh's picture
Minor clean-ups.
ae6d4dd verified
raw
history blame
28.6 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ==========================================================================
# ____ __ _ _____ ____ ____
# | _ \ ___ ___ _ __ / _| __ _| | _____ | ____/ ___/ ___|
# | | | |/ _ \/ _ \ '_ \| |_ / _` | |/ / _ \ | _|| | | | _
# | |_| | __/ __/ |_) | _| (_| | < __/ | |__| |__| |_| |
# |____/ \___|\___| .__/|_| \__,_|_|\_\___| |_____\____\____|
# |_|
#
# --- Deepfake ECG Generator ---
# https://github.com/vlbthambawita/deepfake-ecg
# ==========================================================================
#
# DeepfakeECG GUI Application
# Copyright (C) 2023-2025 by Vajira Thambawita
# Copyright (C) 2025 by Thomas Dreibholz
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Contact:
# * Vajira Thambawita <[email protected]>
# * Thomas Dreibholz <[email protected]>
import datetime
import deepfakeecg
import ecg_plot
import getopt
import gradio
import io
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker
import neurokit2
import numpy
import pathlib
import random
import sys
import tempfile
import threading
import torch
import typing
import version
import PIL
import PIL.Image
from typing import Any, Final
# ###### Print log message ##################################################
def log(logstring : str) -> None:
print(('\x1b[34m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') +
': ' + logstring + '\x1b[0m'));
# ###### DeepFakeECG Plus Session (session with web browser) ################
class Session:
# ###### Constructor #####################################################
def __init__(self) -> None:
self.Lock = threading.Lock()
self.Counter : int = 0
self.Selected : int = 0
self.Results : list[Any] = [ ]
self.Type = None
self.TempDirectory = tempfile.TemporaryDirectory(dir = TempDirectory.name)
log(f'Prepared temporary directory {self.TempDirectory.name}')
# ###### Destructor ######################################################
def __del__(self) -> None:
log(f'Cleaning up temporary directory {self.TempDirectory.name}')
self.TempDirectory.cleanup()
TempDirectory : tempfile.TemporaryDirectory[Any]
Sessions : dict[str,Session] = { }
# ###### Initialize a new session ###########################################
def initializeSession(request: gradio.Request) -> None:
Sessions[request.session_hash] = Session()
log(f'Session "{request.session_hash}" initialized')
# ###### Clean up a session #################################################
def cleanUpSession(request: gradio.Request) -> None:
if request.session_hash in Sessions:
del Sessions[request.session_hash]
log(f'Session "{request.session_hash}" cleaned up')
# ###### Generate ECGs ######################################################
def predict(numberOfECGs: int = 1,
# ecgLengthInSeconds: int = 10,
ecgTypeString: str = 'ECG-12',
generatorModel: str = 'Default',
request: gradio.Request = None) -> list[tuple[PIL.Image.Image,str]]:
ecgLengthInSeconds = 10
log(f'Session "{request.session_hash}": Generate EGCs!')
# ====== Set ECG type ====================================================
ecgType = deepfakeecg.DATA_ECG12
if ecgTypeString == 'ECG-8':
ecgType = deepfakeecg.DATA_ECG8
elif ecgTypeString == 'ECG-12':
ecgType = deepfakeecg.DATA_ECG12
else:
sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n')
# ====== Raise Locator.MAXTICKS, if necessary ============================
matplotlib.ticker.Locator.MAXTICKS = \
max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE)
# print(matplotlib.ticker.Locator.MAXTICKS)
# ====== Generate the ECGs ===============================================
Sessions[request.session_hash].Results = \
deepfakeecg.generateDeepfakeECGs(numberOfECGs,
ecgType = ecgType,
ecgLengthInSeconds = ecgLengthInSeconds,
ecgScaleFactor = deepfakeecg.ECG_DEFAULT_SCALE_FACTOR,
outputFormat = deepfakeecg.OUTPUT_TENSOR,
showProgress = False,
runOnDevice = runOnDevice)
Sessions[request.session_hash].Type = ecgType
# ====== Create a list of image/label tuples for gradio.Gallery ==========
plotList : list[tuple[PIL.Image.Image,str]] = [ ]
ecgNumber : int = 1
info : Final[str] = '25 mm/sec, 1 mV/10 mm'
for result in Sessions[request.session_hash].Results:
# ====== Plot ECG =====================================================
# 1. Convert to NumPy
# 2. Remove the Timestamp column (0)
# 3. Convert from µV to mV
result = result.t().detach().cpu().numpy()[1:] / 1000
# print(result)
# ------ ECG-12 -------------------------------------------------------
if ecgType == deepfakeecg.DATA_ECG12:
ecg_plot.plot(result,
title = 'ECG-12 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ],
lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ------ ECG-8 --------------------------------------------------------
else:
ecg_plot.plot(result,
title = 'ECG-8 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ],
lead_order = [0, 1, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ====== Generate WebP output =========================================
imageBuffer = io.BytesIO()
plt.savefig(imageBuffer, format = 'webp')
plt.close()
image : PIL.Image.Image = PIL.Image.open(imageBuffer)
plotList.append( (image, f'ECG Number {ecgNumber}') )
ecgNumber = ecgNumber + 1
return plotList
# ###### Generic download ###################################################
def download(request: gradio.Request,
outputFormat: int) -> pathlib.Path | None:
if outputFormat == deepfakeecg.OUTPUT_CSV:
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv')
deepfakeecg.dataToCSV(ecgResult, ecgType, fileName)
log(f'Session "{request.session_hash}": Download CSV file {fileName}')
return fileName
elif ( (outputFormat == deepfakeecg.OUTPUT_PDF) or
(outputFormat == deepfakeecg.OUTPUT_PDF_ANALYSIS) ):
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf')
if ecgType == deepfakeecg.DATA_ECG12:
outputLeads = [ 'I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
else:
outputLeads = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
deepfakeecg.dataToPDF(ecgResult, ecgType, outputLeads, fileName, outputFormat,
Sessions[request.session_hash].Selected + 1)
log(f'Session "{request.session_hash}": Download PDF file {fileName}')
return fileName
return None
# ###### Download CSV #######################################################
def downloadCSV(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_CSV)
# ###### Download PDF #######################################################
def downloadPDF(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_PDF)
# ###### Download PDF #######################################################
def downloadPDFwithAnalysis(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_PDF_ANALYSIS)
# ###### Analyze the selected ECG ###########################################
def analyze(event: gradio.SelectData,
request: gradio.Request) -> matplotlib.figure.Figure:
Sessions[request.session_hash].Selected = event.index
log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!')
data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
data = data.t().detach().cpu().numpy()[1:] / 1000
leadI = data[0]
signals, info = neurokit2.ecg_process(leadI, sampling_rate = deepfakeecg.ECG_SAMPLING_RATE)
neurokit2.ecg_plot(signals, info)
# DIN A4 landscape: w=11.7, h=8.27
w = 508/25.4 # mm to inch
h = 122/25.4 # mm to inch
matplotlib.pyplot.gcf().set_size_inches(w, h, forward=True)
return matplotlib.pyplot.gcf()
# ###### Print usage and exit ###############################################
def usage(exitCode : int = 0) -> str:
sys.stdout.write('Usage: ' + sys.argv[0] + ' [-d|--device cpu|cuda] [-v|--version]\n')
sys.exit(exitCode)
# ###### Main program #######################################################
# ====== Initialise =========================================================
runOnDevice: str = 'cuda' if torch.cuda.is_available() else 'cpu'
css = r"""
div {
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png");
}
/* ###### General Settings ############################################## */
html, body {
height: 100%;
margin: 0;
padding: 0;
font-family: sans-serif;
font-size: small;
background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-wiehl.png");
}
/* ###### Header ######################################################## */
div.program-header {
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
height: 7.5vh;
display: flex;
justify-content: space-between;
}
div.program-logo-left {
width: 12.5vw;
float: left;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.program-logo-right {
width: 12.5vw;
float: right;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.program-title {
display: flex;
align-items: center;
padding: 0% 1%;
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
font-family: "Open Sans", sans-serif;
font-size: 4vh;
font-weight: bold;
}
img.program-logo-image {
min-height: 4vh;
max-height: 4vh;
margin-left: auto;
margin-right: auto;
}
"""
# ====== Check arguments ====================================================
try:
options, args = getopt.getopt(
sys.argv[1:],
'd:v',
[
'device=',
'version'
])
for option, optarg in options:
if option in ( '-d', '--device' ):
runOnDevice = optarg
elif option in ( '-v', '--version' ):
sys.stdout.write('PyTorch version: ' + torch.__version__ + '\n')
sys.stdout.write('CUDA version: ' + torch.version.cuda + '\n')
sys.stdout.write('CUDA available: ' + ('yes' if torch.cuda.is_available() else 'no') + '\n')
sys.stdout.write('Device: ' + runOnDevice + '\n')
sys.exit(1)
else:
sys.stderr.write('ERROR: Invalid option ' + option + '!\n')
sys.exit(1)
except getopt.GetoptError as error:
sys.stderr.write('ERROR: ' + str(error) + '\n')
usage(1)
if len(args) > 0:
usage(1)
# ====== Create GUI =========================================================
with gradio.Blocks(css = css, theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue),
fill_height = True, fill_width = True) as gui:
# ====== Session handling ================================================
# Session initialization, to be called when page is loaded
gui.load(initializeSession)
# Session clean-up, to be called when page is closed/refreshed
gui.unload(cleanUpSession)
# ====== Header ==========================================================
with gradio.Row(height = '10vh', min_height = '10vh', max_height = '10vh'):
big_block = gradio.HTML("""
<div class="program-header">
<div class="program-logo-left">
<img class="program-logo-image" src="" alt="SimulaMet" height="32" />
</div>
<div class="program-title" id="title"><a href="https://ihi-search.eu/">SEARCH</a>&nbsp;DeepFake ECG Generator v""" + version.DEEPFAKEECGGENPLUS_VERSION + """</div>
<div class="program-logo-right">
<img class="program-logo-image" src="" alt="NorNet" height="64" />
</div>
</div>
""")
# gradio.Markdown('## Settings')
with gradio.Row(height = '10vh', min_height = '10vh', max_height = '10vh'):
sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True)
# sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
with gradio.Column():
buttonGenerate = gradio.Button("Generate ECGs!")
# buttonAnalyze = gradio.Button("Analyze this ECG!")
with gradio.Row():
buttonCSV = gradio.DownloadButton("Download CSV")
buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden")
buttonPDF = gradio.DownloadButton("Download ECG PDF")
buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden")
buttonPDFwAnalysis = gradio.DownloadButton("Download ECG+Analysis PDF")
buttonPDFwAnalysis_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdfwanalysis_hidden")
# gradio.Markdown('## Output')
with gradio.Row(): # height = '24vh', min_height = '24vh', max_height = '24vh'):
outputGallery = gradio.Gallery(label = 'Generated ECGs',
columns = 8,
# rows = 1,
height = 'auto',
object_fit = 'contain',
show_label = True,
allow_preview = True,
preview = False
)
with gradio.Row(): # height = '24vh', min_height = '24vh', max_height = '24vh'):
analysisOutput = gradio.Plot(label = 'Analysis')
# ====== Add click event handling for "Generate" button ==================
buttonGenerate.click(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery ]
)
# ====== Add click event handling for "Analyze" button ===================
outputGallery.select(analyze,
inputs = [ ],
outputs = [ analysisOutput ]
)
# ====== Add click event handling for download buttons ===================
# Using hidden button and JavaScript, to generate download file on-the-fly:
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
buttonCSV.click(fn = downloadCSV,
inputs = None,
outputs = [ buttonCSV_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_csv_hidden').click()")
buttonPDF.click(fn = downloadPDF,
inputs = None,
outputs = [ buttonPDF_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_pdf_hidden').click()")
buttonPDFwAnalysis.click(fn = downloadPDFwithAnalysis,
inputs = None,
outputs = [ buttonPDFwAnalysis_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_pdfwanalysis_hidden').click()")
# ====== Run on startup ==================================================
gui.load(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery ]
)
# ====== Run the GUI ========================================================
if __name__ == "__main__":
# ------ Prepare temporary directory -------------------------------------
TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-')
log(f'Prepared temporary directory {TempDirectory.name}')
# ------ Run the GUI, with downloads from temporary directory allowed ----
gui.launch(allowed_paths = [ TempDirectory.name ])
# ------ Clean up --------------------------------------------------------
log(f'Cleaning up temporary directory {TempDirectory.name}')
TempDirectory.cleanup()
log('Done!')