Spaces:
Runtime error
Runtime error
File size: 3,962 Bytes
05acb49 0f32912 f33b2e2 05acb49 be36ee6 05acb49 be36ee6 05acb49 0f32912 05acb49 0f32912 05acb49 2c1cd4b 4fdd5ac cfa8095 05acb49 10145f4 9c54c9d 2c1cd4b 4fdd5ac 38eeae0 0f32912 df2873d 05acb49 0f32912 4fdd5ac 05acb49 0c1de5b e4c8bfb 0c1de5b 2c1cd4b 0f32912 0c1de5b 05acb49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
from transformers import Tool, load_tool
from huggingface_hub import hf_hub_download
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from llama_index import (
Prompt,
LLMPredictor,
ServiceContext,
StorageContext,
load_index_from_storage,
set_global_service_context
)
text_qa_template = Prompt(
"Examples of text-to-image prompts are below: \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Given the existing examples of text-to-image prompts, "
"write a new text-to-image prompt in the style of the examples, by re-wording the following prompt to match the style of the above examples: {query_str}\n"
)
refine_template = Prompt(
"The initial prompt is as follows: {query_str}\n"
"We have provided an existing text-to-image prompt based on this query: {existing_answer}\n"
"We have the opportunity to refine the existing prompt "
"(only if needed) with some more relevant examples of text-to-image prompts below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new examples of text-to-image prompts, refine the existing text-to-image prompt to better "
"statisfy the required style. "
"If the context isn't useful, or the existing prompt is good enough, return the existing prompt."
)
PROMPT_ASSISTANT_DESCRIPTION = (
"This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
"contains the image description and outputs an image."
)
class Text2ImagePromptAssistant(Tool):
inputs = ['text']
outputs = ['image']
description = PROMPT_ASSISTANT_DESCRIPTION
def __init__(self, *args, openai_api_key='', model_name='text-davinci-003', temperature=0.3, verbose=False, **hub_kwargs):
super().__init__()
os.environ['OPENAI_API_KEY'] = openai_api_key
if model_name == 'text-davinci-003':
llm = OpenAI(model_name=model_name, temperature=temperature)
elif model_name in ('gpt-3.5-turbo', 'gpt-4'):
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
else:
raise ValueError(
f"{model_name} is not supported, please choose one "
"of 'text-davinci-003', 'gpt-3.5-turbo', or 'gpt-4'."
)
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
set_global_service_context(service_context)
self.storage_path = os.path.dirname(__file__)
self.verbose = verbose
self.hub_kwargs = hub_kwargs
def setup(self):
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/vector_store.json", repo_type="space", local_dir=self.storage_path)
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/index_store.json", repo_type="space", local_dir=self.storage_path)
hf_hub_download(repo_id="llamaindex/text2image_prompt_assistant", filename="storage/docstore.json", repo_type="space", local_dir=self.storage_path)
self.index = load_index_from_storage(StorageContext.from_defaults(persist_dir=os.path.join(self.storage_path, "storage")))
self.query_engine = self.index.as_query_engine(similarity_top_k=5, text_qa_template=text_qa_template, refine_template=refine_template)
self.text2image = load_tool('huggingface-tools/text-to-image')
self.text2image.setup()
self.initialized = True
def __call__(self, prompt):
if not self.is_initialized:
self.setup()
better_prompt = str(self.query_engine.query(prompt)).strip()
if self.verbose:
print('==New prompt generated by LlamaIndex==', flush=True)
print(better_prompt, '\n', flush=True)
return self.text2image(better_prompt)
|