Raycosine commited on
Commit
8e118e5
·
1 Parent(s): 380b053

first commit

Browse files
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import torch
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ from transformers import AutoTokenizer, AutoModel, MarianMTModel, MarianTokenizer
6
+
7
+ # === 模型加载 ===
8
+ print("Loading models...")
9
+ embed_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en")
10
+ embed_model = AutoModel.from_pretrained("BAAI/bge-large-en")
11
+
12
+ trans_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
13
+ trans_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
14
+ print("Models loaded.")
15
+
16
+ # === 载入字典文件 ===
17
+ with open("k_definition_cleaned.json", encoding="utf-8") as f:
18
+ modern_dict = json.load(f)
19
+
20
+ with open("oc_definition_cleaned.json", encoding="utf-8") as f:
21
+ ancient_dict = json.load(f)
22
+
23
+ # === 编码函数 ===
24
+ def encode(texts):
25
+ inputs = embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
26
+ with torch.no_grad():
27
+ outputs = embed_model(**inputs)
28
+ embeddings = outputs.last_hidden_state[:, 0]
29
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
30
+ return embeddings
31
+
32
+ # === 翻译函数 ===
33
+ def translate_to_english(text):
34
+ if all(ord(c) < 128 for c in text): # already English
35
+ return text
36
+ inputs = trans_tokenizer(text, return_tensors="pt", padding=True)
37
+ translated = trans_model.generate(**inputs)
38
+ return trans_tokenizer.decode(translated[0], skip_special_tokens=True)
39
+
40
+ # === 核心匹配函数 ===
41
+ def find_similar_hanzi(idea_text, top_k=10):
42
+ idea_en = translate_to_english(idea_text)
43
+ idea_vec = encode([idea_en])[0].unsqueeze(0)
44
+
45
+ def search(dictionary):
46
+ results = []
47
+ for hanzi, defs in dictionary.items():
48
+ def_vecs = encode(defs)
49
+ scores = cosine_similarity(def_vecs, idea_vec).flatten()
50
+ max_idx = scores.argmax()
51
+ results.append((hanzi, defs[max_idx], float(scores[max_idx])))
52
+ return sorted(results, key=lambda x: x[2], reverse=True)[:top_k]
53
+
54
+ modern = search(modern_dict)
55
+ ancient = search(ancient_dict)
56
+
57
+ return modern, ancient
58
+
59
+ # === 用于展示表格的处理函数 ===
60
+ def gradio_interface(query):
61
+ modern, ancient = find_similar_hanzi(query, top_k=50)
62
+ return {
63
+ "modern_results": [[h, d, round(s, 4)] for h, d, s in modern],
64
+ "ancient_results": [[h, d, round(s, 4)] for h, d, s in ancient]
65
+ }
66
+
67
+ # === Gradio 页面设置 ===
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# Hanzi Imagery Search")
70
+ with gr.Row():
71
+ inp = gr.Textbox(label="输入意象短语(中/英文)", placeholder="如:warrior, warmth, 月亮等")
72
+ btn = gr.Button("搜索")
73
+
74
+ modern_output = gr.Dataframe(headers=["汉字", "释义", "相似度"], label="现代释义匹配", interactive=False)
75
+ ancient_output = gr.Dataframe(headers=["汉字", "释义", "相似度"], label="古代释义匹配", interactive=False)
76
+ json_output = gr.JSON(label="JSON 返回结构")
77
+
78
+ def full_response(query):
79
+ res = gradio_interface(query)
80
+ return res["modern_results"], res["ancient_results"], res
81
+
82
+ btn.click(fn=full_response, inputs=[inp], outputs=[modern_output, ancient_output, json_output])
83
+
84
+ demo.launch(share=True)
k_definition_cleaned.json ADDED
The diff for this file is too large to render. See raw diff
 
oc_definition_cleaned.json ADDED
The diff for this file is too large to render. See raw diff