Spaces:
Runtime error
Runtime error
import json | |
def data2reference( top_k_items, output_n = 3 ): | |
outputted_items = set() | |
output_str = "#Reference:\n" | |
for item in top_k_items: | |
item_in_life = item["keyword"] | |
if item_in_life in outputted_items: | |
continue | |
name_in_cultivation = item["name_in_cultivation"] | |
description_in_cultivation = item["description_in_cultivation"] | |
# output_str += f"name_in_life: {item_in_life}\n" | |
# output_str += f"name_in_cultivation: {name_in_cultivation}\n" | |
# output_str += f"description_in_cultivation: {description_in_cultivation}\n\n" | |
# output with into json format | |
output_data = { | |
"name_in_life": item_in_life, | |
"name_in_cultivation": name_in_cultivation, | |
"description_in_cultivation": description_in_cultivation | |
} | |
output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n" | |
outputted_items.add(item_in_life) | |
if len(outputted_items) >= output_n: | |
break | |
return output_str.strip() | |
def data2prompt(query_item , top_k_items): | |
reference_prompt = data2reference(top_k_items, 3) | |
task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n" | |
input_prompt = "# Input:\n" | |
if "keyword" in query_item: | |
input_prompt += f"input_name:{query_item['keyword']}\n" | |
if "description" in query_item: | |
input_prompt += f"description_in_life:{query_item['description']}\n" | |
else: | |
# directly dump query_item | |
input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n" | |
CoT_prompt = \ | |
"""Let's think it step by step,以json形式输出逐个字段。包含以下字段 | |
- name_in_life: 进一步明确要生成描述的物品名称 | |
- name_in_cultivation_1: 尝试编写物品在修仙界对应的名称 | |
- description_in_cultivation_1: 尝试编写物品在修仙界对应的描述 | |
- echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动" | |
- critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺 | |
- echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述" | |
- analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述 | |
- echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述" | |
- candidate_descriptions: 从不同角度,输出3次不同的加强后的描述 | |
- analysis_candidates: 分析各个candidates有什么优点 | |
- echo_4: "根据analysis_candidates,我将merge出一个最终的描述" | |
- final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述 | |
- echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词" | |
- name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字 | |
- new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1 | |
""" | |
return reference_prompt + task_prompt1 + input_prompt + CoT_prompt | |
try: | |
from src.ZhipuClient import ZhipuClient | |
except: | |
from ZhipuClient import ZhipuClient | |
zhipu_client = None | |
import json | |
def markdown_to_json(markdown_str): | |
# 移除Markdown语法中可能存在的标记,如代码块标记等 | |
if markdown_str.startswith("```json"): | |
markdown_str = markdown_str[7:-3].strip() | |
elif markdown_str.startswith("```"): | |
markdown_str = markdown_str[3:-3].strip() | |
# 将字符串转换为JSON字典 | |
json_dict = json.loads(markdown_str) | |
return json_dict | |
import re | |
def forced_extract(input_str, keywords): | |
result = {key: "" for key in keywords} | |
for key in keywords: | |
# 使用正则表达式来查找关键词-值对 | |
pattern = f'"{key}":\s*"(.*?)"' | |
match = re.search(pattern, input_str) | |
if match: | |
result[key] = match.group(1) | |
return result | |
def generate_cultivation_with_rag( query_item, search_result ): | |
global zhipu_client | |
if zhipu_client is None: | |
zhipu_client = ZhipuClient() | |
prompt = data2prompt(query_item, search_result) | |
response = zhipu_client.prompt2response(prompt) | |
try: | |
json_response = markdown_to_json(response) | |
except: | |
keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"] | |
json_response = forced_extract(response, keyword_list) | |
if "new_name" not in json_response or json_response["new_name"] == "": | |
if "name_in_cultivation_1" in json_response: | |
json_response["new_name"] = json_response["name_in_cultivation_1"] | |
else: | |
json_response["new_name"] = "" | |
if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "": | |
if "description_in_cultivation_1" in json_response: | |
json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"] | |
else: | |
json_response["final_enhanced_description"] = json_response["new_name"] | |
return json_response | |
if __name__ == '__main__': | |
try: | |
from src.Database import Database | |
except: | |
from Database import Database | |
db = Database() | |
try: | |
from src.Captioner import Captioner | |
except: | |
from Captioner import Captioner | |
import os | |
os.environ['HTTP_PROXY'] = 'http://localhost:8234' | |
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' | |
captioner = Captioner() | |
test_image = "temp_images/3or47vg0.jpg" | |
caption_response = captioner.caption(test_image) | |
# print(caption_response) | |
search_result = db.search_with_image_name( test_image ) | |
# print(search_result[0].keys()) | |
# reference_str = data2reference(search_result, output_n = 3) | |
# print(reference_str) | |
seen = set() | |
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] | |
# print(keywords) | |
# prompt = data2prompt(caption_response , keywords) | |
# print(prompt) | |
from get_major_object import get_major_object, verify_keyword_in_base | |
json_response = get_major_object(caption_response , keywords) | |
print(json_response) | |
print() | |
in_base_data , alt_data = verify_keyword_in_base(json_response , db) | |
if alt_data is not None: | |
result = generate_cultivation_with_rag(alt_data , search_result) | |
print(result) | |