File size: 5,829 Bytes
e06a9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch
from PIL import Image
from torch.utils.data import Dataset

# Step 1: Load the World Cuisines dataset
ds = load_dataset("worldcuisines/food-kb")

# Access the 'main' dataset
dataset = ds['main']

# Check the structure of the dataset
print(dataset)

# Converting dataset to a list of dictionaries for easier manipulation
data_list = dataset.to_dict()['image1']  # Accessing the first image column (you can access others like image2, etc.)

# Now split the dataset into train and test
train_data, test_data = train_test_split(data_list, test_size=0.2)

# Check the shapes of train_data and test_data
print(f"Training data size: {len(train_data)}")
print(f"Testing data size: {len(test_data)}")

# Define a custom dataset class for the image classification task
class FoodDataset(Dataset):
    def __init__(self, dataset, processor, max_length=256):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # For simplicity, let's use image1 for training and test
        image = Image.open(item['image1'])  # Assuming 'image1' has the food images
        label = item['fine_categories']  # You can modify this based on the label

        # Process the image
        encoding = self.processor(images=image, return_tensors="pt", padding=True, truncation=True)

        # Return the input and target labels
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': label  # Assuming that 'fine_categories' is used as labels
        }

# Step 2: Load the ViT model for image classification
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

# Step 3: Load the text generation model (Gemini) for nutrition breakdown and diet plan
tokenizer = AutoTokenizer.from_pretrained("describeai/gemini")
gemini_model = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini")

# Helper function to get nutritional breakdown and allergen information
def get_nutrition_and_allergens(food_name):
    # Look for the food item in the dataset
    result = None
    try:
        dataset = ds['main']  # Access the correct dataset split
        for item in dataset:
            if food_name.lower() in item['name'].lower():
                result = item
                break

        if result:
            nutrition_info = result.get('nutrition', 'Nutrition information not available')
            allergens = result.get('allergens', 'Allergen information not available')
            diet_plan = f"This item is suitable for a diet including {result.get('suitable_for', 'N/A')}."
        else:
            nutrition_info = "Food item not found in the database."
            allergens = "Allergen information not available."
            diet_plan = "Diet plan not available for this food item."

    except KeyError as e:
        nutrition_info = f"Key error: {e}"
        allergens = "Allergen information not available."
        diet_plan = "Diet plan not available."

    except Exception as e:
        nutrition_info = f"An error occurred: {str(e)}"
        allergens = "Allergen information not available."
        diet_plan = "Diet plan not available."

    return nutrition_info, allergens, diet_plan

# Main prediction function for the image classification and text generation
def predict(image):
    try:
        # Step 1: Classify the food item in the image using ViT model
        inputs = processor(images=image, return_tensors="pt")
        outputs = vit_model(**inputs)

        # Get the predicted label (food item)
        predicted_label = outputs.logits.argmax(-1).item()

        # Get the food name from the class labels (assuming the model has the food labels)
        class_labels = vit_model.config.id2label  # Get the class label mapping
        food_item = class_labels[predicted_label]

        # Step 2: Generate nutritional breakdown, allergens, and diet plan
        nutrition_info, allergens, diet_plan = get_nutrition_and_allergens(food_item)

        # Step 3: Generate a detailed description using the Gemini model
        description_input = f"Nutritional breakdown and diet plan for {food_item}"
        diet_plan_text = tokenizer(description_input, return_tensors="pt", padding=True, truncation=True)
        diet_plan_output = gemini_model.generate(**diet_plan_text)
        diet_plan_text = tokenizer.decode(diet_plan_output[0], skip_special_tokens=True)

        # Combine results into a single output
        response = f"**Detected Food:** {food_item}\n\n"
        response += f"**Nutrition Info:** {nutrition_info}\n\n"
        response += f"**Allergens:** {allergens}\n\n"
        response += f"**Diet Plan:** {diet_plan}\n\n"
        response += f"**Detailed Diet Plan and Breakdown:** {diet_plan_text}"

    except Exception as e:
        response = f"Error: {str(e)}"

    return response

# Step 4: Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="NutriScan: AI-Powered Food Analyzer",
    description="Upload an image of food, and get a nutritional breakdown, allergen information, and diet plan recommendations.",
    examples=[["path_to_example_image.jpg"]]  # replace with paths to example images if needed
)

# Launch the Gradio interface
if __name__ == "__main__":
    interface.launch()