import json import numpy as np from PIL import Image import torch.nn.functional as F import torch from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False class Chat: def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): self.device = device self.model = model self.transform = transform self.df = dataframe self.tar_img_feats = tar_img_feats self.img_feats = None self.target_recipe = None self.messages = [] if stopping_criteria is not None: self.stopping_criteria = stopping_criteria else: stop_words_ids = [torch.tensor([2]).to(self.device)] self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def encode_image(self, image_path): img = Image.fromarray(image_path).convert("RGB") img = self.transform(img).unsqueeze(0) img = img.to(self.device) img_embs = self.model.visual_encoder(img) img_feats = F.normalize(self.model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() self.img_feats = img_feats self.get_target(self.img_feats, self.tar_img_feats) def get_target(self, img_feats, tar_img_feats) : score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() index = np.argsort(score)[::-1][0] + 1 self.target_recipe = self.df.iloc[index] def ask(self, msg): if "nutrition" in msg or "nutrients" in msg : return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) elif "instruction" in msg : return json.dumps(self.target_recipe["recipe_instructions"], indent=4) elif "ingredients" in msg : return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) elif "tag" in msg or "class" in msg : return json.dumps(self.target_recipe["tags"], indent=4) else: return "Conversational capabilities will be included later."