import torch from torch import nn from transformers import LlavaForConditionalGeneration, LlavaConfig import re from PIL import Image from random import randint class VQApair(LlavaForConditionalGeneration): config_class = LlavaConfig def __init__(self, config, **kwargs): super().__init__(config) self.processor = kwargs.pop("proc") def genChoice(self, question, base_prompt, img_obj): base_prompt += "{}<|end|>\n<|user|> Suggest 1 correct answer<|end|><|assistant|> ".format(question) inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0) output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500) index = torch.where(output[0]==32001)[0][-1].item() answer = self.processor.decode(output[0][index:], skip_special_tokens=True) base_prompt += "{}<|end|>\n<|user|> Suggest 3 incorrect answers<|end|><|assistant|> ".format(answer) inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0) output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500) index = torch.where(output[0]==32001)[0][-1].item() choices = self.processor.decode(output[0][index:], skip_special_tokens=True) a = choices.split("\n") a = [x[3:].strip() for x in a] a = [x for x in a if x] correct_answer = randint(0,len(a)) a.insert(correct_answer, answer) a = ["{}) {}".format(i+1, a[i]) for i in range(len(a))] ans = "Correct Answer: {}".format(a[correct_answer]) return {"Choices": a, "Answers": ans} def generateQn(self, img_path, n): #commands = ["Generate a simple question",""] prompt =''' <|user|>\n\nDescribe this image in a passage<|end|><|assistant|> ''' artifacts = [] img_obj = Image.open(img_path) inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0) #Generate Desc output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500) index = torch.where(output[0]==32001)[0][-1].item() desc = self.processor.decode(output[0][index:], skip_special_tokens=True) #Update Prompt to generate question prompt += "{}<|end|>\n<|user|> {}<|end|><|assistant|> ".format(desc,"Generate a simple question") inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0) #Generate k questions output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500, do_sample=False, num_beams=3,num_beam_groups=3,diversity_penalty=10.0, num_return_sequences=n) for out in output: entry = {} index = torch.where(out==32001)[0][-1].item() text = self.processor.decode(out[index:], skip_special_tokens=True) entry.update({"desc":desc}) entry.update({"question":text}) entry.update(self.genChoice(text, prompt, img_obj)) artifacts.append(entry) return artifacts