Spaces:
Sleeping
Sleeping
Update new_approach/spa_ensemble.py
Browse files- new_approach/spa_ensemble.py +20 -10
new_approach/spa_ensemble.py
CHANGED
|
@@ -32,8 +32,11 @@ class FeatureExtractor:
|
|
| 32 |
features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
|
| 33 |
else:
|
| 34 |
features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
try:
|
| 38 |
hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
|
| 39 |
features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
|
|
@@ -136,18 +139,22 @@ class BioCLIP2ZeroShot:
|
|
| 136 |
prototypes = self.text_features_prototypes
|
| 137 |
try: logit_scale = self.model.logit_scale.exp()
|
| 138 |
except: logit_scale = torch.tensor(100.0).to(self.device)
|
| 139 |
-
|
|
|
|
| 140 |
return logits
|
| 141 |
|
| 142 |
class EnsembleClassifier(nn.Module):
|
| 143 |
-
def __init__(self, num_handcrafted_features=
|
| 144 |
-
num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=
|
| 145 |
super().__init__()
|
| 146 |
self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
|
|
|
|
|
|
|
| 147 |
self.handcraft_branch = nn.Sequential(
|
| 148 |
nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
|
| 149 |
-
nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate)
|
| 150 |
-
|
|
|
|
| 151 |
self.bioclip2_branch = nn.Sequential(
|
| 152 |
nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
|
| 153 |
fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
|
|
@@ -204,7 +211,8 @@ class ModelManager:
|
|
| 204 |
print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
|
| 205 |
print("Using dummy scaler (predictions may be inaccurate).")
|
| 206 |
self.scaler = StandardScaler()
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
# 4. Download & Load Ensemble Models
|
| 210 |
self.models = []
|
|
@@ -219,9 +227,11 @@ class ModelManager:
|
|
| 219 |
model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename)
|
| 220 |
|
| 221 |
# Load
|
|
|
|
| 222 |
model = EnsembleClassifier(
|
| 223 |
-
num_handcrafted_features=
|
| 224 |
-
num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i]
|
|
|
|
| 225 |
)
|
| 226 |
state_dict = torch.load(model_path, map_location=self.device)
|
| 227 |
model.load_state_dict(state_dict)
|
|
|
|
| 32 |
features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
|
| 33 |
else:
|
| 34 |
features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
|
| 35 |
+
|
| 36 |
+
# --- FIX: Removed Histogram extraction (9 features) to match the 40 features expected by your .pth files ---
|
| 37 |
+
# hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8);
|
| 38 |
+
# for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v)
|
| 39 |
+
|
| 40 |
try:
|
| 41 |
hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
|
| 42 |
features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
|
|
|
|
| 139 |
prototypes = self.text_features_prototypes
|
| 140 |
try: logit_scale = self.model.logit_scale.exp()
|
| 141 |
except: logit_scale = torch.tensor(100.0).to(self.device)
|
| 142 |
+
# --- FIX: Added .detach() before .numpy() ---
|
| 143 |
+
logits = (logit_scale * image_features @ prototypes.T).detach().cpu().numpy().squeeze()
|
| 144 |
return logits
|
| 145 |
|
| 146 |
class EnsembleClassifier(nn.Module):
|
| 147 |
+
def __init__(self, num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=100,
|
| 148 |
+
num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=768):
|
| 149 |
super().__init__()
|
| 150 |
self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
|
| 151 |
+
|
| 152 |
+
# --- FIX: Removed 3rd layer to match training checkpoint (Size mismatch error) ---
|
| 153 |
self.handcraft_branch = nn.Sequential(
|
| 154 |
nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
|
| 155 |
+
nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
self.bioclip2_branch = nn.Sequential(
|
| 159 |
nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
|
| 160 |
fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
|
|
|
|
| 211 |
print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
|
| 212 |
print("Using dummy scaler (predictions may be inaccurate).")
|
| 213 |
self.scaler = StandardScaler()
|
| 214 |
+
# FIX: Fit on 40 zeros instead of 49 to match the feature reduction
|
| 215 |
+
self.scaler.fit(np.zeros((1, 40)))
|
| 216 |
|
| 217 |
# 4. Download & Load Ensemble Models
|
| 218 |
self.models = []
|
|
|
|
| 227 |
model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename)
|
| 228 |
|
| 229 |
# Load
|
| 230 |
+
# FIX: Passed num_handcrafted_features=40 and prototype_dim=768 to match weights
|
| 231 |
model = EnsembleClassifier(
|
| 232 |
+
num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=self.num_classes,
|
| 233 |
+
num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i],
|
| 234 |
+
prototype_dim=768
|
| 235 |
)
|
| 236 |
state_dict = torch.load(model_path, map_location=self.device)
|
| 237 |
model.load_state_dict(state_dict)
|