DSatishchandra commited on
Commit
d96a71d
·
verified ·
1 Parent(s): b27384b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -1,37 +1,40 @@
1
  import torch
2
- from torchvision import models, transforms
3
  from torch import nn
 
4
  from PIL import Image
5
  import gradio as gr
6
  import numpy as np
7
 
8
- # Load the pre-trained model (ResNet50)
9
  model = models.resnet50(pretrained=True)
10
- model.fc = nn.Linear(model.fc.in_features, 2) # For binary classification (Thyroid: Positive/Negative)
11
 
12
- # Define image transformation for preprocessing
13
  transform = transforms.Compose([
14
- transforms.Resize((224, 224)),
15
  transforms.ToTensor(),
16
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
17
  ])
18
 
19
- # Function to classify the thyroid condition based on image input
20
- def classify_thyroid_image(image):
21
- # Convert the numpy array (from Gradio) to a Pillow image
22
- image = Image.fromarray(image.astype('uint8'), 'RGB')
23
-
24
- # Apply transformation (resize, normalize, etc.)
25
  image = transform(image).unsqueeze(0) # Add batch dimension
26
-
27
- model.eval() # Set the model to evaluation mode
28
 
 
29
  with torch.no_grad():
30
  output = model(image)
31
  _, predicted = torch.max(output, 1)
32
 
33
- diagnosis = "Thyroid Disease Detected" if predicted.item() == 1 else "No Thyroid Disease"
 
 
 
 
 
 
 
34
  return diagnosis
35
 
36
  # Create Gradio interface for image input
37
- gr.Interface(fn=classify_thyroid_image, inputs="image", outputs="text").launch()
 
1
  import torch
 
2
  from torch import nn
3
+ from torchvision import models, transforms
4
  from PIL import Image
5
  import gradio as gr
6
  import numpy as np
7
 
8
+ # Define model (pretrained ResNet50)
9
  model = models.resnet50(pretrained=True)
10
+ model.fc = nn.Linear(model.fc.in_features, 3) # 3 output classes: Normal, Hypothyroidism, Hyperthyroidism
11
 
12
+ # Define image transformations (resizing and normalization)
13
  transform = transforms.Compose([
14
+ transforms.Resize((224, 224)), # Resize image for input
15
  transforms.ToTensor(),
16
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
17
  ])
18
 
19
+ # Example classification function
20
+ def classify_thyroid_condition(image):
21
+ image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert numpy array to Pillow Image
 
 
 
22
  image = transform(image).unsqueeze(0) # Add batch dimension
 
 
23
 
24
+ model.eval() # Set the model to evaluation mode
25
  with torch.no_grad():
26
  output = model(image)
27
  _, predicted = torch.max(output, 1)
28
 
29
+ # Map prediction to class labels
30
+ if predicted.item() == 0:
31
+ diagnosis = "Normal"
32
+ elif predicted.item() == 1:
33
+ diagnosis = "Hypothyroidism"
34
+ else:
35
+ diagnosis = "Hyperthyroidism"
36
+
37
  return diagnosis
38
 
39
  # Create Gradio interface for image input
40
+ gr.Interface(fn=classify_thyroid_condition, inputs="image", outputs="text").launch()