File size: 5,829 Bytes
e06a9ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch
from PIL import Image
from torch.utils.data import Dataset
# Step 1: Load the World Cuisines dataset
ds = load_dataset("worldcuisines/food-kb")
# Access the 'main' dataset
dataset = ds['main']
# Check the structure of the dataset
print(dataset)
# Converting dataset to a list of dictionaries for easier manipulation
data_list = dataset.to_dict()['image1'] # Accessing the first image column (you can access others like image2, etc.)
# Now split the dataset into train and test
train_data, test_data = train_test_split(data_list, test_size=0.2)
# Check the shapes of train_data and test_data
print(f"Training data size: {len(train_data)}")
print(f"Testing data size: {len(test_data)}")
# Define a custom dataset class for the image classification task
class FoodDataset(Dataset):
def __init__(self, dataset, processor, max_length=256):
self.dataset = dataset
self.processor = processor
self.max_length = max_length
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# For simplicity, let's use image1 for training and test
image = Image.open(item['image1']) # Assuming 'image1' has the food images
label = item['fine_categories'] # You can modify this based on the label
# Process the image
encoding = self.processor(images=image, return_tensors="pt", padding=True, truncation=True)
# Return the input and target labels
return {
'input_ids': encoding['input_ids'].squeeze(),
'attention_mask': encoding['attention_mask'].squeeze(),
'labels': label # Assuming that 'fine_categories' is used as labels
}
# Step 2: Load the ViT model for image classification
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Step 3: Load the text generation model (Gemini) for nutrition breakdown and diet plan
tokenizer = AutoTokenizer.from_pretrained("describeai/gemini")
gemini_model = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini")
# Helper function to get nutritional breakdown and allergen information
def get_nutrition_and_allergens(food_name):
# Look for the food item in the dataset
result = None
try:
dataset = ds['main'] # Access the correct dataset split
for item in dataset:
if food_name.lower() in item['name'].lower():
result = item
break
if result:
nutrition_info = result.get('nutrition', 'Nutrition information not available')
allergens = result.get('allergens', 'Allergen information not available')
diet_plan = f"This item is suitable for a diet including {result.get('suitable_for', 'N/A')}."
else:
nutrition_info = "Food item not found in the database."
allergens = "Allergen information not available."
diet_plan = "Diet plan not available for this food item."
except KeyError as e:
nutrition_info = f"Key error: {e}"
allergens = "Allergen information not available."
diet_plan = "Diet plan not available."
except Exception as e:
nutrition_info = f"An error occurred: {str(e)}"
allergens = "Allergen information not available."
diet_plan = "Diet plan not available."
return nutrition_info, allergens, diet_plan
# Main prediction function for the image classification and text generation
def predict(image):
try:
# Step 1: Classify the food item in the image using ViT model
inputs = processor(images=image, return_tensors="pt")
outputs = vit_model(**inputs)
# Get the predicted label (food item)
predicted_label = outputs.logits.argmax(-1).item()
# Get the food name from the class labels (assuming the model has the food labels)
class_labels = vit_model.config.id2label # Get the class label mapping
food_item = class_labels[predicted_label]
# Step 2: Generate nutritional breakdown, allergens, and diet plan
nutrition_info, allergens, diet_plan = get_nutrition_and_allergens(food_item)
# Step 3: Generate a detailed description using the Gemini model
description_input = f"Nutritional breakdown and diet plan for {food_item}"
diet_plan_text = tokenizer(description_input, return_tensors="pt", padding=True, truncation=True)
diet_plan_output = gemini_model.generate(**diet_plan_text)
diet_plan_text = tokenizer.decode(diet_plan_output[0], skip_special_tokens=True)
# Combine results into a single output
response = f"**Detected Food:** {food_item}\n\n"
response += f"**Nutrition Info:** {nutrition_info}\n\n"
response += f"**Allergens:** {allergens}\n\n"
response += f"**Diet Plan:** {diet_plan}\n\n"
response += f"**Detailed Diet Plan and Breakdown:** {diet_plan_text}"
except Exception as e:
response = f"Error: {str(e)}"
return response
# Step 4: Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="NutriScan: AI-Powered Food Analyzer",
description="Upload an image of food, and get a nutritional breakdown, allergen information, and diet plan recommendations.",
examples=[["path_to_example_image.jpg"]] # replace with paths to example images if needed
)
# Launch the Gradio interface
if __name__ == "__main__":
interface.launch()
|