dalybuilds commited on
Commit
1902761
·
verified ·
1 Parent(s): 22e05ad

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +55 -69
model_utils.py CHANGED
@@ -9,10 +9,14 @@ from scipy.special import softmax
9
  class BugClassifier:
10
  def __init__(self):
11
  # Initialize model and feature extractor
12
- self.model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
13
  self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
14
 
15
- # Define class labels (these would be replaced with your actual trained classes)
16
  self.labels = [
17
  "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
18
  "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
@@ -32,96 +36,78 @@ class BugClassifier:
32
  """,
33
  # Add more species information as needed
34
  }
 
 
 
35
 
36
  def predict(self, image):
37
  """
38
  Make a prediction on the input image
39
  Returns predicted class and confidence score
40
  """
41
- # Preprocess image
42
- if isinstance(image, Image.Image):
43
- image_tensor = self.preprocess_image(image)
44
- else:
45
- raise ValueError("Input must be a PIL Image")
 
46
 
47
- # Make prediction
48
- with torch.no_grad():
49
- outputs = self.model(image_tensor)
50
- probs = softmax(outputs.logits.numpy()[0])
51
- pred_idx = np.argmax(probs)
52
-
53
- return self.labels[pred_idx], float(probs[pred_idx] * 100)
 
 
 
 
 
 
 
54
 
55
  def preprocess_image(self, image):
56
  """
57
  Preprocess image for model input
58
  """
59
- # Resize image if needed
60
- if image.size != (224, 224):
61
- image = image.resize((224, 224))
62
-
63
- # Convert to tensor using feature extractor
64
- inputs = self.feature_extractor(images=image, return_tensors="pt")
65
- return inputs.pixel_values
 
 
 
 
66
 
67
  def get_species_info(self, species):
68
  """
69
  Return information about a species
70
  """
71
- return self.species_info.get(species, "Information not available for this species.")
 
 
 
 
72
 
73
  def compare_species(self, species1, species2):
74
  """
75
  Generate comparison information between two species
76
  """
77
- # This would be expanded with actual comparison logic
 
 
78
  return f"""
79
  **Comparing {species1} and {species2}:**
80
 
81
- These species have different characteristics and roles in the ecosystem.
82
- {self.get_species_info(species1)}
83
 
84
- {self.get_species_info(species2)}
85
- """
86
-
87
- def generate_gradcam(image, model):
88
- """
89
- Generate Grad-CAM visualization for the image
90
- """
91
- # This is a simplified version - you would need to implement the actual Grad-CAM logic
92
- # For now, we'll return a simple heatmap overlay
93
- img_array = np.array(image)
94
- heatmap = cv2.applyColorMap(
95
- cv2.resize(np.random.rand(7,7) * 255, (224, 224)).astype(np.uint8),
96
- cv2.COLORMAP_JET
97
- )
98
-
99
- # Overlay heatmap on original image
100
- overlay = cv2.addWeighted(
101
- cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR),
102
- 0.7,
103
- heatmap,
104
- 0.3,
105
- 0
106
- )
107
-
108
- return Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
109
-
110
- def get_severity_prediction(species):
111
- """
112
- Predict ecological severity/impact based on species
113
- """
114
- # This would be replaced with actual severity prediction logic
115
- severity_map = {
116
- "Ladybug": "Low",
117
- "Butterfly": "Low",
118
- "Ant": "Medium",
119
- "Beetle": "Medium",
120
- "Spider": "Low",
121
- "Grasshopper": "Medium",
122
- "Moth": "Low",
123
- "Dragonfly": "Low",
124
- "Bee": "Low",
125
- "Wasp": "Medium"
126
- }
127
- return severity_map.get(species, "Medium")
 
9
  class BugClassifier:
10
  def __init__(self):
11
  # Initialize model and feature extractor
12
+ self.model = ViTForImageClassification.from_pretrained(
13
+ "google/vit-base-patch16-224",
14
+ num_labels=10, # Match number of classes
15
+ ignore_mismatched_sizes=True # Add this to handle size mismatch
16
+ )
17
  self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
18
 
19
+ # Define class labels
20
  self.labels = [
21
  "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
22
  "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
 
36
  """,
37
  # Add more species information as needed
38
  }
39
+
40
+ # Set model to evaluation mode
41
+ self.model.eval()
42
 
43
  def predict(self, image):
44
  """
45
  Make a prediction on the input image
46
  Returns predicted class and confidence score
47
  """
48
+ try:
49
+ # Preprocess image
50
+ if isinstance(image, Image.Image):
51
+ image_tensor = self.preprocess_image(image)
52
+ else:
53
+ raise ValueError("Input must be a PIL Image")
54
 
55
+ # Make prediction
56
+ with torch.no_grad():
57
+ outputs = self.model(image_tensor)
58
+ probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
59
+ pred_idx = np.argmax(probs)
60
+
61
+ # Ensure index is within bounds
62
+ if pred_idx >= len(self.labels):
63
+ pred_idx = 0 # Default to first class if out of bounds
64
+
65
+ return self.labels[pred_idx], float(probs[pred_idx] * 100)
66
+ except Exception as e:
67
+ print(f"Prediction error: {str(e)}")
68
+ return self.labels[0], 0.0 # Return default prediction in case of error
69
 
70
  def preprocess_image(self, image):
71
  """
72
  Preprocess image for model input
73
  """
74
+ try:
75
+ # Convert RGBA to RGB if necessary
76
+ if image.mode == 'RGBA':
77
+ image = image.convert('RGB')
78
+
79
+ # Process image using feature extractor
80
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
81
+ return inputs.pixel_values
82
+ except Exception as e:
83
+ print(f"Preprocessing error: {str(e)}")
84
+ raise
85
 
86
  def get_species_info(self, species):
87
  """
88
  Return information about a species
89
  """
90
+ return self.species_info.get(species, f"""
91
+ Information about {species}:
92
+ This species is part of our insect database. While detailed information
93
+ is still being compiled, all insects play important roles in their ecosystems.
94
+ """)
95
 
96
  def compare_species(self, species1, species2):
97
  """
98
  Generate comparison information between two species
99
  """
100
+ info1 = self.get_species_info(species1)
101
+ info2 = self.get_species_info(species2)
102
+
103
  return f"""
104
  **Comparing {species1} and {species2}:**
105
 
106
+ {species1}:
107
+ {info1}
108
 
109
+ {species2}:
110
+ {info2}
111
+
112
+ Both species contribute to their ecosystems in unique ways.
113
+ """