import os from transformers import Tool, load_tool from huggingface_hub import hf_hub_download from langchain.llms import OpenAI from langchain.chat_models import ChatOpenAI from llama_index import ( Prompt, LLMPredictor, ServiceContext, StorageContext, load_index_from_storage, set_global_service_context ) text_qa_template = Prompt( "Examples of text-to-image prompts are below: \n" "---------------------\n" "{context_str}" "\n---------------------\n" "Given the existing examples of text-to-image prompts, " "write a new text-to-image prompt in the style of the examples, by re-wording the following prompt to match the style of the above examples: {query_str}\n" ) refine_template = Prompt( "The initial prompt is as follows: {query_str}\n" "We have provided an existing text-to-image prompt based on this query: {existing_answer}\n" "We have the opportunity to refine the existing prompt " "(only if needed) with some more relevant examples of text-to-image prompts below.\n" "------------\n" "{context_msg}\n" "------------\n" "Given the new examples of text-to-image prompts, refine the existing text-to-image prompt to better " "statisfy the required style. " "If the context isn't useful, or the existing prompt is good enough, return the existing prompt." ) PROMPT_ASSISTANT_DESCRIPTION = ( "This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which " "contains the image description and outputs an image." ) class Text2ImagePromptAssistant(Tool): inputs = ['text'] outputs = ['image'] description = PROMPT_ASSISTANT_DESCRIPTION def __init__(self, *args, openai_api_key='', model_name='text-davinci-003', temperature=0.3, verbose=False, **hub_kwargs): super().__init__() os.environ['OPENAI_API_KEY'] = openai_api_key if model_name == 'text-davinci-003': llm = OpenAI(model_name=model_name, temperature=temperature) elif model_name in ('gpt-3.5-turbo', 'gpt-4'): llm = ChatOpenAI(model_name=model_name, temperature=temperature) else: raise ValueError( f"{model_name} is not supported, please choose one " "of 'text-davinci-003', 'gpt-3.5-turbo', or 'gpt-4'." ) service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm)) set_global_service_context(service_context) self.storage_path = os.path.dirname(__file__) self.verbose = verbose self.hub_kwargs = hub_kwargs def setup(self): hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/vector_store.json", repo_type="space", local_dir=self.storage_path) hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/index_store.json", repo_type="space", local_dir=self.storage_path) hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/docstore.json", repo_type="space", local_dir=self.storage_path) self.index = load_index_from_storage(StorageContext.from_defaults(persist_dir=os.path.join(self.storage_path, "storage"))) self.query_engine = self.index.as_query_engine(similarity_top_k=5, text_qa_template=text_qa_template, refine_template=refine_template) self.text2image = load_tool('huggingface-tools/text-to-image') self.text2image.setup() self.initialized = True def __call__(self, prompt): if not self.is_initialized: self.setup() better_prompt = str(self.query_engine.query(prompt)).strip() if self.verbose: print('==New prompt generated by LlamaIndex==', flush=True) print(better_prompt, '\n', flush=True) return self.text2image(better_prompt)