Spaces:
Sleeping
Sleeping
import pandas as pd | |
import json | |
from PIL import Image | |
import numpy as np | |
import os | |
import sys | |
from pathlib import Path | |
import torch | |
import torch.nn.functional as F | |
from src.data.embs import ImageDataset | |
from src.model.blip_embs import blip_embs | |
from src.data.transforms import transform_test | |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import gradio as gr | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all(input_ids[:, -len(stop):] == stop).item(): | |
return True | |
return False | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_blip_config(model="base"): | |
config = dict() | |
if model == "base": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " | |
config["vit"] = "base" | |
config["batch_size_train"] = 32 | |
config["batch_size_test"] = 16 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 4 | |
config["init_lr"] = 1e-5 | |
elif model == "large": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" | |
config["vit"] = "large" | |
config["batch_size_train"] = 16 | |
config["batch_size_test"] = 32 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 12 | |
config["init_lr"] = 5e-6 | |
config["image_size"] = 384 | |
config["queue_size"] = 57600 | |
config["alpha"] = 0.4 | |
config["k_test"] = 256 | |
config["negative_all_rank"] = True | |
return config | |
print("Creating model") | |
config = get_blip_config("large") | |
model = blip_embs( | |
pretrained=config["pretrained"], | |
image_size=config["image_size"], | |
vit=config["vit"], | |
vit_grad_ckpt=config["vit_grad_ckpt"], | |
vit_ckpt_layer=config["vit_ckpt_layer"], | |
queue_size=config["queue_size"], | |
negative_all_rank=config["negative_all_rank"], | |
) | |
model = model.to(device) | |
model.eval() | |
print("Model Loaded !") | |
print("="*50) | |
transform = transform_test(384) | |
print("Loading Data") | |
df = pd.read_json("datasets/sidechef/my_recipes.json") | |
print("Loading Target Embedding") | |
tar_img_feats = [] | |
for _id in df["id_"].tolist(): | |
tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) | |
tar_img_feats = torch.cat(tar_img_feats, dim=0) | |
class Chat: | |
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): | |
self.device = device | |
self.model = model | |
self.transform = transform | |
self.df = dataframe | |
self.tar_img_feats = tar_img_feats | |
self.img_feats = None | |
self.target_recipe = None | |
self.messages = [] | |
if stopping_criteria is not None: | |
self.stopping_criteria = stopping_criteria | |
else: | |
stop_words_ids = [torch.tensor([2]).to(self.device)] | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def encode_image(self, image_path): | |
img = Image.fromarray(image_path).convert("RGB") | |
img = self.transform(img).unsqueeze(0) | |
img = img.to(self.device) | |
img_embs = model.visual_encoder(img) | |
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() | |
self.img_feats = img_feats | |
self.get_target(self.img_feats, self.tar_img_feats) | |
def get_target(self, img_feats, tar_img_feats) : | |
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() | |
index = np.argsort(score)[::-1][0] + 1 | |
self.target_recipe = df.iloc[index] | |
def ask(self, msg): | |
if "nutrition" in msg or "nutrients" in msg : | |
return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) | |
elif "instruction" in msg : | |
return json.dumps(self.target_recipe["recipe_instructions"], indent=4) | |
elif "ingredients" in msg : | |
return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) | |
elif "tag" in msg or "class" in msg : | |
return json.dumps(self.target_recipe["tags"], indent=4) | |
else: | |
return "Conversational capabilities will be included later." | |
chat = Chat(model,transform,df,tar_img_feats) | |
print("Chat Initialized !") | |
custom_css = """ | |
.primary{ | |
background-color: #4CAF50; /* Green */ | |
} | |
""" | |
def respond_to_user(image, message): | |
# Process the image and message here | |
# For demonstration, I'll just return a simple text response | |
chat = Chat(model,transform,df,tar_img_feats) | |
chat.encode_image(image) | |
response = chat.ask(message) | |
return response | |
iface = gr.Interface( | |
fn=respond_to_user, | |
inputs=[gr.Image(), gr.Textbox(label="Ask Query")], | |
outputs=gr.Textbox(label="Nutrition-GPT"), | |
title="Nutrition-GPT Demo", | |
description="Upload an food image and ask queries!", | |
css=".component-12 {background-color: red}", | |
) | |
iface.launch() |