Spaces:
Sleeping
Sleeping
Update new_approach/spa_ensemble.py
Browse files- new_approach/spa_ensemble.py +10 -4
new_approach/spa_ensemble.py
CHANGED
|
@@ -12,10 +12,11 @@ from sklearn.preprocessing import StandardScaler
|
|
| 12 |
import torchvision.transforms as transforms
|
| 13 |
import open_clip
|
| 14 |
import joblib
|
| 15 |
-
from huggingface_hub import hf_hub_download
|
| 16 |
|
| 17 |
# --- CONFIGURATION ---
|
| 18 |
CONFIDENCE_THRESHOLD = 0.99
|
|
|
|
| 19 |
LIST_DIR = Path("list")
|
| 20 |
|
| 21 |
# ==============================================================================
|
|
@@ -43,6 +44,8 @@ class FeatureExtractor:
|
|
| 43 |
@staticmethod
|
| 44 |
def extract_texture_features(img):
|
| 45 |
img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
|
|
|
|
|
|
|
| 46 |
edges = cv2.Canny(gray, 50, 150)
|
| 47 |
gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 48 |
features.update({
|
|
@@ -74,6 +77,7 @@ class FeatureExtractor:
|
|
| 74 |
@staticmethod
|
| 75 |
def extract_all_features(img):
|
| 76 |
img = img.convert('RGB')
|
|
|
|
| 77 |
max_size = 1024
|
| 78 |
if max(img.size) > max_size:
|
| 79 |
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
|
@@ -170,7 +174,7 @@ class ModelManager:
|
|
| 170 |
print(f"Initializing SPA Ensemble on {self.device}...")
|
| 171 |
|
| 172 |
# --- CONFIG: YOUR MODEL REPO ID ---
|
| 173 |
-
#
|
| 174 |
self.REPO_ID = "FrAnKu34t23/ensemble_models_plant"
|
| 175 |
|
| 176 |
self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
|
|
@@ -192,11 +196,13 @@ class ModelManager:
|
|
| 192 |
# 3. Download & Load Scaler
|
| 193 |
print("SPA Ensemble: Downloading Scaler...")
|
| 194 |
try:
|
|
|
|
| 195 |
scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib")
|
| 196 |
self.scaler = joblib.load(scaler_path)
|
| 197 |
-
print("✓ Scaler loaded.")
|
| 198 |
except Exception as e:
|
| 199 |
-
print(f"Warning: Could not download scaler: {e}.
|
|
|
|
| 200 |
self.scaler = StandardScaler()
|
| 201 |
self.scaler.fit(np.zeros((1, 49)))
|
| 202 |
|
|
|
|
| 12 |
import torchvision.transforms as transforms
|
| 13 |
import open_clip
|
| 14 |
import joblib
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
|
| 17 |
# --- CONFIGURATION ---
|
| 18 |
CONFIDENCE_THRESHOLD = 0.99
|
| 19 |
+
# The list directory remains in the root of the Space
|
| 20 |
LIST_DIR = Path("list")
|
| 21 |
|
| 22 |
# ==============================================================================
|
|
|
|
| 44 |
@staticmethod
|
| 45 |
def extract_texture_features(img):
|
| 46 |
img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
|
| 47 |
+
# Optimization: Canny/Sobel can be slow on huge images.
|
| 48 |
+
# We assume image is resized in extract_all_features
|
| 49 |
edges = cv2.Canny(gray, 50, 150)
|
| 50 |
gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
| 51 |
features.update({
|
|
|
|
| 77 |
@staticmethod
|
| 78 |
def extract_all_features(img):
|
| 79 |
img = img.convert('RGB')
|
| 80 |
+
# OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel
|
| 81 |
max_size = 1024
|
| 82 |
if max(img.size) > max_size:
|
| 83 |
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
|
|
|
| 174 |
print(f"Initializing SPA Ensemble on {self.device}...")
|
| 175 |
|
| 176 |
# --- CONFIG: YOUR MODEL REPO ID ---
|
| 177 |
+
# Using the correct repo ID provided
|
| 178 |
self.REPO_ID = "FrAnKu34t23/ensemble_models_plant"
|
| 179 |
|
| 180 |
self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
|
|
|
|
| 196 |
# 3. Download & Load Scaler
|
| 197 |
print("SPA Ensemble: Downloading Scaler...")
|
| 198 |
try:
|
| 199 |
+
# Now fetching scaler.joblib from the Model Repo
|
| 200 |
scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib")
|
| 201 |
self.scaler = joblib.load(scaler_path)
|
| 202 |
+
print("✓ Scaler downloaded and loaded.")
|
| 203 |
except Exception as e:
|
| 204 |
+
print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
|
| 205 |
+
print("Using dummy scaler (predictions may be inaccurate).")
|
| 206 |
self.scaler = StandardScaler()
|
| 207 |
self.scaler.fit(np.zeros((1, 49)))
|
| 208 |
|