FrAnKu34t23 commited on
Commit
f8ed33a
·
verified ·
1 Parent(s): 64262de

Update new_approach/spa_ensemble.py

Browse files
Files changed (1) hide show
  1. new_approach/spa_ensemble.py +20 -10
new_approach/spa_ensemble.py CHANGED
@@ -32,8 +32,11 @@ class FeatureExtractor:
32
  features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
33
  else:
34
  features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
35
- hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8);
36
- for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v)
 
 
 
37
  try:
38
  hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
39
  features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
@@ -136,18 +139,22 @@ class BioCLIP2ZeroShot:
136
  prototypes = self.text_features_prototypes
137
  try: logit_scale = self.model.logit_scale.exp()
138
  except: logit_scale = torch.tensor(100.0).to(self.device)
139
- logits = (logit_scale * image_features @ prototypes.T).cpu().numpy().squeeze()
 
140
  return logits
141
 
142
  class EnsembleClassifier(nn.Module):
143
- def __init__(self, num_handcrafted_features=49, dinov2_dim=1024, bioclip2_dim=100,
144
- num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=512):
145
  super().__init__()
146
  self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
 
 
147
  self.handcraft_branch = nn.Sequential(
148
  nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
149
- nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate),
150
- nn.Linear(hidden_dim // 2, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate))
 
151
  self.bioclip2_branch = nn.Sequential(
152
  nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
153
  fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
@@ -204,7 +211,8 @@ class ModelManager:
204
  print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
205
  print("Using dummy scaler (predictions may be inaccurate).")
206
  self.scaler = StandardScaler()
207
- self.scaler.fit(np.zeros((1, 49)))
 
208
 
209
  # 4. Download & Load Ensemble Models
210
  self.models = []
@@ -219,9 +227,11 @@ class ModelManager:
219
  model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename)
220
 
221
  # Load
 
222
  model = EnsembleClassifier(
223
- num_handcrafted_features=49, dinov2_dim=1024, bioclip2_dim=self.num_classes,
224
- num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i]
 
225
  )
226
  state_dict = torch.load(model_path, map_location=self.device)
227
  model.load_state_dict(state_dict)
 
32
  features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
33
  else:
34
  features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
35
+
36
+ # --- FIX: Removed Histogram extraction (9 features) to match the 40 features expected by your .pth files ---
37
+ # hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8);
38
+ # for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v)
39
+
40
  try:
41
  hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
42
  features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
 
139
  prototypes = self.text_features_prototypes
140
  try: logit_scale = self.model.logit_scale.exp()
141
  except: logit_scale = torch.tensor(100.0).to(self.device)
142
+ # --- FIX: Added .detach() before .numpy() ---
143
+ logits = (logit_scale * image_features @ prototypes.T).detach().cpu().numpy().squeeze()
144
  return logits
145
 
146
  class EnsembleClassifier(nn.Module):
147
+ def __init__(self, num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=100,
148
+ num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=768):
149
  super().__init__()
150
  self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
151
+
152
+ # --- FIX: Removed 3rd layer to match training checkpoint (Size mismatch error) ---
153
  self.handcraft_branch = nn.Sequential(
154
  nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
155
+ nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate)
156
+ )
157
+
158
  self.bioclip2_branch = nn.Sequential(
159
  nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
160
  fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
 
211
  print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
212
  print("Using dummy scaler (predictions may be inaccurate).")
213
  self.scaler = StandardScaler()
214
+ # FIX: Fit on 40 zeros instead of 49 to match the feature reduction
215
+ self.scaler.fit(np.zeros((1, 40)))
216
 
217
  # 4. Download & Load Ensemble Models
218
  self.models = []
 
227
  model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename)
228
 
229
  # Load
230
+ # FIX: Passed num_handcrafted_features=40 and prototype_dim=768 to match weights
231
  model = EnsembleClassifier(
232
+ num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=self.num_classes,
233
+ num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i],
234
+ prototype_dim=768
235
  )
236
  state_dict = torch.load(model_path, map_location=self.device)
237
  model.load_state_dict(state_dict)