FrAnKu34t23 commited on
Commit
64262de
·
verified ·
1 Parent(s): 88a7516

Update new_approach/spa_ensemble.py

Browse files
Files changed (1) hide show
  1. new_approach/spa_ensemble.py +10 -4
new_approach/spa_ensemble.py CHANGED
@@ -12,10 +12,11 @@ from sklearn.preprocessing import StandardScaler
12
  import torchvision.transforms as transforms
13
  import open_clip
14
  import joblib
15
- from huggingface_hub import hf_hub_download # <--- REQUIRED for downloading weights
16
 
17
  # --- CONFIGURATION ---
18
  CONFIDENCE_THRESHOLD = 0.99
 
19
  LIST_DIR = Path("list")
20
 
21
  # ==============================================================================
@@ -43,6 +44,8 @@ class FeatureExtractor:
43
  @staticmethod
44
  def extract_texture_features(img):
45
  img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
 
 
46
  edges = cv2.Canny(gray, 50, 150)
47
  gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
48
  features.update({
@@ -74,6 +77,7 @@ class FeatureExtractor:
74
  @staticmethod
75
  def extract_all_features(img):
76
  img = img.convert('RGB')
 
77
  max_size = 1024
78
  if max(img.size) > max_size:
79
  img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
@@ -170,7 +174,7 @@ class ModelManager:
170
  print(f"Initializing SPA Ensemble on {self.device}...")
171
 
172
  # --- CONFIG: YOUR MODEL REPO ID ---
173
- # ⚠️ REPLACE 'YourUsername' with your actual HuggingFace username!
174
  self.REPO_ID = "FrAnKu34t23/ensemble_models_plant"
175
 
176
  self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
@@ -192,11 +196,13 @@ class ModelManager:
192
  # 3. Download & Load Scaler
193
  print("SPA Ensemble: Downloading Scaler...")
194
  try:
 
195
  scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib")
196
  self.scaler = joblib.load(scaler_path)
197
- print("✓ Scaler loaded.")
198
  except Exception as e:
199
- print(f"Warning: Could not download scaler: {e}. Using dummy.")
 
200
  self.scaler = StandardScaler()
201
  self.scaler.fit(np.zeros((1, 49)))
202
 
 
12
  import torchvision.transforms as transforms
13
  import open_clip
14
  import joblib
15
+ from huggingface_hub import hf_hub_download
16
 
17
  # --- CONFIGURATION ---
18
  CONFIDENCE_THRESHOLD = 0.99
19
+ # The list directory remains in the root of the Space
20
  LIST_DIR = Path("list")
21
 
22
  # ==============================================================================
 
44
  @staticmethod
45
  def extract_texture_features(img):
46
  img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
47
+ # Optimization: Canny/Sobel can be slow on huge images.
48
+ # We assume image is resized in extract_all_features
49
  edges = cv2.Canny(gray, 50, 150)
50
  gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
51
  features.update({
 
77
  @staticmethod
78
  def extract_all_features(img):
79
  img = img.convert('RGB')
80
+ # OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel
81
  max_size = 1024
82
  if max(img.size) > max_size:
83
  img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
 
174
  print(f"Initializing SPA Ensemble on {self.device}...")
175
 
176
  # --- CONFIG: YOUR MODEL REPO ID ---
177
+ # Using the correct repo ID provided
178
  self.REPO_ID = "FrAnKu34t23/ensemble_models_plant"
179
 
180
  self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
 
196
  # 3. Download & Load Scaler
197
  print("SPA Ensemble: Downloading Scaler...")
198
  try:
199
+ # Now fetching scaler.joblib from the Model Repo
200
  scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib")
201
  self.scaler = joblib.load(scaler_path)
202
+ print("✓ Scaler downloaded and loaded.")
203
  except Exception as e:
204
+ print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
205
+ print("Using dummy scaler (predictions may be inaccurate).")
206
  self.scaler = StandardScaler()
207
  self.scaler.fit(np.zeros((1, 49)))
208