Spaces:
Runtime error
Runtime error
import os | |
from transformers import Tool, load_tool | |
from huggingface_hub import hf_hub_download | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from llama_index import ( | |
Prompt, | |
LLMPredictor, | |
ServiceContext, | |
StorageContext, | |
load_index_from_storage, | |
set_global_service_context | |
) | |
text_qa_template = Prompt( | |
"Examples of text-to-image prompts are below: \n" | |
"---------------------\n" | |
"{context_str}" | |
"\n---------------------\n" | |
"Given the existing examples of text-to-image prompts, " | |
"write a new text-to-image prompt in the style of the examples, by re-wording the following prompt to match the style of the above examples: {query_str}\n" | |
) | |
refine_template = Prompt( | |
"The initial prompt is as follows: {query_str}\n" | |
"We have provided an existing text-to-image prompt based on this query: {existing_answer}\n" | |
"We have the opportunity to refine the existing prompt " | |
"(only if needed) with some more relevant examples of text-to-image prompts below.\n" | |
"------------\n" | |
"{context_msg}\n" | |
"------------\n" | |
"Given the new examples of text-to-image prompts, refine the existing text-to-image prompt to better " | |
"statisfy the required style. " | |
"If the context isn't useful, or the existing prompt is good enough, return the existing prompt." | |
) | |
PROMPT_ASSISTANT_DESCRIPTION = ( | |
"This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which " | |
"contains the image description and outputs an image." | |
) | |
class Text2ImagePromptAssistant(Tool): | |
inputs = ['text'] | |
outputs = ['image'] | |
description = PROMPT_ASSISTANT_DESCRIPTION | |
def __init__(self, *args, openai_api_key='', model_name='text-davinci-003', temperature=0.3, verbose=False, **hub_kwargs): | |
super().__init__() | |
os.environ['OPENAI_API_KEY'] = openai_api_key | |
if model_name == 'text-davinci-003': | |
llm = OpenAI(model_name=model_name, temperature=temperature) | |
elif model_name in ('gpt-3.5-turbo', 'gpt-4'): | |
llm = ChatOpenAI(model_name=model_name, temperature=temperature) | |
else: | |
raise ValueError( | |
f"{model_name} is not supported, please choose one " | |
"of 'text-davinci-003', 'gpt-3.5-turbo', or 'gpt-4'." | |
) | |
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm)) | |
set_global_service_context(service_context) | |
self.storage_path = os.path.dirname(__file__) | |
self.verbose = verbose | |
self.hub_kwargs = hub_kwargs | |
def setup(self): | |
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/vector_store.json", repo_type="space", local_dir=self.storage_path) | |
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/index_store.json", repo_type="space", local_dir=self.storage_path) | |
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/docstore.json", repo_type="space", local_dir=self.storage_path) | |
self.index = load_index_from_storage(StorageContext.from_defaults(persist_dir=os.path.join(self.storage_path, "storage"))) | |
self.query_engine = self.index.as_query_engine(similarity_top_k=5, text_qa_template=text_qa_template, refine_template=refine_template) | |
self.text2image = load_tool('huggingface-tools/text-to-image') | |
self.text2image.setup() | |
self.initialized = True | |
def __call__(self, prompt): | |
if not self.is_initialized: | |
self.setup() | |
better_prompt = str(self.query_engine.query(prompt)).strip() | |
if self.verbose: | |
print('==New prompt generated by LlamaIndex==', flush=True) | |
print(better_prompt, '\n', flush=True) | |
return self.text2image(better_prompt) | |