text2image_prompt_assistant / prompt_assistant.py
cheesyFishes's picture
add verbose
2c1cd4b
raw
history blame contribute delete
No virus
3.96 kB
import os
from transformers import Tool, load_tool
from huggingface_hub import hf_hub_download
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from llama_index import (
Prompt,
LLMPredictor,
ServiceContext,
StorageContext,
load_index_from_storage,
set_global_service_context
)
text_qa_template = Prompt(
"Examples of text-to-image prompts are below: \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Given the existing examples of text-to-image prompts, "
"write a new text-to-image prompt in the style of the examples, by re-wording the following prompt to match the style of the above examples: {query_str}\n"
)
refine_template = Prompt(
"The initial prompt is as follows: {query_str}\n"
"We have provided an existing text-to-image prompt based on this query: {existing_answer}\n"
"We have the opportunity to refine the existing prompt "
"(only if needed) with some more relevant examples of text-to-image prompts below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new examples of text-to-image prompts, refine the existing text-to-image prompt to better "
"statisfy the required style. "
"If the context isn't useful, or the existing prompt is good enough, return the existing prompt."
)
PROMPT_ASSISTANT_DESCRIPTION = (
"This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
"contains the image description and outputs an image."
)
class Text2ImagePromptAssistant(Tool):
inputs = ['text']
outputs = ['image']
description = PROMPT_ASSISTANT_DESCRIPTION
def __init__(self, *args, openai_api_key='', model_name='text-davinci-003', temperature=0.3, verbose=False, **hub_kwargs):
super().__init__()
os.environ['OPENAI_API_KEY'] = openai_api_key
if model_name == 'text-davinci-003':
llm = OpenAI(model_name=model_name, temperature=temperature)
elif model_name in ('gpt-3.5-turbo', 'gpt-4'):
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
else:
raise ValueError(
f"{model_name} is not supported, please choose one "
"of 'text-davinci-003', 'gpt-3.5-turbo', or 'gpt-4'."
)
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
set_global_service_context(service_context)
self.storage_path = os.path.dirname(__file__)
self.verbose = verbose
self.hub_kwargs = hub_kwargs
def setup(self):
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/vector_store.json", repo_type="space", local_dir=self.storage_path)
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/index_store.json", repo_type="space", local_dir=self.storage_path)
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/docstore.json", repo_type="space", local_dir=self.storage_path)
self.index = load_index_from_storage(StorageContext.from_defaults(persist_dir=os.path.join(self.storage_path, "storage")))
self.query_engine = self.index.as_query_engine(similarity_top_k=5, text_qa_template=text_qa_template, refine_template=refine_template)
self.text2image = load_tool('huggingface-tools/text-to-image')
self.text2image.setup()
self.initialized = True
def __call__(self, prompt):
if not self.is_initialized:
self.setup()
better_prompt = str(self.query_engine.query(prompt)).strip()
if self.verbose:
print('==New prompt generated by LlamaIndex==', flush=True)
print(better_prompt, '\n', flush=True)
return self.text2image(better_prompt)