doozer21 commited on
Commit
2ce327e
Β·
verified Β·
1 Parent(s): cbdf1e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -192
app.py CHANGED
@@ -4,13 +4,10 @@
4
  # IMPROVEMENTS:
5
  # -------------
6
  # βœ… Mobile-friendly single-column layout
7
- # βœ… Fixed mobile upload issues with session state
8
- # βœ… Persistent predictions across reruns
9
- # βœ… Simplified, responsive CSS
10
- # βœ… Better error handling
11
- # βœ… Loads model from Hugging Face Hub OR local file
12
- # βœ… Optimized for slow connections
13
- # βœ… Touch-friendly interface
14
  #
15
  # ============================================================
16
 
@@ -21,7 +18,6 @@ from torchvision import transforms
21
  from PIL import Image
22
  import timm
23
  from pathlib import Path
24
- import hashlib
25
 
26
  # ============================================================
27
  # PAGE CONFIGURATION
@@ -34,37 +30,23 @@ st.set_page_config(
34
  initial_sidebar_state="collapsed"
35
  )
36
 
37
- # ============================================================
38
- # SESSION STATE INITIALIZATION
39
- # ============================================================
40
-
41
- if 'predictions' not in st.session_state:
42
- st.session_state.predictions = None
43
- if 'processed_image' not in st.session_state:
44
- st.session_state.processed_image = None
45
- if 'last_image_hash' not in st.session_state:
46
- st.session_state.last_image_hash = None
47
-
48
  # ============================================================
49
  # MINIMAL CSS (Mobile-First)
50
  # ============================================================
51
 
52
  st.markdown("""
53
  <style>
54
- /* Remove extra padding on mobile */
55
  .block-container {
56
  padding-top: 2rem;
57
  padding-bottom: 2rem;
58
  }
59
 
60
- /* Cleaner header */
61
  h1 {
62
  text-align: center;
63
  color: #FF6B6B;
64
  margin-bottom: 0.5rem;
65
  }
66
 
67
- /* Result cards */
68
  .prediction-card {
69
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
70
  padding: 1.5rem;
@@ -85,14 +67,12 @@ st.markdown("""
85
  opacity: 0.9;
86
  }
87
 
88
- /* Confidence bars */
89
  .conf-bar {
90
  background: #f0f0f0;
91
  border-radius: 8px;
92
  height: 36px;
93
  margin: 0.5rem 0;
94
  overflow: hidden;
95
- position: relative;
96
  }
97
 
98
  .conf-fill {
@@ -105,16 +85,6 @@ st.markdown("""
105
  font-weight: 600;
106
  font-size: 0.95rem;
107
  }
108
-
109
- /* Make file uploader more visible */
110
- .stFileUploader {
111
- margin-bottom: 1rem;
112
- }
113
-
114
- /* Make camera input more visible */
115
- .stCameraInput {
116
- margin-top: 1rem;
117
- }
118
  </style>
119
  """, unsafe_allow_html=True)
120
 
@@ -145,52 +115,35 @@ FOOD_CLASSES = [
145
  "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles"
146
  ]
147
 
148
- # ============================================================
149
- # HELPER FUNCTIONS
150
- # ============================================================
151
-
152
- def get_image_hash(image_bytes):
153
- """Create a hash of image bytes to detect if it's a new image."""
154
- return hashlib.md5(image_bytes).hexdigest()
155
-
156
  # ============================================================
157
  # MODEL LOADING
158
  # ============================================================
159
 
160
  @st.cache_resource
161
  def load_model():
162
- """
163
- Loads model from local file or Hugging Face Hub.
164
- Cached for performance across sessions.
165
- """
166
  try:
167
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
168
 
169
- # Try loading from local file first (for HF Spaces)
170
  local_path = Path("model1_best.pth")
171
 
172
  if local_path.exists():
173
- checkpoint = torch.load(local_path, map_location=device)
174
  else:
175
- # Fallback: try to download from HF Hub
176
  try:
177
  from huggingface_hub import hf_hub_download
178
  model_path = hf_hub_download(
179
  repo_id="doozer21/FoodVision",
180
  filename="model1_best.pth"
181
  )
182
- checkpoint = torch.load(model_path, map_location=device)
183
- except Exception as e:
184
- st.error("❌ Could not load model from local file or Hugging Face Hub")
185
- st.info("Make sure model1_best.pth is in your Space's repository")
186
  return None, None, None
187
 
188
- # Get config
189
  model_config = checkpoint.get('model_config', {
190
  'model_id': 'convnextv2_base.fcmae_ft_in22k_in1k_384'
191
  })
192
 
193
- # Create and load model
194
  model = timm.create_model(
195
  model_config['model_id'],
196
  pretrained=False,
@@ -202,7 +155,6 @@ def load_model():
202
  model.eval()
203
 
204
  accuracy = checkpoint.get('best_val_acc', 0)
205
-
206
  return model, device, accuracy
207
 
208
  except Exception as e:
@@ -251,6 +203,50 @@ def predict(model, image_tensor, device, top_k=3):
251
 
252
  return results
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # ============================================================
255
  # MAIN APP
256
  # ============================================================
@@ -260,14 +256,14 @@ def main():
260
  st.title("πŸ• FoodVision AI")
261
  st.markdown("**Identify 101 food dishes instantly**")
262
 
263
- # Load model with status
264
- with st.spinner("πŸ”„ Loading AI model..."):
265
- model, device, accuracy = load_model()
266
 
267
  if model is None:
 
268
  st.stop()
269
 
270
- # Show model info in expander (cleaner for mobile)
271
  with st.expander("ℹ️ Model Info"):
272
  st.write(f"**Architecture:** ConvNeXt V2 Base")
273
  st.write(f"**Accuracy:** {accuracy:.2f}%")
@@ -276,149 +272,72 @@ def main():
276
 
277
  st.markdown("---")
278
 
279
- # Single-column layout (mobile-friendly)
280
- st.subheader("πŸ“Έ Upload or Take a Photo")
281
-
282
- # File uploader
283
- uploaded_file = st.file_uploader(
284
- "Choose a food image",
285
- type=['jpg', 'jpeg', 'png', 'webp'],
286
- key="file_uploader"
287
- )
288
-
289
- # Camera input (below uploader)
290
- st.markdown("**Or use your camera:**")
291
- camera_photo = st.camera_input(
292
- "Take a picture",
293
- key="camera_input"
294
- )
295
-
296
- # Determine which image to use
297
- image_source = None
298
- source_name = ""
299
- image_bytes = None
300
 
301
- if camera_photo is not None:
302
- image_source = camera_photo
303
- source_name = "camera"
304
- image_bytes = camera_photo.getvalue()
305
- elif uploaded_file is not None:
306
- image_source = uploaded_file
307
- source_name = "upload"
308
- image_bytes = uploaded_file.getvalue()
309
 
310
- # Process image if we have one
311
- if image_source is not None and image_bytes is not None:
312
- try:
313
- # Check if this is a new image
314
- current_hash = get_image_hash(image_bytes)
315
-
316
- # Only process if it's a new image
317
- if current_hash != st.session_state.last_image_hash:
318
- # Load image
319
- image = Image.open(image_source)
320
-
321
- # Store image in session state
322
- st.session_state.processed_image = image
323
- st.session_state.last_image_hash = current_hash
324
 
325
- # Show loading indicator
326
- with st.spinner("🧠 Analyzing your food..."):
327
- # Preprocess and predict
328
  img_tensor = preprocess_image(image)
329
  predictions = predict(model, img_tensor, device, top_k=3)
330
-
331
- # Store predictions in session state
332
- st.session_state.predictions = predictions
333
-
334
- # Display results (from session state)
335
- if st.session_state.processed_image is not None:
336
- # Show image preview
337
- st.image(
338
- st.session_state.processed_image,
339
- caption=f"Image from {source_name}",
340
- use_column_width=True
341
- )
342
-
343
- if st.session_state.predictions is not None:
344
- st.markdown("---")
345
 
346
- # Display top prediction prominently
347
- top_food, top_conf = st.session_state.predictions[0]
348
 
349
- st.markdown(f"""
350
- <div class="prediction-card">
351
- <h2>πŸ† {top_food}</h2>
352
- <h3>{top_conf:.1f}% Confidence</h3>
353
- </div>
354
- """, unsafe_allow_html=True)
355
-
356
- # Show all top-3 predictions
357
- st.markdown("### πŸ“Š Top 3 Predictions")
 
358
 
359
- for i, (food, conf) in enumerate(st.session_state.predictions, 1):
360
- emoji = "πŸ₯‡" if i == 1 else "πŸ₯ˆ" if i == 2 else "πŸ₯‰"
361
-
362
- st.markdown(f"**{emoji} {food}**")
363
- st.markdown(f"""
364
- <div class="conf-bar">
365
- <div class="conf-fill" style="width: {conf}%">
366
- {conf:.1f}%
367
- </div>
368
- </div>
369
- """, unsafe_allow_html=True)
370
 
371
- # Feedback based on confidence
372
- st.markdown("---")
373
- if top_conf > 90:
374
- st.success("πŸŽ‰ **Very confident!** The model is very sure about this prediction.")
375
- elif top_conf > 70:
376
- st.success("πŸ‘ **Good confidence!** This looks like a solid prediction.")
377
- elif top_conf > 50:
378
- st.warning("πŸ€” **Moderate confidence.** The food might be ambiguous or partially visible.")
379
- else:
380
- st.warning("πŸ˜• **Low confidence.** Try a clearer photo with better lighting.")
381
 
382
- # Add a clear button to reset
383
- if st.button("πŸ”„ Analyze Another Image", use_container_width=True):
384
- st.session_state.predictions = None
385
- st.session_state.processed_image = None
386
- st.session_state.last_image_hash = None
387
- st.rerun()
388
-
389
- except Exception as e:
390
- st.error(f"❌ Error: {str(e)}")
391
- st.info("Try a different image or check if the file is corrupted")
392
-
393
- # Reset state on error
394
- st.session_state.predictions = None
395
- st.session_state.processed_image = None
396
- st.session_state.last_image_hash = None
397
 
398
- else:
399
- # Instructions (only show if no predictions)
400
- if st.session_state.predictions is None:
401
- st.info("πŸ‘† Upload a food image or take a photo to get started!")
402
-
403
- with st.expander("πŸ’‘ Tips for Best Results"):
404
- st.markdown("""
405
- - Use clear, well-lit photos
406
- - Make sure food is the main subject
407
- - Avoid heavily filtered images
408
- - Try different angles if confidence is low
409
- - Works best with common dishes
410
- """)
411
-
412
- with st.expander("🍽️ What can it recognize?"):
413
- st.markdown("""
414
- The model can identify **101 popular dishes** including:
415
- - πŸ• Pizza, Pasta, Burgers
416
- - 🍣 Sushi, Ramen, Pad Thai
417
- - πŸ₯— Salads, Sandwiches
418
- - 🍰 Desserts (cakes, ice cream, etc.)
419
- - 🍳 Breakfast foods
420
- - And many more!
421
- """)
422
 
423
  # Footer
424
  st.markdown("---")
 
4
  # IMPROVEMENTS:
5
  # -------------
6
  # βœ… Mobile-friendly single-column layout
7
+ # βœ… SIMPLIFIED: No session state complexity
8
+ # βœ… Direct processing on every upload
9
+ # βœ… Works reliably on mobile
10
+ # βœ… No unnecessary buttons
 
 
 
11
  #
12
  # ============================================================
13
 
 
18
  from PIL import Image
19
  import timm
20
  from pathlib import Path
 
21
 
22
  # ============================================================
23
  # PAGE CONFIGURATION
 
30
  initial_sidebar_state="collapsed"
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  # ============================================================
34
  # MINIMAL CSS (Mobile-First)
35
  # ============================================================
36
 
37
  st.markdown("""
38
  <style>
 
39
  .block-container {
40
  padding-top: 2rem;
41
  padding-bottom: 2rem;
42
  }
43
 
 
44
  h1 {
45
  text-align: center;
46
  color: #FF6B6B;
47
  margin-bottom: 0.5rem;
48
  }
49
 
 
50
  .prediction-card {
51
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
52
  padding: 1.5rem;
 
67
  opacity: 0.9;
68
  }
69
 
 
70
  .conf-bar {
71
  background: #f0f0f0;
72
  border-radius: 8px;
73
  height: 36px;
74
  margin: 0.5rem 0;
75
  overflow: hidden;
 
76
  }
77
 
78
  .conf-fill {
 
85
  font-weight: 600;
86
  font-size: 0.95rem;
87
  }
 
 
 
 
 
 
 
 
 
 
88
  </style>
89
  """, unsafe_allow_html=True)
90
 
 
115
  "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles"
116
  ]
117
 
 
 
 
 
 
 
 
 
118
  # ============================================================
119
  # MODEL LOADING
120
  # ============================================================
121
 
122
  @st.cache_resource
123
  def load_model():
124
+ """Loads model from local file or Hugging Face Hub."""
 
 
 
125
  try:
126
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
127
 
 
128
  local_path = Path("model1_best.pth")
129
 
130
  if local_path.exists():
131
+ checkpoint = torch.load(local_path, map_location=device, weights_only=False)
132
  else:
 
133
  try:
134
  from huggingface_hub import hf_hub_download
135
  model_path = hf_hub_download(
136
  repo_id="doozer21/FoodVision",
137
  filename="model1_best.pth"
138
  )
139
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
140
+ except Exception:
 
 
141
  return None, None, None
142
 
 
143
  model_config = checkpoint.get('model_config', {
144
  'model_id': 'convnextv2_base.fcmae_ft_in22k_in1k_384'
145
  })
146
 
 
147
  model = timm.create_model(
148
  model_config['model_id'],
149
  pretrained=False,
 
155
  model.eval()
156
 
157
  accuracy = checkpoint.get('best_val_acc', 0)
 
158
  return model, device, accuracy
159
 
160
  except Exception as e:
 
203
 
204
  return results
205
 
206
+ # ============================================================
207
+ # DISPLAY RESULTS
208
+ # ============================================================
209
+
210
+ def display_results(predictions):
211
+ """Display prediction results."""
212
+ st.markdown("---")
213
+
214
+ # Top prediction
215
+ top_food, top_conf = predictions[0]
216
+
217
+ st.markdown(f"""
218
+ <div class="prediction-card">
219
+ <h2>πŸ† {top_food}</h2>
220
+ <h3>{top_conf:.1f}% Confidence</h3>
221
+ </div>
222
+ """, unsafe_allow_html=True)
223
+
224
+ # Top 3 predictions
225
+ st.markdown("### πŸ“Š Top 3 Predictions")
226
+
227
+ for i, (food, conf) in enumerate(predictions, 1):
228
+ emoji = "πŸ₯‡" if i == 1 else "πŸ₯ˆ" if i == 2 else "πŸ₯‰"
229
+
230
+ st.markdown(f"**{emoji} {food}**")
231
+ st.markdown(f"""
232
+ <div class="conf-bar">
233
+ <div class="conf-fill" style="width: {conf}%">
234
+ {conf:.1f}%
235
+ </div>
236
+ </div>
237
+ """, unsafe_allow_html=True)
238
+
239
+ # Feedback
240
+ st.markdown("---")
241
+ if top_conf > 90:
242
+ st.success("πŸŽ‰ **Very confident!** The model is very sure.")
243
+ elif top_conf > 70:
244
+ st.success("πŸ‘ **Good confidence!** Solid prediction.")
245
+ elif top_conf > 50:
246
+ st.warning("πŸ€” **Moderate confidence.** Food might be ambiguous.")
247
+ else:
248
+ st.warning("πŸ˜• **Low confidence.** Try a clearer photo.")
249
+
250
  # ============================================================
251
  # MAIN APP
252
  # ============================================================
 
256
  st.title("πŸ• FoodVision AI")
257
  st.markdown("**Identify 101 food dishes instantly**")
258
 
259
+ # Load model
260
+ model, device, accuracy = load_model()
 
261
 
262
  if model is None:
263
+ st.error("❌ Could not load model. Check if model1_best.pth exists.")
264
  st.stop()
265
 
266
+ # Model info
267
  with st.expander("ℹ️ Model Info"):
268
  st.write(f"**Architecture:** ConvNeXt V2 Base")
269
  st.write(f"**Accuracy:** {accuracy:.2f}%")
 
272
 
273
  st.markdown("---")
274
 
275
+ # Input section
276
+ st.subheader("πŸ“Έ Choose Your Input Method")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ # Tab-based approach (better for mobile)
279
+ tab1, tab2 = st.tabs(["πŸ“ Upload Image", "πŸ“· Take Photo"])
 
 
 
 
 
 
280
 
281
+ with tab1:
282
+ uploaded_file = st.file_uploader(
283
+ "Select a food image",
284
+ type=['jpg', 'jpeg', 'png', 'webp'],
285
+ label_visibility="collapsed"
286
+ )
287
+
288
+ if uploaded_file is not None:
289
+ try:
290
+ image = Image.open(uploaded_file)
291
+ st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
292
 
293
+ with st.spinner("🧠 Analyzing..."):
 
 
294
  img_tensor = preprocess_image(image)
295
  predictions = predict(model, img_tensor, device, top_k=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ display_results(predictions)
 
298
 
299
+ except Exception as e:
300
+ st.error(f"❌ Error: {str(e)}")
301
+
302
+ with tab2:
303
+ camera_photo = st.camera_input("Take a picture", label_visibility="collapsed")
304
+
305
+ if camera_photo is not None:
306
+ try:
307
+ image = Image.open(camera_photo)
308
+ st.image(image, caption="Camera Photo", use_column_width=True)
309
 
310
+ with st.spinner("🧠 Analyzing..."):
311
+ img_tensor = preprocess_image(image)
312
+ predictions = predict(model, img_tensor, device, top_k=3)
 
 
 
 
 
 
 
 
313
 
314
+ display_results(predictions)
 
 
 
 
 
 
 
 
 
315
 
316
+ except Exception as e:
317
+ st.error(f"❌ Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ # Instructions (show at bottom when no image)
320
+ if uploaded_file is None and camera_photo is None:
321
+ st.info("πŸ‘† Choose a tab above to get started!")
322
+
323
+ with st.expander("πŸ’‘ Tips for Best Results"):
324
+ st.markdown("""
325
+ - Use clear, well-lit photos
326
+ - Make sure food is the main subject
327
+ - Avoid heavily filtered images
328
+ - Try different angles if confidence is low
329
+ """)
330
+
331
+ with st.expander("🍽️ What can it recognize?"):
332
+ st.markdown("""
333
+ **101 popular dishes** including:
334
+ - πŸ• Pizza, Pasta, Burgers
335
+ - 🍣 Sushi, Ramen, Pad Thai
336
+ - πŸ₯— Salads, Sandwiches
337
+ - 🍰 Desserts (cakes, ice cream)
338
+ - 🍳 Breakfast foods
339
+ - And many more!
340
+ """)
 
 
341
 
342
  # Footer
343
  st.markdown("---")