|
import gradio as gr |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM |
|
from datasets import load_dataset |
|
from sklearn.model_selection import train_test_split |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
ds = load_dataset("worldcuisines/food-kb") |
|
|
|
|
|
dataset = ds['main'] |
|
|
|
|
|
print(dataset) |
|
|
|
|
|
data_list = dataset.to_dict()['image1'] |
|
|
|
|
|
train_data, test_data = train_test_split(data_list, test_size=0.2) |
|
|
|
|
|
print(f"Training data size: {len(train_data)}") |
|
print(f"Testing data size: {len(test_data)}") |
|
|
|
|
|
class FoodDataset(Dataset): |
|
def __init__(self, dataset, processor, max_length=256): |
|
self.dataset = dataset |
|
self.processor = processor |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
item = self.dataset[idx] |
|
|
|
image = Image.open(item['image1']) |
|
label = item['fine_categories'] |
|
|
|
|
|
encoding = self.processor(images=image, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
return { |
|
'input_ids': encoding['input_ids'].squeeze(), |
|
'attention_mask': encoding['attention_mask'].squeeze(), |
|
'labels': label |
|
} |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") |
|
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("describeai/gemini") |
|
gemini_model = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini") |
|
|
|
|
|
def get_nutrition_and_allergens(food_name): |
|
|
|
result = None |
|
try: |
|
dataset = ds['main'] |
|
for item in dataset: |
|
if food_name.lower() in item['name'].lower(): |
|
result = item |
|
break |
|
|
|
if result: |
|
nutrition_info = result.get('nutrition', 'Nutrition information not available') |
|
allergens = result.get('allergens', 'Allergen information not available') |
|
diet_plan = f"This item is suitable for a diet including {result.get('suitable_for', 'N/A')}." |
|
else: |
|
nutrition_info = "Food item not found in the database." |
|
allergens = "Allergen information not available." |
|
diet_plan = "Diet plan not available for this food item." |
|
|
|
except KeyError as e: |
|
nutrition_info = f"Key error: {e}" |
|
allergens = "Allergen information not available." |
|
diet_plan = "Diet plan not available." |
|
|
|
except Exception as e: |
|
nutrition_info = f"An error occurred: {str(e)}" |
|
allergens = "Allergen information not available." |
|
diet_plan = "Diet plan not available." |
|
|
|
return nutrition_info, allergens, diet_plan |
|
|
|
|
|
def predict(image): |
|
try: |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = vit_model(**inputs) |
|
|
|
|
|
predicted_label = outputs.logits.argmax(-1).item() |
|
|
|
|
|
class_labels = vit_model.config.id2label |
|
food_item = class_labels[predicted_label] |
|
|
|
|
|
nutrition_info, allergens, diet_plan = get_nutrition_and_allergens(food_item) |
|
|
|
|
|
description_input = f"Nutritional breakdown and diet plan for {food_item}" |
|
diet_plan_text = tokenizer(description_input, return_tensors="pt", padding=True, truncation=True) |
|
diet_plan_output = gemini_model.generate(**diet_plan_text) |
|
diet_plan_text = tokenizer.decode(diet_plan_output[0], skip_special_tokens=True) |
|
|
|
|
|
response = f"**Detected Food:** {food_item}\n\n" |
|
response += f"**Nutrition Info:** {nutrition_info}\n\n" |
|
response += f"**Allergens:** {allergens}\n\n" |
|
response += f"**Diet Plan:** {diet_plan}\n\n" |
|
response += f"**Detailed Diet Plan and Breakdown:** {diet_plan_text}" |
|
|
|
except Exception as e: |
|
response = f"Error: {str(e)}" |
|
|
|
return response |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="NutriScan: AI-Powered Food Analyzer", |
|
description="Upload an image of food, and get a nutritional breakdown, allergen information, and diet plan recommendations.", |
|
examples=[["path_to_example_image.jpg"]] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|