Spaces:
Runtime error
Runtime error
| # predict.py - FoodVision Command Line Prediction Script | |
| # ============================================================ | |
| # | |
| # USAGE: | |
| # ------ | |
| # python predict.py --image path/to/food.jpg | |
| # python predict.py --image pizza.jpg --top 5 | |
| # python predict.py --folder food_images/ | |
| # python predict.py --image burger.png --json results.json | |
| # | |
| # FEATURES: | |
| # --------- | |
| # β Single image prediction | |
| # β Batch prediction on folder | |
| # β JSON output option | |
| # β Detailed or simple output modes | |
| # | |
| # ============================================================ | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import timm | |
| import argparse | |
| from pathlib import Path | |
| import json | |
| import sys | |
| # ============================================================ | |
| # FOOD CLASSES (101 categories) | |
| # ============================================================ | |
| FOOD_CLASSES = [ | |
| "apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", | |
| "beet_salad", "beignets", "bibimbap", "bread_pudding", "breakfast_burrito", | |
| "bruschetta", "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", | |
| "ceviche", "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", | |
| "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder", | |
| "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", | |
| "deviled_eggs", "donuts", "dumplings", "edamame", "eggs_benedict", | |
| "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", | |
| "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", | |
| "frozen_yogurt", "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", | |
| "grilled_salmon", "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", | |
| "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna", | |
| "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", | |
| "mussels", "nachos", "omelette", "onion_rings", "oysters", | |
| "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck", | |
| "pho", "pizza", "pork_chop", "poutine", "prime_rib", | |
| "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", | |
| "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", | |
| "spaghetti_bolognese", "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", | |
| "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles" | |
| ] | |
| # ============================================================ | |
| # MODEL LOADING | |
| # ============================================================ | |
| def load_model(model_path, device): | |
| """ | |
| Loads a trained model from checkpoint. | |
| Args: | |
| model_path: Path to .pth checkpoint file | |
| device: torch device ('cuda' or 'cpu') | |
| Returns: | |
| Loaded model in eval mode | |
| """ | |
| print(f"π Loading model from: {model_path}") | |
| try: | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Get model config | |
| model_config = checkpoint.get('model_config', { | |
| 'model_id': 'convnextv2_base.fcmae_ft_in22k_in1k_384' | |
| }) | |
| # Create model | |
| model = timm.create_model( | |
| model_config['model_id'], | |
| pretrained=False, | |
| num_classes=101 | |
| ) | |
| # Load weights | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| accuracy = checkpoint.get('best_val_acc', 0) | |
| print(f"β Model loaded successfully!") | |
| print(f" Architecture: {model_config.get('name', 'ConvNeXt V2')}") | |
| if accuracy > 0: | |
| print(f" Accuracy: {accuracy:.2f}%") | |
| return model | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| sys.exit(1) | |
| # ============================================================ | |
| # IMAGE PREPROCESSING | |
| # ============================================================ | |
| def preprocess_image(image_path): | |
| """ | |
| Loads and preprocesses an image. | |
| Args: | |
| image_path: Path to image file | |
| Returns: | |
| Preprocessed image tensor | |
| """ | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| try: | |
| # Load and convert image | |
| image = Image.open(image_path).convert('RGB') | |
| # Apply transforms | |
| img_tensor = transform(image).unsqueeze(0) | |
| return img_tensor | |
| except Exception as e: | |
| print(f"β Error loading image {image_path}: {e}") | |
| return None | |
| # ============================================================ | |
| # PREDICTION FUNCTION | |
| # ============================================================ | |
| def predict(model, image_tensor, device, top_k=3): | |
| """ | |
| Predicts food class for a single image. | |
| Args: | |
| model: Trained PyTorch model | |
| image_tensor: Preprocessed image | |
| device: torch device | |
| top_k: Number of top predictions | |
| Returns: | |
| List of (class_name, confidence) tuples | |
| """ | |
| with torch.no_grad(): | |
| # Move to device | |
| image_tensor = image_tensor.to(device) | |
| # Forward pass | |
| outputs = model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Get top-k | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| # Convert to lists | |
| top_probs = top_probs.cpu().numpy()[0] | |
| top_indices = top_indices.cpu().numpy()[0] | |
| # Format results | |
| results = [] | |
| for prob, idx in zip(top_probs, top_indices): | |
| class_name = FOOD_CLASSES[idx] | |
| confidence = float(prob) * 100 | |
| results.append((class_name, confidence)) | |
| return results | |
| # ============================================================ | |
| # OUTPUT FORMATTING | |
| # ============================================================ | |
| def print_predictions(image_path, predictions, detailed=True): | |
| """ | |
| Prints predictions in a nice format. | |
| Args: | |
| image_path: Path to image | |
| predictions: List of (class_name, confidence) tuples | |
| detailed: Whether to show detailed output | |
| """ | |
| if detailed: | |
| print(f"\n{'='*60}") | |
| print(f"π· Image: {image_path}") | |
| print(f"{'='*60}") | |
| for i, (food, conf) in enumerate(predictions, 1): | |
| emoji = "π₯" if i == 1 else "π₯" if i == 2 else "π₯" | |
| formatted_name = food.replace('_', ' ').title() | |
| bar_length = int(conf / 2) # Scale to 50 chars max | |
| bar = 'β' * bar_length + 'β' * (50 - bar_length) | |
| print(f"\n{emoji} Rank {i}:") | |
| print(f" Food: {formatted_name}") | |
| print(f" Confidence: {conf:.2f}%") | |
| print(f" {bar}") | |
| print(f"\n{'='*60}\n") | |
| else: | |
| # Simple output | |
| top_food, top_conf = predictions[0] | |
| formatted_name = top_food.replace('_', ' ').title() | |
| print(f"{image_path}: {formatted_name} ({top_conf:.1f}%)") | |
| def save_json(image_path, predictions, output_file): | |
| """ | |
| Saves predictions to JSON file. | |
| Args: | |
| image_path: Path to image | |
| predictions: List of (class_name, confidence) tuples | |
| output_file: Output JSON file path | |
| """ | |
| result = { | |
| 'image': str(image_path), | |
| 'predictions': [ | |
| { | |
| 'rank': i, | |
| 'food': food, | |
| 'formatted_name': food.replace('_', ' ').title(), | |
| 'confidence': round(conf, 2) | |
| } | |
| for i, (food, conf) in enumerate(predictions, 1) | |
| ], | |
| 'top_prediction': { | |
| 'food': predictions[0][0], | |
| 'confidence': round(predictions[0][1], 2) | |
| } | |
| } | |
| with open(output_file, 'w') as f: | |
| json.dump(result, f, indent=2) | |
| print(f"πΎ Saved results to: {output_file}") | |
| # ============================================================ | |
| # MAIN FUNCTION | |
| # ============================================================ | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='FoodVision - AI-powered food classification', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python predict.py --image pizza.jpg | |
| python predict.py --image food.png --top 5 | |
| python predict.py --folder food_images/ | |
| python predict.py --image burger.jpg --simple | |
| python predict.py --image pasta.jpg --json output.json | |
| """ | |
| ) | |
| # Input options | |
| input_group = parser.add_mutually_exclusive_group(required=True) | |
| input_group.add_argument('--image', type=str, help='Path to single image') | |
| input_group.add_argument('--folder', type=str, help='Path to folder of images') | |
| # Model options | |
| parser.add_argument('--model', type=str, default='model1_best.pth', | |
| help='Path to model checkpoint (default: model1_best.pth)') | |
| # Output options | |
| parser.add_argument('--top', type=int, default=3, | |
| help='Number of top predictions to show (default: 3)') | |
| parser.add_argument('--simple', action='store_true', | |
| help='Simple output format (one line per image)') | |
| parser.add_argument('--json', type=str, | |
| help='Save results to JSON file') | |
| # Device option | |
| parser.add_argument('--cpu', action='store_true', | |
| help='Force CPU usage (default: auto-detect GPU)') | |
| args = parser.parse_args() | |
| # Setup device | |
| if args.cpu: | |
| device = torch.device('cpu') | |
| print("π» Using CPU") | |
| else: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if device.type == 'cuda': | |
| print(f"β‘ Using GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| print("π» Using CPU (no GPU detected)") | |
| # Load model | |
| model = load_model(args.model, device) | |
| print() # Blank line | |
| # Get list of images to process | |
| if args.image: | |
| image_paths = [Path(args.image)] | |
| else: | |
| # Process folder | |
| folder = Path(args.folder) | |
| image_paths = list(folder.glob('*.jpg')) + list(folder.glob('*.jpeg')) + \ | |
| list(folder.glob('*.png')) + list(folder.glob('*.webp')) | |
| if not image_paths: | |
| print(f"β No images found in {folder}") | |
| sys.exit(1) | |
| print(f"π Found {len(image_paths)} images in {folder}\n") | |
| # Process each image | |
| all_results = {} | |
| for img_path in image_paths: | |
| # Preprocess | |
| img_tensor = preprocess_image(img_path) | |
| if img_tensor is None: | |
| continue | |
| # Predict | |
| predictions = predict(model, img_tensor, device, args.top) | |
| # Store results | |
| all_results[str(img_path)] = predictions | |
| # Print predictions | |
| print_predictions(img_path, predictions, detailed=not args.simple) | |
| # Save to JSON if requested | |
| if args.json: | |
| if len(all_results) == 1: | |
| # Single image - save simple format | |
| img_path, predictions = list(all_results.items())[0] | |
| save_json(img_path, predictions, args.json) | |
| else: | |
| # Multiple images - save batch format | |
| batch_results = { | |
| 'total_images': len(all_results), | |
| 'results': [] | |
| } | |
| for img_path, predictions in all_results.items(): | |
| batch_results['results'].append({ | |
| 'image': img_path, | |
| 'top_prediction': { | |
| 'food': predictions[0][0], | |
| 'confidence': round(predictions[0][1], 2) | |
| }, | |
| 'all_predictions': [ | |
| { | |
| 'rank': i, | |
| 'food': food, | |
| 'confidence': round(conf, 2) | |
| } | |
| for i, (food, conf) in enumerate(predictions, 1) | |
| ] | |
| }) | |
| with open(args.json, 'w') as f: | |
| json.dump(batch_results, f, indent=2) | |
| print(f"\nπΎ Saved batch results to: {args.json}") | |
| # Summary for batch processing | |
| if len(all_results) > 1: | |
| print(f"\n{'='*60}") | |
| print(f"β Successfully processed {len(all_results)} images") | |
| print(f"{'='*60}") | |
| # ============================================================ | |
| # RUN SCRIPT | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| main() |