Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import openai | |
import time | |
from sentence_transformers import SentenceTransformer | |
from langchain.prompts import PromptTemplate | |
from collections import Counter | |
def process(caption, category, asr, ocr): | |
preference = "兴趣标签" | |
if len(category) == 0: | |
category = "空" | |
if len(asr) == 0: | |
asr = "空" | |
example = "例如,给定一个视频,它的\"标题\"为\"长安系最便宜的轿车,4W起很多人都看不上它,但我知道车只是代步工具,又需要什么面子呢" \ | |
"!\",\"类别\"为\"汽车\",\"ocr\"为\"长安系最便宜的一款轿车\",\"asr\"为\"我不否认现在的国产和合资还有一定的差距," \ | |
"但确实是他们让我们5万开了MP V8万开上了轿车,10万开张了ICV15万开张了大七座。\",\"{}\"生成机器人推断出合理的\"{}\"为\"" \ | |
"长安轿车报价、最便宜的长安轿车、新款长安轿车\"。".format(preference, preference) | |
prompt = PromptTemplate( | |
input_variables=["preference", "caption", "ocr", "asr", "category", "example"], | |
template="你是一个视频的\"{preference}\"生成机器人,根据输入的视频标题、类别、ocr、asr推理出合理的\"{preference}\",以多个多" | |
"于两字的标签形式进行表达,以顿号隔开。{example}那么,给定一个新的视频,它的\"标题\"为\"{caption}\",\"类别\"为" | |
"\"{category}\",\"ocr\"为\"{ocr}\",\"asr\"为\"{asr}\",请推断出该视频的\"{preference}\":" | |
) | |
text = prompt.format(preference=preference, caption=caption, category=category, ocr=ocr, asr=asr, example=example) | |
try: | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": text}], | |
temperature=1.5, | |
n=5 | |
) | |
res = [] | |
for j in range(5): | |
ans = completion.choices[j].message["content"].strip() | |
ans = ans.replace("\n", "") | |
ans = ans.replace("。", "") | |
ans = ans.replace(",", "、") | |
res += ans.split('、') | |
tag_count = Counter(res) | |
tag_count = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)[:10] | |
tags_embed = np.load('./tag_data/tags_embed.npy') | |
tags_dis = np.load('./tag_data/tags_dis.npy') | |
candidate_tags = [_[0] for _ in tag_count] | |
encoder = SentenceTransformer("hfl/chinese-roberta-wwm-ext-large", device='cuda') | |
candidate_tags_embed = encoder.encode(candidate_tags) | |
candidate_tags_dis = [np.sqrt(np.dot(_, _.T)) for _ in candidate_tags_embed] | |
scores = np.dot(candidate_tags_embed, tags_embed.T) | |
f = open('./tag_data/tags.txt', 'r') | |
all_tags = [] | |
for line in f.readlines(): | |
all_tags.append(line.strip()) | |
f.close() | |
final_ans = [] | |
for i in range(scores.shape[0]): | |
for j in range(scores.shape[1]): | |
score = scores[i][j] / (candidate_tags_dis[i] * tags_dis[j]) | |
if score > 0.9: | |
final_ans.append(all_tags[j]) | |
final_ans = Counter(final_ans) | |
final_ans = sorted(final_ans.items(), key=lambda x: x[1], reverse=True)[:5] | |
final_ans = [_[0] for _ in final_ans] | |
return "、".join(final_ans) | |
except: | |
return 'api error' | |
def connection(api_key): | |
openai.api_key = api_key | |
time.sleep(5) | |
with gr.Blocks() as demo: | |
gr.Markdown("<h3><center>TagGPT</center></h3>") | |
gr.Markdown( | |
""" | |
This is a demo to the work [TagGPT: Large Language Models are Zero-shot Multimodal Taggers](https://github.com/TencentARC/TagGPT).<br> | |
This space connects TagGPT to provide tagging service based on the tag set (from the Kuaishou data).<br> | |
""" | |
) | |
gr.Markdown( | |
""" | |
Step 1: input openai api key (sk-...) and click "connect" button. | |
""" | |
) | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
text_api = gr.Textbox( | |
label='OpenAI API key', | |
placeholder="Paste your OpenAI API key here️", | |
type="password" | |
).style( | |
container=False, | |
) | |
btn_api = gr.Button("Connect").style(full_width=False) | |
btn_api.click(connection, [text_api]) | |
gr.Markdown( | |
""" | |
Step 2: fall in the four items (i.e., caption, category, ASR, and OCR) in the appropriate input fields. | |
Click any item of the "Examples" to quickly see the tagging results. | |
""" | |
) | |
text_caption = gr.Textbox( | |
label='标题(Caption)', | |
placeholder="Indispensable" | |
) | |
text_category = gr.Textbox( | |
label='类别(Category)', | |
placeholder="Dispensable" | |
) | |
text_asr = gr.Textbox( | |
label='ASR', | |
placeholder="Dispensable" | |
) | |
text_ocr = gr.Textbox( | |
label='OCR', | |
placeholder="Dispensable" | |
) | |
text_output = gr.Textbox(value='', label='Output') | |
btn = gr.Button(value='Submit') | |
btn.click(process, inputs=[text_caption, text_category, text_asr, text_ocr], outputs=[text_output]) | |
examples = [ | |
[ | |
'正确解决iCloud储存空间将满的实用技巧', | |
'高新数码', | |
'你的iCloud一直提示,iCloud储存空间将满,而且还会把你的iPhone iPad Mac等相关连设备都弹一遍,怎么办,烦死了。其实只要用iPho' | |
'ne打开设置,点击上方的头像栏,然后点击iCloud,选择管理账户储存空间,进入后,把这里不需要云备份按自己的需要进行删除,然后再回' | |
'到上一级页面,把这里的iCloud云备份关闭就可以了。', | |
'iCloud储存空间将满实用技巧' | |
], | |
[ | |
'30平米迷你小公寓如何设计?创意设计多出一间房!', | |
'房产家居', | |
'这是一个30平米的迷你公寓,入户没有鞋柜,没有餐厅,没有书桌位置,咱们可以这样设计,首先吧卫生间墙内推30公分,嵌入鞋柜,' | |
'厨房旁边装三联动推拉门,进出舒适,隔绝油烟,把原有沙发该到卧室位置,可以让出一个餐厅位置,沙发前移做一组书柜,内藏隐形' | |
'壁床,打开是卧室,收起便是小客厅,对面墙壁,再做一组书柜和书桌,各种需求都满足。', | |
'30平迷你公寓怎么设计' | |
], | |
[ | |
'好好钓鱼!千万别算账', | |
'生活', | |
'一只鱼竿四千多,用它来钓多少鱼才能回本,钓鱼是娱乐而算账却很扎心,这是一根国产标配波爬竿,我的深海标配,主要针对海绵大' | |
'型鱼类,比如金枪鱼、gt、鬼头刀等,经济价值最高的算是金枪鱼了,一不小心钓上一条鱼竿就回来了,不小心又钓上一条轮子就回来了' | |
'当你发现你的渔具全是渔获换来的,吹起牛来是不是更加有底气。一支鱼竿,一副轮子,又来一套,哈哈。你钓了几套渔具呢?', | |
'钓鱼很开心,算账很扎心' | |
], | |
[ | |
'很多朋友再问,为什么蒸出来的花卷葱没变色,今天告诉你', | |
'美食', | |
'为什么你自己在家做的画卷,蒸出来之后,它的葱总是黄色的呢?而早餐店里面的画卷蒸出来过后它的葱都是绿油油的,那么今天呢,' | |
'把方法教给你,那就是我们做花卷的时候呢,再从里面加入少量的泡打粉,拌均匀,这样做出来的画卷呢,蒸熟过后,它的葱也是绿' | |
'油油的,大家可以看一下。', | |
'为什么你蒸的葱油花卷总会变黄' | |
], | |
[ | |
'豪华七座SUV 比亚迪唐DM-i冠军版带你体验出行新方式', | |
'汽车', | |
'你听说了吗,2023年冠军版本的唐DM-i已经上市了,上市当天,汉唐总销量已经超过了8000多台,大家觉得这个成绩怎么样,那么,今天' | |
'艾琳就带大家全民啊了解一下,2023年冠军版本的唐DM-i,新款的唐DM-i和老款之间的差别并不大,但是我们的新款新增了冰川蓝的颜色' | |
'还有啊,你看这个轮毂是不是和我们老款252公里的轮毂一样呢,另外新款多了铝合金底盘,FSD可变主力悬挂,快充也升级到了40千瓦,还' | |
'有专门的快充孔,20分钟就可以让他满血复活,2023款唐DM-i一共分为三个配置,尊贵尊荣尊享,不过都是7座,那配置上的差异,艾琳' | |
'就把它放在图上了,方便大家对比。想入手的朋友记住了,武汉比亚迪找艾琳,艾琳带你早提车。', | |
'豪华标杆SUV,23款唐DM-1冠军版' | |
] | |
] | |
gr.Examples( | |
examples, | |
[text_caption, text_category, text_asr, text_ocr], | |
text_output, | |
process | |
) | |
if __name__ == "__main__": | |
demo.launch() | |