File size: 4,072 Bytes
c426a27
 
 
 
 
 
 
5c74464
c426a27
5c74464
c426a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eabdb1c
c426a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eabdb1c
 
 
 
 
 
c426a27
 
 
 
eabdb1c
c426a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from langchain.llms.openai import OpenAI
import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
import pdb

class TextRefiner:
    def __init__(self, device, api_key=""):
        print(f"Initializing TextRefiner to {device}")
        self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
        self.prompt_tag = {
            "imagination": {"True": "could",
                            "False": "could not"}
        }
        self.short_prompts = {
            "length": "around {length} words",
            "sentiment": "of {sentiment} sentiment",
            "language": "in {language}",
            }
        
        self.long_prompts = {
            "imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence",
        }
        
        self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"."

        self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)"
        
    def parse(self, response):
        out = response.strip()
        return out
    
    def parse2(self, response):
        out = response.strip()
        return out
    
    def prepare_input(self, query, short_prompts, long_prompts):
        input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query})
        print('prompt: ', input)
        return input
    
    def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
        """
        query: the caption of the region of interest, generated by captioner
        controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
        """
        prompts = []
        long_prompts = []
        for control, value in controls.items():
            # if control in self.prompt_tag:
            #     value = self.prompt_tag[control][value]
            if control in self.short_prompts:
                prompts.append(self.short_prompts[control].format(**{control: value}))
            else:
                if value in [True, "True", "true"]:
                    long_prompts.append(self.long_prompts[control])
        input = self.prepare_input(query, prompts, long_prompts)
        response = self.llm(input)
        response = self.parse(response)
        
        response_wiki = ""
        if enable_wiki:
            tmp_configs = {"query": query}
            prompt_wiki = self.wiki_prompts.format(**tmp_configs)
            response_wiki = self.llm(prompt_wiki)
            response_wiki = self.parse2(response_wiki)
        out = {
            'raw_caption': query,
            'caption': response,
            'wiki': response_wiki
        }
        print(out)
        return out
    
if __name__ == "__main__":
    model = TextRefiner(device='cpu')
    controls = {
        "length": "30",
        "sentiment": "negative",
        # "imagination": "True",
        "imagination": "False",
        "language": "English",
    }
    # model.inference(query='a dog is sitting on a brown bench', controls=controls)
    model.inference(query='a cat is sleeping', controls=controls)