FoodVision / predict.py
doozer21's picture
Initial FoodVision deployment with Git LFS for model
74037f6
raw
history blame
13.1 kB
# predict.py - FoodVision Command Line Prediction Script
# ============================================================
#
# USAGE:
# ------
# python predict.py --image path/to/food.jpg
# python predict.py --image pizza.jpg --top 5
# python predict.py --folder food_images/
# python predict.py --image burger.png --json results.json
#
# FEATURES:
# ---------
# βœ… Single image prediction
# βœ… Batch prediction on folder
# βœ… JSON output option
# βœ… Detailed or simple output modes
#
# ============================================================
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import argparse
from pathlib import Path
import json
import sys
# ============================================================
# FOOD CLASSES (101 categories)
# ============================================================
FOOD_CLASSES = [
"apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare",
"beet_salad", "beignets", "bibimbap", "bread_pudding", "breakfast_burrito",
"bruschetta", "caesar_salad", "cannoli", "caprese_salad", "carrot_cake",
"ceviche", "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla",
"chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder",
"club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes",
"deviled_eggs", "donuts", "dumplings", "edamame", "eggs_benedict",
"escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras",
"french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice",
"frozen_yogurt", "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich",
"grilled_salmon", "guacamole", "gyoza", "hamburger", "hot_and_sour_soup",
"hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna",
"lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
"mussels", "nachos", "omelette", "onion_rings", "oysters",
"pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck",
"pho", "pizza", "pork_chop", "poutine", "prime_rib",
"pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto",
"samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits",
"spaghetti_bolognese", "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake",
"sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles"
]
# ============================================================
# MODEL LOADING
# ============================================================
def load_model(model_path, device):
"""
Loads a trained model from checkpoint.
Args:
model_path: Path to .pth checkpoint file
device: torch device ('cuda' or 'cpu')
Returns:
Loaded model in eval mode
"""
print(f"πŸ“‚ Loading model from: {model_path}")
try:
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Get model config
model_config = checkpoint.get('model_config', {
'model_id': 'convnextv2_base.fcmae_ft_in22k_in1k_384'
})
# Create model
model = timm.create_model(
model_config['model_id'],
pretrained=False,
num_classes=101
)
# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
accuracy = checkpoint.get('best_val_acc', 0)
print(f"βœ… Model loaded successfully!")
print(f" Architecture: {model_config.get('name', 'ConvNeXt V2')}")
if accuracy > 0:
print(f" Accuracy: {accuracy:.2f}%")
return model
except Exception as e:
print(f"❌ Error loading model: {e}")
sys.exit(1)
# ============================================================
# IMAGE PREPROCESSING
# ============================================================
def preprocess_image(image_path):
"""
Loads and preprocesses an image.
Args:
image_path: Path to image file
Returns:
Preprocessed image tensor
"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
try:
# Load and convert image
image = Image.open(image_path).convert('RGB')
# Apply transforms
img_tensor = transform(image).unsqueeze(0)
return img_tensor
except Exception as e:
print(f"❌ Error loading image {image_path}: {e}")
return None
# ============================================================
# PREDICTION FUNCTION
# ============================================================
def predict(model, image_tensor, device, top_k=3):
"""
Predicts food class for a single image.
Args:
model: Trained PyTorch model
image_tensor: Preprocessed image
device: torch device
top_k: Number of top predictions
Returns:
List of (class_name, confidence) tuples
"""
with torch.no_grad():
# Move to device
image_tensor = image_tensor.to(device)
# Forward pass
outputs = model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get top-k
top_probs, top_indices = torch.topk(probabilities, top_k)
# Convert to lists
top_probs = top_probs.cpu().numpy()[0]
top_indices = top_indices.cpu().numpy()[0]
# Format results
results = []
for prob, idx in zip(top_probs, top_indices):
class_name = FOOD_CLASSES[idx]
confidence = float(prob) * 100
results.append((class_name, confidence))
return results
# ============================================================
# OUTPUT FORMATTING
# ============================================================
def print_predictions(image_path, predictions, detailed=True):
"""
Prints predictions in a nice format.
Args:
image_path: Path to image
predictions: List of (class_name, confidence) tuples
detailed: Whether to show detailed output
"""
if detailed:
print(f"\n{'='*60}")
print(f"πŸ“· Image: {image_path}")
print(f"{'='*60}")
for i, (food, conf) in enumerate(predictions, 1):
emoji = "πŸ₯‡" if i == 1 else "πŸ₯ˆ" if i == 2 else "πŸ₯‰"
formatted_name = food.replace('_', ' ').title()
bar_length = int(conf / 2) # Scale to 50 chars max
bar = 'β–ˆ' * bar_length + 'β–‘' * (50 - bar_length)
print(f"\n{emoji} Rank {i}:")
print(f" Food: {formatted_name}")
print(f" Confidence: {conf:.2f}%")
print(f" {bar}")
print(f"\n{'='*60}\n")
else:
# Simple output
top_food, top_conf = predictions[0]
formatted_name = top_food.replace('_', ' ').title()
print(f"{image_path}: {formatted_name} ({top_conf:.1f}%)")
def save_json(image_path, predictions, output_file):
"""
Saves predictions to JSON file.
Args:
image_path: Path to image
predictions: List of (class_name, confidence) tuples
output_file: Output JSON file path
"""
result = {
'image': str(image_path),
'predictions': [
{
'rank': i,
'food': food,
'formatted_name': food.replace('_', ' ').title(),
'confidence': round(conf, 2)
}
for i, (food, conf) in enumerate(predictions, 1)
],
'top_prediction': {
'food': predictions[0][0],
'confidence': round(predictions[0][1], 2)
}
}
with open(output_file, 'w') as f:
json.dump(result, f, indent=2)
print(f"πŸ’Ύ Saved results to: {output_file}")
# ============================================================
# MAIN FUNCTION
# ============================================================
def main():
parser = argparse.ArgumentParser(
description='FoodVision - AI-powered food classification',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python predict.py --image pizza.jpg
python predict.py --image food.png --top 5
python predict.py --folder food_images/
python predict.py --image burger.jpg --simple
python predict.py --image pasta.jpg --json output.json
"""
)
# Input options
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument('--image', type=str, help='Path to single image')
input_group.add_argument('--folder', type=str, help='Path to folder of images')
# Model options
parser.add_argument('--model', type=str, default='model1_best.pth',
help='Path to model checkpoint (default: model1_best.pth)')
# Output options
parser.add_argument('--top', type=int, default=3,
help='Number of top predictions to show (default: 3)')
parser.add_argument('--simple', action='store_true',
help='Simple output format (one line per image)')
parser.add_argument('--json', type=str,
help='Save results to JSON file')
# Device option
parser.add_argument('--cpu', action='store_true',
help='Force CPU usage (default: auto-detect GPU)')
args = parser.parse_args()
# Setup device
if args.cpu:
device = torch.device('cpu')
print("πŸ’» Using CPU")
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
print(f"⚑ Using GPU: {torch.cuda.get_device_name(0)}")
else:
print("πŸ’» Using CPU (no GPU detected)")
# Load model
model = load_model(args.model, device)
print() # Blank line
# Get list of images to process
if args.image:
image_paths = [Path(args.image)]
else:
# Process folder
folder = Path(args.folder)
image_paths = list(folder.glob('*.jpg')) + list(folder.glob('*.jpeg')) + \
list(folder.glob('*.png')) + list(folder.glob('*.webp'))
if not image_paths:
print(f"❌ No images found in {folder}")
sys.exit(1)
print(f"πŸ“ Found {len(image_paths)} images in {folder}\n")
# Process each image
all_results = {}
for img_path in image_paths:
# Preprocess
img_tensor = preprocess_image(img_path)
if img_tensor is None:
continue
# Predict
predictions = predict(model, img_tensor, device, args.top)
# Store results
all_results[str(img_path)] = predictions
# Print predictions
print_predictions(img_path, predictions, detailed=not args.simple)
# Save to JSON if requested
if args.json:
if len(all_results) == 1:
# Single image - save simple format
img_path, predictions = list(all_results.items())[0]
save_json(img_path, predictions, args.json)
else:
# Multiple images - save batch format
batch_results = {
'total_images': len(all_results),
'results': []
}
for img_path, predictions in all_results.items():
batch_results['results'].append({
'image': img_path,
'top_prediction': {
'food': predictions[0][0],
'confidence': round(predictions[0][1], 2)
},
'all_predictions': [
{
'rank': i,
'food': food,
'confidence': round(conf, 2)
}
for i, (food, conf) in enumerate(predictions, 1)
]
})
with open(args.json, 'w') as f:
json.dump(batch_results, f, indent=2)
print(f"\nπŸ’Ύ Saved batch results to: {args.json}")
# Summary for batch processing
if len(all_results) > 1:
print(f"\n{'='*60}")
print(f"βœ… Successfully processed {len(all_results)} images")
print(f"{'='*60}")
# ============================================================
# RUN SCRIPT
# ============================================================
if __name__ == "__main__":
main()