import pandas as pd import json from PIL import Image import numpy as np import os import sys from pathlib import Path import torch import torch.nn.functional as F from src.data.embs import ImageDataset from src.model.blip_embs import blip_embs from src.data.transforms import transform_test from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import gradio as gr class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_blip_config(model="base"): config = dict() if model == "base": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " config["vit"] = "base" config["batch_size_train"] = 32 config["batch_size_test"] = 16 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 4 config["init_lr"] = 1e-5 elif model == "large": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" config["vit"] = "large" config["batch_size_train"] = 16 config["batch_size_test"] = 32 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 12 config["init_lr"] = 5e-6 config["image_size"] = 384 config["queue_size"] = 57600 config["alpha"] = 0.4 config["k_test"] = 256 config["negative_all_rank"] = True return config print("Creating model") config = get_blip_config("large") model = blip_embs( pretrained=config["pretrained"], image_size=config["image_size"], vit=config["vit"], vit_grad_ckpt=config["vit_grad_ckpt"], vit_ckpt_layer=config["vit_ckpt_layer"], queue_size=config["queue_size"], negative_all_rank=config["negative_all_rank"], ) model = model.to(device) model.eval() print("Model Loaded !") print("="*50) transform = transform_test(384) print("Loading Data") df = pd.read_json("datasets/sidechef/my_recipes.json") print("Loading Target Embedding") tar_img_feats = [] for _id in df["id_"].tolist(): tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) tar_img_feats = torch.cat(tar_img_feats, dim=0) class Chat: def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): self.device = device self.model = model self.transform = transform self.df = dataframe self.tar_img_feats = tar_img_feats self.img_feats = None self.target_recipe = None self.messages = [] if stopping_criteria is not None: self.stopping_criteria = stopping_criteria else: stop_words_ids = [torch.tensor([2]).to(self.device)] self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def encode_image(self, image_path): img = Image.fromarray(image_path).convert("RGB") img = self.transform(img).unsqueeze(0) img = img.to(self.device) img_embs = model.visual_encoder(img) img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() self.img_feats = img_feats self.get_target(self.img_feats, self.tar_img_feats) def get_target(self, img_feats, tar_img_feats) : score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() index = np.argsort(score)[::-1][0] + 1 self.target_recipe = df.iloc[index] def ask(self, msg): if "nutrition" in msg or "nutrients" in msg : return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) elif "instruction" in msg : return json.dumps(self.target_recipe["recipe_instructions"], indent=4) elif "ingredients" in msg : return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) elif "tag" in msg or "class" in msg : return json.dumps(self.target_recipe["tags"], indent=4) else: return "Conversational capabilities will be included later." chat = Chat(model,transform,df,tar_img_feats) print("Chat Initialized !") custom_css = """ .primary{ background-color: #4CAF50; /* Green */ } """ def respond_to_user(image, message): # Process the image and message here # For demonstration, I'll just return a simple text response chat = Chat(model,transform,df,tar_img_feats) chat.encode_image(image) response = chat.ask(message) return response iface = gr.Interface( fn=respond_to_user, inputs=[gr.Image(), gr.Textbox(label="Ask Query")], outputs=gr.Textbox(label="Nutrition-GPT"), title="Nutrition-GPT Demo", description="Upload an food image and ask queries!", css=".component-12 {background-color: red}", ) iface.launch()