File size: 3,962 Bytes
05acb49
0f32912
f33b2e2
05acb49
 
 
be36ee6
05acb49
 
 
 
 
 
 
be36ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05acb49
 
 
0f32912
 
05acb49
 
 
 
 
 
0f32912
05acb49
 
2c1cd4b
4fdd5ac
cfa8095
05acb49
 
 
 
 
 
 
 
 
 
 
10145f4
9c54c9d
2c1cd4b
4fdd5ac
 
 
38eeae0
 
 
0f32912
df2873d
05acb49
0f32912
 
 
 
4fdd5ac
05acb49
0c1de5b
e4c8bfb
 
 
0c1de5b
2c1cd4b
 
 
 
0f32912
0c1de5b
05acb49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from transformers import Tool, load_tool
from huggingface_hub import hf_hub_download
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from llama_index import (
    Prompt,
    LLMPredictor,
    ServiceContext, 
    StorageContext, 
    load_index_from_storage, 
    set_global_service_context
)

text_qa_template = Prompt(
    "Examples of text-to-image prompts are below: \n"
    "---------------------\n"
    "{context_str}"
    "\n---------------------\n"
    "Given the existing examples of text-to-image prompts, "
    "write a new text-to-image prompt in the style of the examples, by re-wording the following prompt to match the style of the above examples: {query_str}\n"
)


refine_template = Prompt(
    "The initial prompt is as follows: {query_str}\n"
    "We have provided an existing text-to-image prompt based on this query: {existing_answer}\n"
    "We have the opportunity to refine the existing prompt "
    "(only if needed) with some more relevant examples of text-to-image prompts below.\n"
    "------------\n"
    "{context_msg}\n"
    "------------\n"
    "Given the new examples of text-to-image prompts, refine the existing text-to-image prompt to better "
    "statisfy the required style. "
    "If the context isn't useful, or the existing prompt is good enough, return the existing prompt."
)


PROMPT_ASSISTANT_DESCRIPTION = (
    "This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
    "contains the image description and outputs an image."
)


class Text2ImagePromptAssistant(Tool):
    
    inputs = ['text']
    outputs = ['image']
    description = PROMPT_ASSISTANT_DESCRIPTION
    
    def __init__(self, *args, openai_api_key='', model_name='text-davinci-003', temperature=0.3, verbose=False, **hub_kwargs):
        super().__init__()
        os.environ['OPENAI_API_KEY'] = openai_api_key
        if model_name == 'text-davinci-003':
            llm = OpenAI(model_name=model_name, temperature=temperature)
        elif model_name in ('gpt-3.5-turbo', 'gpt-4'):
            llm = ChatOpenAI(model_name=model_name, temperature=temperature)
        else:
            raise ValueError(
                f"{model_name} is not supported, please choose one "
                "of 'text-davinci-003', 'gpt-3.5-turbo', or 'gpt-4'."
            )
        service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
        set_global_service_context(service_context)
        
        self.storage_path = os.path.dirname(__file__)
        self.verbose = verbose
        self.hub_kwargs = hub_kwargs

    def setup(self):
        hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/vector_store.json", repo_type="space", local_dir=self.storage_path)
        hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/index_store.json", repo_type="space", local_dir=self.storage_path)
        hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/docstore.json", repo_type="space", local_dir=self.storage_path)
        
        self.index = load_index_from_storage(StorageContext.from_defaults(persist_dir=os.path.join(self.storage_path, "storage")))
        self.query_engine = self.index.as_query_engine(similarity_top_k=5, text_qa_template=text_qa_template, refine_template=refine_template)
        
        self.text2image = load_tool('huggingface-tools/text-to-image')
        self.text2image.setup()

        self.initialized = True

    def __call__(self, prompt):
        if not self.is_initialized:
            self.setup()

        better_prompt = str(self.query_engine.query(prompt)).strip()
        
        if self.verbose:
            print('==New prompt generated by LlamaIndex==', flush=True)
            print(better_prompt, '\n', flush=True)

        return self.text2image(better_prompt)