Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
from PIL import Image | |
import torch.nn.functional as F | |
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all(input_ids[:, -len(stop):] == stop).item(): | |
return True | |
return False | |
class Chat: | |
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): | |
self.device = device | |
self.model = model | |
self.transform = transform | |
self.df = dataframe | |
self.tar_img_feats = tar_img_feats | |
self.img_feats = None | |
self.target_recipe = None | |
self.messages = [] | |
if stopping_criteria is not None: | |
self.stopping_criteria = stopping_criteria | |
else: | |
stop_words_ids = [torch.tensor([2]).to(self.device)] | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def encode_image(self, image_path): | |
img = Image.fromarray(image_path).convert("RGB") | |
img = self.transform(img).unsqueeze(0) | |
img = img.to(self.device) | |
img_embs = self.model.visual_encoder(img) | |
img_feats = F.normalize(self.model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() | |
self.img_feats = img_feats | |
self.get_target(self.img_feats, self.tar_img_feats) | |
def get_target(self, img_feats, tar_img_feats) : | |
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() | |
index = np.argsort(score)[::-1][0] + 1 | |
self.target_recipe = self.df.iloc[index] | |
def ask(self, msg): | |
if "nutrition" in msg or "nutrients" in msg : | |
return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) | |
elif "instruction" in msg : | |
return json.dumps(self.target_recipe["recipe_instructions"], indent=4) | |
elif "ingredients" in msg : | |
return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) | |
elif "tag" in msg or "class" in msg : | |
return json.dumps(self.target_recipe["tags"], indent=4) | |
else: | |
return "Conversational capabilities will be included later." | |