Spaces:
Sleeping
Sleeping
# https://platform.openai.com/examples/default-emoji-translation | |
# https://zhuanlan.zhihu.com/p/672725319 | |
import os | |
import numpy as np | |
from openai import OpenAI | |
from langchain_community.vectorstores import FAISS | |
from langchain_openai import OpenAIEmbeddings | |
import json | |
import random | |
def _to_abs_path(fn, dir): | |
if not os.path.isabs(fn): | |
fn = os.path.join(dir, fn) | |
return fn | |
def _to_cache_path(dir): | |
cache_dir = os.path.join(dir, 'cache') | |
os.makedirs(cache_dir, exist_ok=True) | |
return cache_dir | |
def _read_txt(fn): | |
with open(fn, 'r', encoding='utf-8') as fp: | |
text = fp.read() | |
return text | |
class GPTHelper: | |
def __init__(self, config) -> None: | |
self.setup_config(config) | |
def setup_config(self, config): | |
self.config = config | |
self.client = OpenAI( | |
api_key=config.get('openai_api_key', os.getenv('openai_api_key')), | |
base_url=config.get('openai_api_base', os.getenv('openai_api_base')), | |
) | |
self.embeddings = OpenAIEmbeddings( | |
openai_api_key=config.get('openai_api_key', os.getenv('openai_api_key')), | |
openai_api_base=config.get('openai_api_base', os.getenv('openai_api_base')), | |
) | |
self.prepare_faiss() | |
# prompts | |
self.prompt_composition = _read_txt(_to_abs_path(self.config.composition_from, config.yaml_dir)) | |
self.prompt_image = _read_txt(_to_abs_path(self.config.image_from, config.yaml_dir)) | |
def prepare_faiss(self): | |
json_fn = os.path.join(self.config.dat_dir, 'DB1404.json') | |
with open(json_fn, 'r', encoding='utf-8') as fp: | |
cc = json.load(fp) | |
self.donbda_dict = cc | |
self.donbda_texts = [cc[k] for k in cc] | |
if os.path.exists(self.config.db_path): | |
self.faiss_db = FAISS.load_local(self.config.db_path, self.embeddings, allow_dangerous_deserialization=True) | |
else: | |
self.faiss_db = FAISS.from_texts(self.donbda_texts, self.config.embeddings) | |
self.faiss_db.save_local(self.config.db_path) | |
# ask gpt for keywords in text | |
def query_keywords(self, image_topic): | |
# get keywords | |
response = self.client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "你会得到一句话,请根据这句话给出几个创作简笔画的关键词,只回答关键词,不要回复其他内容,关键词用;隔开" | |
}, | |
{ | |
"role": "user", | |
"content": image_topic, | |
} | |
], | |
temperature=0.8, | |
max_tokens=64, | |
top_p=1 | |
) | |
return response.choices[0].message.content.split(';') | |
# query keyword in faiss db | |
def query_in_faiss_db(self, keyword): | |
query_results = self.faiss_db.similarity_search(keyword) | |
words = [] | |
for rlt in query_results: | |
w = rlt.page_content | |
if w in self.donbda_texts: | |
idxs = np.where(w == np.array(self.donbda_texts))[0] | |
select_i = random.randint(0, len(idxs)-1) | |
idx = idxs[select_i] | |
words.append({ | |
'idx': int(idx), | |
'word': w, | |
}) | |
return words | |
# query composition of the image | |
def query_composition(self, image_topic, keywords_to_query, canvas_width, canvas_height, num_words=0): | |
prompt = self.prompt_composition | |
prompt = prompt.replace('%width%', f'{canvas_width}') | |
prompt = prompt.replace('%height%', f'{canvas_height}') | |
prompt = prompt.replace('%num_cols%', f'{int(np.ceil(canvas_width / 180))}') | |
prompt = prompt.replace('%num_rows%', f'{int(np.ceil(canvas_height / 180))}') | |
if num_words > 0: | |
prompt = prompt.replace('%usage%', f'使用这些关键词,每个关键词可使用多次,总共应该出现 {num_words} 个,回答时按照从远到近的顺序。') | |
else: | |
prompt = prompt.replace('%usage%', '仅使用这些关键词,且每个关键词使用一次,回答时按照从远到近的顺序。') | |
self.log(prompt) | |
response = self.client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "system", | |
"content": prompt, | |
}, | |
{ | |
"role": "user", | |
"content": f'主题是:"{image_topic}",关键词是:"{keywords_to_query}"', | |
} | |
], | |
temperature=0.8, | |
max_tokens=4096, | |
top_p=1 | |
) | |
return response.choices[0].message.content | |
# query prompt of the image | |
def query_image_prompt(self, image_topic): | |
response = self.client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "system", | |
"content": self.prompt_image, | |
}, | |
{ | |
"role": "user", | |
"content": f'主题是:"{image_topic}"', | |
} | |
], | |
temperature=0.8, | |
max_tokens=4096, | |
top_p=1 | |
) | |
return response.choices[0].message.content | |