File size: 2,188 Bytes
14c80dd
 
 
 
 
39c57ff
14c80dd
 
 
39c57ff
 
 
 
 
14c80dd
39c57ff
14c80dd
 
90b32b2
14c80dd
 
 
 
39c57ff
14c80dd
 
39c57ff
14c80dd
39c57ff
 
 
 
 
 
 
 
14c80dd
39c57ff
14c80dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c57ff
14c80dd
 
 
39c57ff
14c80dd
 
 
39c57ff
8cf2a80
39c57ff
 
 
 
 
 
 
 
14c80dd
 
39c57ff
14c80dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import gradio as gr
import torch
from monai import bundle

# Set the bundle name and download path
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

# Set up some examples from the test set for better user experience
examples = [
    ['examples/spleen_1.nii.gz', 50],
    ['examples/spleen_11.nii.gz', 50],
]

# Load the pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/spleen_ct_segmentation_v0.1.0',
    load_ts_module=True,
)

# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load transforms and inferer directly from the bundle
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = parser.get_parsed_content(
    'preprocessing',
    lazy=True, eval_expr=True,instantiate=True
)
inferer = parser.get_parsed_content(
    'inferer',
    lazy=True, eval_expr=True, instantiate=True
)

# Define the prediction function
def predict(input_file, z_axis, model=model, device=device):
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)[None,...]
        data['pred'] = inferer(inputs=inputs, network=model)
    
    input_image = data['image'].numpy()
    pred_image = torch.argmax(data['pred'], dim=1).cpu().detach().numpy()
    
    return input_image[0, :, :, z_axis], pred_image[0, :, :, z_axis]*255

# Set up the demo interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label='input file'),
        gr.Slider(0, 200, label='z-axis', value=50)
    ], 
    outputs=['image', 'image'],
    title='Segment the Spleen using MONAI!',
    description="""## 🚀 To run
    Upload a abdominal CT scan, or try one of the examples below!
    
    More details on the model can be found [here!](https://huggingface.co/katielink/spleen_ct_segmentation_v0.1.0)

    ## ⚠️ Disclaimer
    Not to be used for diagnostic purposes.
    """,
    examples=examples,
)

# Launch the demo
iface.launch()