Nutrigenics-chatbot / demo_chat.py
OmkarThawakar
initail commit
ed00004
import json
import numpy as np
from PIL import Image
import torch.nn.functional as F
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all(input_ids[:, -len(stop):] == stop).item():
return True
return False
class Chat:
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None):
self.device = device
self.model = model
self.transform = transform
self.df = dataframe
self.tar_img_feats = tar_img_feats
self.img_feats = None
self.target_recipe = None
self.messages = []
if stopping_criteria is not None:
self.stopping_criteria = stopping_criteria
else:
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def encode_image(self, image_path):
img = Image.fromarray(image_path).convert("RGB")
img = self.transform(img).unsqueeze(0)
img = img.to(self.device)
img_embs = self.model.visual_encoder(img)
img_feats = F.normalize(self.model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()
self.img_feats = img_feats
self.get_target(self.img_feats, self.tar_img_feats)
def get_target(self, img_feats, tar_img_feats) :
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()
index = np.argsort(score)[::-1][0] + 1
self.target_recipe = self.df.iloc[index]
def ask(self, msg):
if "nutrition" in msg or "nutrients" in msg :
return json.dumps(self.target_recipe["recipe_nutrients"], indent=4)
elif "instruction" in msg :
return json.dumps(self.target_recipe["recipe_instructions"], indent=4)
elif "ingredients" in msg :
return json.dumps(self.target_recipe["recipe_ingredients"], indent=4)
elif "tag" in msg or "class" in msg :
return json.dumps(self.target_recipe["tags"], indent=4)
else:
return "Conversational capabilities will be included later."