from langchain.llms.openai import OpenAI import torch from PIL import Image, ImageDraw, ImageOps from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering import pdb class TextRefiner: def __init__(self, device, api_key=""): print(f"Initializing TextRefiner to {device}") self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key) self.prompt_tag = { "imagination": {"True": "could", "False": "could not"} } self.short_prompts = { "length": "around {length} words", "sentiment": "of {sentiment} sentiment", "language": "in {language}", } self.long_prompts = { "imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence", } self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"." self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)" def parse(self, response): out = response.strip() return out def parse2(self, response): out = response.strip() return out def prepare_input(self, query, short_prompts, long_prompts): input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query}) print('prompt: ', input) return input def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False): """ query: the caption of the region of interest, generated by captioner controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"} """ prompts = [] long_prompts = [] for control, value in controls.items(): # if control in self.prompt_tag: # value = self.prompt_tag[control][value] if control in self.short_prompts: prompts.append(self.short_prompts[control].format(**{control: value})) else: if value in [True, "True", "true"]: long_prompts.append(self.long_prompts[control]) input = self.prepare_input(query, prompts, long_prompts) response = self.llm(input) response = self.parse(response) response_wiki = "" if enable_wiki: tmp_configs = {"query": query} prompt_wiki = self.wiki_prompts.format(**tmp_configs) response_wiki = self.llm(prompt_wiki) response_wiki = self.parse2(response_wiki) out = { 'raw_caption': query, 'caption': response, 'wiki': response_wiki } print(out) return out if __name__ == "__main__": model = TextRefiner(device='cpu') controls = { "length": "30", "sentiment": "negative", # "imagination": "True", "imagination": "False", "language": "English", } # model.inference(query='a dog is sitting on a brown bench', controls=controls) model.inference(query='a cat is sleeping', controls=controls)