silk-road commited on
Commit
0117cec
·
verified ·
1 Parent(s): 4591767

Upload 23 files

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Idiot Cultivation System
3
- emoji: 👀
4
- colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Test Idiot-Cultivation-System
3
+ emoji:
4
+ colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ cache_file for CLIP models and BGE-small-zh-v1.5
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ openai
2
+ torch
3
+ transformers
4
+ scikit-learn
5
+ numpy
6
+ pandas
7
+ zhipuai
8
+ opencv-python
9
+ pillow
src/CLIPExtractor.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+
10
+ class CLIPExtractor:
11
+ def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):
12
+
13
+ # 设置代理环境变量
14
+ # os.environ['HTTP_PROXY'] = 'http://localhost:8234'
15
+ # os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
16
+
17
+ # # 设置环境变量
18
+ # os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
19
+ # os.environ["HF_HOME"] = os.path.expanduser("models/")
20
+
21
+ if not cache_dir:
22
+ # 指定缓存目录
23
+ cache_dir = "models"
24
+ if not os.path.exists(cache_dir) and os.path.exists("../models"):
25
+ cache_dir = "../models"
26
+
27
+ # Initialize the model and processor with specified values
28
+ self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
29
+ self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ self.model.to(self.device)
32
+
33
+ def extract_image(self, frame):
34
+ # Convert frame (from OpenCV) to PIL Image
35
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
36
+ images = [image]
37
+
38
+ # Process the image and extract features
39
+ inputs = self.processor(images=images, return_tensors="pt").to(self.device)
40
+ with torch.no_grad():
41
+ outputs = self.model.get_image_features(**inputs)
42
+
43
+ ans = outputs.cpu().numpy()
44
+ return ans[0]
45
+
46
+ def extract_image_from_file(self, file_name):
47
+ if not os.path.exists(file_name):
48
+ raise FileNotFoundError(f"File {file_name} not found.")
49
+
50
+ images = [Image.open(file_name).convert("RGB")]
51
+
52
+ # Process the image and extract features
53
+ inputs = self.processor(images=images, return_tensors="pt").to(self.device)
54
+ with torch.no_grad():
55
+ outputs = self.model.get_image_features(**inputs)
56
+
57
+ ans = outputs.cpu().numpy()
58
+ return ans[0]
59
+
60
+ def extract_text(self, text):
61
+ if not isinstance(text, str) or not text:
62
+ raise ValueError("Input text should be a non-empty string.")
63
+
64
+ # Tokenize the text
65
+ inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
66
+
67
+
68
+ # Process the text and extract features
69
+ # inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
70
+
71
+ with torch.no_grad():
72
+ outputs = self.model.get_text_features(**inputs)
73
+
74
+ ans = outputs.cpu().numpy()
75
+ return ans[0]
76
+
77
+
78
+ if __name__ == "__main__":
79
+
80
+ clip_extractor = CLIPExtractor()
81
+
82
+ sample_image = "images/狐狸.jpg"
83
+ # 提取图像特征
84
+ image_feature = clip_extractor.extract_image_from_file(sample_image)
85
+
86
+
87
+ # 提取文本特征
88
+ sample_text = "A photo of fox"
89
+ text_feature = clip_extractor.extract_text(sample_text)
90
+
91
+ # consine similarity
92
+ cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
93
+ print(cosine_similarity)
src/Captioner.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import base64
3
+ from io import BytesIO
4
+ import os
5
+ from openai import OpenAI
6
+ import json
7
+
8
+ class Captioner:
9
+ def __init__(self, api_key_path = None, proxy=None, api_base="https://api.lingyiwanwu.com/v1"):
10
+
11
+ # if api_key_path is None:
12
+ # # try find datas/01_key.txt and ../datas/01_key.txt
13
+ # cand_paths = ['datas/01_key.txt', '../datas/01_key.txt']
14
+ # flag = False
15
+ # for path in cand_paths:
16
+ # if os.path.exists(path):
17
+ # api_key_path = path
18
+ # flag = True
19
+ # break
20
+
21
+ # if not flag:
22
+ # raise ValueError("Please provide the path to the API key file.")
23
+
24
+
25
+ self.api_key = os.getenv('YI_VL_KEY')
26
+ self.api_base = api_base
27
+ # if proxy:
28
+ # os.environ['HTTP_PROXY'] = proxy
29
+ # os.environ['HTTPS_PROXY'] = proxy
30
+ self.client = OpenAI(
31
+ api_key=self.api_key,
32
+ base_url=self.api_base
33
+ )
34
+
35
+ self.history = {}
36
+ self.history_file = None
37
+
38
+ self.load_history()
39
+
40
+ def load_access_token(self, file_path):
41
+ with open(file_path, 'r') as file:
42
+ return file.read().strip()
43
+
44
+ def image2base64(self, image_path):
45
+ # 打开图像
46
+ with Image.open(image_path) as img:
47
+ # 检查图像高度是否超过480
48
+ if img.height > 480:
49
+ # 计算调整后的宽度,以保持宽高比不变
50
+ aspect_ratio = img.width / img.height
51
+ new_height = 480
52
+ new_width = int(new_height * aspect_ratio)
53
+ img = img.resize((new_width, new_height), Image.ANTIALIAS)
54
+
55
+ # 使用BytesIO在内存中保存调整大小后的图像
56
+ buffered = BytesIO()
57
+ img.save(buffered, format="JPEG")
58
+ buffered.seek(0)
59
+
60
+ # 将图像转换为Base64编码字符串
61
+ img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.read()).decode('utf-8')
62
+
63
+ return img_base64
64
+
65
+ def load_history(self, jsonl_file_name=None):
66
+ if jsonl_file_name is None:
67
+ jsonl_file_name = "datas/caption_history.jsonl"
68
+
69
+ self.history_file = jsonl_file_name
70
+
71
+ if os.path.exists(jsonl_file_name):
72
+ with open(jsonl_file_name, 'r', encoding='utf-8') as f:
73
+ for line in f:
74
+ data = json.loads(line)
75
+ self.history[data['file_name']] = data['response']
76
+
77
+ def search_from_history(self, file_name):
78
+ return self.history.get(file_name, None)
79
+
80
+ def save_history(self, jsonl_file_name=None):
81
+ if jsonl_file_name is None:
82
+ jsonl_file_name = self.history_file
83
+
84
+ if jsonl_file_name:
85
+ with open(jsonl_file_name, 'w', encoding='utf-8') as f:
86
+ for file_name, response in self.history.items():
87
+ json.dump({'file_name': file_name, 'response': response}, f, ensure_ascii=False)
88
+ f.write('\n')
89
+
90
+ # print(f"History saved to {jsonl_file_name}")
91
+
92
+ def add_to_history(self, file_name, response):
93
+ self.history[file_name] = response
94
+
95
+ def caption(self, image_name):
96
+
97
+ # Check if the caption is already in the history
98
+ cached_response = self.search_from_history(image_name)
99
+ if cached_response:
100
+ # print("return the cache")
101
+ return cached_response
102
+
103
+ prompt = """Analyze the image and output in JSON format, including the following fields:
104
+ - "detailed_description": A detailed description of the image content.
105
+ - "major_object": Determine the main object/scene in the image based on the description, output with a simple word
106
+ - "Chinese_name": 判断图片中主要物体的中文名
107
+ - "real_or_composite": Determine whether this image was taken with a camera or created/modifed by a computer, output with real or composite."""
108
+
109
+ img_base64 = self.image2base64(image_name)
110
+
111
+ completion = self.client.chat.completions.create(
112
+ model="yi-vision",
113
+ messages=[
114
+ {
115
+ "role": "user",
116
+ "content": [
117
+ {
118
+ "type": "text",
119
+ "text": prompt
120
+ },
121
+ {
122
+ "type": "image_url",
123
+ "image_url": {
124
+ "url": img_base64
125
+ }
126
+ }
127
+ ]
128
+ }
129
+ ],
130
+ stream=False
131
+ )
132
+
133
+ response = completion.choices[0].message.content
134
+
135
+ # Add the new response to history
136
+ self.add_to_history(image_name, response)
137
+ # Save history after adding the new entry
138
+ self.save_history()
139
+
140
+ return response
141
+
142
+ if __name__ == "__main__":
143
+ import os
144
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
145
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
146
+ captioner = Captioner()
147
+ test_image = "temp_images/3zjz9b3l.jpg"
148
+ print(captioner.caption(test_image))
src/Database.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from tqdm import tqdm
4
+
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+
9
+ class Database:
10
+ def __init__(self, parquet_path=None, customized_parquets = None):
11
+ self.default_parquet_path = 'datas/database_4000.parquet'
12
+ self.parquet_path = parquet_path or self.default_parquet_path
13
+
14
+ self.default_customized_parquets = ["datas/customized_database_0.parquet"]
15
+ self.customized_parquets = customized_parquets or self.default_customized_parquets
16
+
17
+ self.datas = None
18
+ self.last_save_table = None
19
+
20
+ if os.path.exists(self.parquet_path):
21
+ # self.load_from_parquet(self.parquet_path)
22
+ pass
23
+
24
+ # self.load_from_customized(self.customized_parquets)
25
+
26
+ self.clip_extractor = None
27
+ self.bge_extractor = None
28
+
29
+ self.en_keyword2data = {}
30
+
31
+ def build_en_keyword2index(self):
32
+ # build in lower case
33
+ self.en_keyword2data = {row['translated_word'].lower(): row for i, row in self.datas.iterrows()}
34
+
35
+ def search_by_en_keyword(self, keyword):
36
+ if len(self.en_keyword2data) == 0:
37
+ self.build_en_keyword2index()
38
+
39
+ keyword = keyword.lower()
40
+ if keyword in self.en_keyword2data:
41
+ ans = self.en_keyword2data[keyword].to_dict()
42
+ del ans["clip_feature"]
43
+ del ans["bge_feature"]
44
+ return ans
45
+ else:
46
+ return None
47
+
48
+ def load_from_parquet(self, parquet_path):
49
+ self.datas = pd.read_parquet(parquet_path)
50
+
51
+ def load_from_customized(self, customized_parquets=None):
52
+ customized_parquets = customized_parquets or self.customized_parquets
53
+
54
+ # Load each parquet file and concatenate them into the self.datas DataFrame
55
+ for index, parquet_file in enumerate(customized_parquets):
56
+ if os.path.exists(parquet_file):
57
+ temp_df = pd.read_parquet(parquet_file)
58
+ if self.datas is None:
59
+ self.datas = temp_df
60
+ else:
61
+ self.datas = pd.concat([self.datas, temp_df], ignore_index=True)
62
+
63
+ # if last parquet file
64
+ if index == len(customized_parquets) - 1:
65
+ self.last_save_table = temp_df
66
+
67
+ # if customized_parquets:
68
+ # Record the last parquet file's contents as self.last_save_table
69
+
70
+
71
+ def add_data(self, data, if_save=True):
72
+ required_columns = ['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description']
73
+ for column in required_columns:
74
+ if column not in data:
75
+ raise ValueError(f"Missing required field: {column}")
76
+
77
+ # Optional field
78
+ if 'founder' not in data:
79
+ data['founder'] = ""
80
+
81
+ # Extract features
82
+ if self.clip_extractor is None:
83
+ self.init_clip_extractor()
84
+ if self.bge_extractor is None:
85
+ self.init_bge_extractor()
86
+
87
+ data['clip_feature'] = self.clip_extractor.extract_text(data['translated_word'] + '.' + data['description'])
88
+ data['bge_feature'] = self.bge_extractor.extract([data['keyword']])[0].tolist()
89
+
90
+ # Convert to DataFrame and add to self.datas
91
+ data_df = pd.DataFrame([data])
92
+ if self.datas is None:
93
+ self.datas = data_df
94
+ else:
95
+ self.datas = pd.concat([self.datas, data_df], ignore_index=True)
96
+
97
+ # set self.en_keyword2data to last row of self.datas
98
+ self.en_keyword2data[data['translated_word'].lower()] = self.datas.iloc[-1]
99
+
100
+ # Add to last_save_table
101
+ if self.last_save_table is None:
102
+ # self.last_save_table = data_df
103
+ # create a new DataFrame with the same columns as self.datas
104
+ self.last_save_table = pd.DataFrame(columns=self.datas.columns)
105
+
106
+ self.last_save_table = pd.concat([self.last_save_table, data_df], ignore_index=True)
107
+
108
+ if if_save:
109
+ self.save_to_parquet(self.customized_parquets[-1], self.last_save_table )
110
+
111
+ def add_datas(self, datas, if_save=True):
112
+ for data in datas:
113
+ self.add_data(data, if_save=False)
114
+ if if_save:
115
+ self.save_to_parquet(self.customized_parquets[-1], self.last_save_table)
116
+
117
+ def init_from_excel(self, excel_path):
118
+ df = pd.read_excel(excel_path)
119
+
120
+ # Drop rows with any empty cell in the required columns
121
+ df.dropna(subset=['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description'], inplace=True)
122
+
123
+ # Add the new columns
124
+ df['clip_feature'] = None
125
+ df['bge_feature'] = None
126
+
127
+ self.datas = df
128
+
129
+ self.extract_clip()
130
+ self.extract_bge()
131
+
132
+ def save_to_parquet(self, parquet_path=None, df = None):
133
+
134
+ parquet_path = parquet_path or self.default_parquet_path
135
+ if df is None:
136
+ if self.datas is not None:
137
+ self.datas.to_parquet(parquet_path)
138
+ else:
139
+ df.to_parquet(parquet_path)
140
+
141
+ def init_clip_extractor(self):
142
+ if self.clip_extractor is None:
143
+ try:
144
+ from CLIPExtractor import CLIPExtractor
145
+ except:
146
+ from src.CLIPExtractor import CLIPExtractor
147
+
148
+ cache_dir = "models"
149
+
150
+ self.clip_extractor = CLIPExtractor(model_name = "openai/clip-vit-large-patch14",cache_dir = cache_dir)
151
+
152
+
153
+ def extract_clip(self):
154
+ if self.clip_extractor is None:
155
+ self.init_clip_extractor()
156
+
157
+ clip_features = []
158
+ # for text in tqdm(self.datas['keyword'], desc='Extracting CLIP features'):
159
+ for index, row in tqdm(self.datas.iterrows(), desc='Extracting CLIP features', total=len(self.datas)):
160
+ text = row['translated_word'] + '.' + row['description']
161
+ if text:
162
+ feature = self.clip_extractor.extract_text(text)
163
+ else:
164
+ feature = None
165
+ clip_features.append(feature)
166
+
167
+ self.datas['clip_feature'] = clip_features
168
+
169
+ def init_bge_extractor(self):
170
+ if self.bge_extractor is None:
171
+ try:
172
+ from text_embedding import TextExtractor
173
+ except:
174
+ from src.text_embedding import TextExtractor
175
+
176
+ self.bge_extractor = TextExtractor('BAAI/bge-small-zh-v1.5')
177
+
178
+ def top_k_search(self, query_feature, attribute, top_k=15):
179
+ return self.remoter.top_k_search(query_feature, attribute, top_k)
180
+ '''
181
+ # Ensure the attribute exists in the dataframe
182
+ if attribute not in self.datas.columns:
183
+ raise ValueError(f"Attribute {attribute} not found in the data.")
184
+
185
+ # Convert query feature and attribute features to numpy arrays
186
+ query_feature = np.array(query_feature).reshape(1, -1)
187
+ attribute_features = np.stack(self.datas[attribute].dropna().values)
188
+
189
+ # Compute cosine similarity between query and all attributes
190
+ similarities = cosine_similarity(query_feature, attribute_features)[0]
191
+
192
+ # Get the top_k indices based on similarity
193
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
194
+
195
+ # Retrieve the top_k most similar items
196
+ top_k_results = self.datas.iloc[top_k_indices].copy()
197
+
198
+ top_k_results = top_k_results.drop(columns=['clip_feature', 'bge_feature'])
199
+
200
+ top_k_results['similarity'] = similarities[top_k_indices]
201
+
202
+ return top_k_results.to_dict(orient='records')
203
+ '''
204
+ def search_with_image_name(self, image_name):
205
+ self.init_clip_extractor()
206
+
207
+ img_feature = self.clip_extractor.extract_image_from_file(image_name)
208
+
209
+ return self.top_k_search(img_feature, 'clip_feature')
210
+
211
+ def search_with_image(self, image, if_opencv = False ):
212
+ if self.clip_extractor is None:
213
+ self.init_clip_extractor()
214
+
215
+ img_feature = self.clip_extractor.extract_image(image, if_opencv = if_opencv)
216
+
217
+ return self.top_k_search(img_feature, 'clip_feature')
218
+
219
+ def search_with_chinese(self, text):
220
+ if self.bge_extractor is None:
221
+ self.init_bge_extractor()
222
+
223
+ text_feature = self.bge_extractor.extract([text])[0].tolist()
224
+
225
+ return self.top_k_search(text_feature, 'bge_feature')
226
+
227
+
228
+
229
+ def extract_bge(self):
230
+ if self.bge_extractor is None:
231
+ self.init_bge_extractor()
232
+
233
+ # Extract features for each row and store them in the bge_feature column
234
+ bge_features = []
235
+ for text in tqdm(self.datas['keyword'], desc='Extracting BGE features'):
236
+ if text:
237
+ feature = self.bge_extractor.extract([text])[0].tolist()
238
+ else:
239
+ feature = None
240
+ bge_features.append(feature)
241
+
242
+ self.datas['bge_feature'] = bge_features
243
+
244
+ if __name__ == '__main__':
245
+ # Usage example
246
+ db = Database()
247
+ re_extract = False
248
+ if db.datas is None or re_extract:
249
+ print("Rebuilding database from excel file")
250
+ db.init_from_excel('datas/database_4000.xlsx')
251
+ db.save_to_parquet()
252
+
253
+ # print(db.datas[0].keys())
254
+
255
+ query_text = "钢琴"
256
+
257
+ results = db.search_with_chinese(query_text)
258
+
259
+ print(results[0].keys())
260
+
261
+ for result in results[:3]:
262
+ print(result)
263
+
264
+ image_path = "datas/老虎.jpg"
265
+
266
+ results = db.search_with_image_name(image_path)
267
+
268
+ for result in results[:3]:
269
+ print(result)
270
+ # 'keyword': '老虎狗', 'name_in_cultivation': '灵虎犬神', 'description_in_cultivation': '在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图
271
+ # 斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。', 'translated_word': 'Tiger Dog', 'description': 'A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in
272
+ # conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features.', 'founder': ''
273
+ # test_new_data = {
274
+ # "keyword": "老虎狗2",
275
+ # "name_in_cultivation": "灵虎犬神",
276
+ # "description_in_cultivation": "在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。",
277
+ # "translated_word": "Tiger Dog",
278
+ # "description":"A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features."
279
+ # }
280
+
281
+ # db.add_data(test_new_data)
src/Founder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+
4
+ from src.RemoteDatabase import RemoteDatabase
5
+
6
+ class Founder:
7
+ def __init__(self, filepath='datas/founder.jsonl'):
8
+ self.filepath = filepath
9
+ self.datas = {}
10
+ self.founder2items = defaultdict(list)
11
+ self.remote = RemoteDatabase()
12
+
13
+ try:
14
+ # self.load_founder()
15
+ pass
16
+ except FileNotFoundError:
17
+ self.datas = {}
18
+
19
+ # Initialize the reverse mapping
20
+ for word, founder in self.datas.items():
21
+ self.founder2items[founder].append(word)
22
+
23
+ def load_founder(self):
24
+ """Load founder data from a jsonl file."""
25
+ with open(self.filepath, 'r', encoding='utf-8') as file:
26
+ for line in file:
27
+ data = json.loads(line.strip())
28
+ self.datas.update(data)
29
+
30
+ def save_founder(self):
31
+ """Save founder data to a jsonl file."""
32
+ with open(self.filepath, 'w', encoding='utf-8') as file:
33
+ for word, founder in self.datas.items():
34
+ file.write(json.dumps({word: founder}, ensure_ascii=False) + '\n')
35
+
36
+ def get_founder(self, word):
37
+ return self.remote.get_top_founders(word=word)
38
+
39
+ def set_founder(self, word, founder, enforce=False):
40
+ """Set the founder of a word if it's not already set or if enforce is True."""
41
+ self.remote.set_founders(word=word, founder=founder, enforce=enforce)
42
+ '''
43
+ if word in self.datas and not enforce:
44
+ print(f"Warning: {word} already has a founder: {self.datas[word]}. Use enforce=True to override.")
45
+ else:
46
+ self.datas[word] = founder
47
+ self.founder2items[founder].append(word)
48
+ self.save_founder()
49
+ '''
50
+
51
+ def get_all_items_from_founder(self, founder):
52
+ """Get all words discovered by a specific founder."""
53
+ return self.founder2items.get(founder, [])
54
+
55
+ def get_top_rank(self, top_k=20):
56
+ return self.remote.get_top_founders(top_k=top_k)
57
+ '''
58
+ """Get the top_k founders with the most discovered words."""
59
+ sorted_founders = sorted(self.founder2items.items(), key=lambda x: len(x[1]), reverse=True)
60
+ return sorted_founders[:top_k]
61
+ '''
62
+
63
+ # Example usage:
64
+ # founder = Founder()
65
+ # founder.set_founder('apple', 'Alice')
66
+ # founder.set_founder('banana', 'Bob')
67
+ # print(founder.get_founder('apple'))
68
+ # print(founder.get_all_items_from_founder('Alice'))
69
+ # print(founder.get_top_rank())
70
+
71
+ if __name__ == '__main__':
72
+ founder = Founder()
73
+ founder.set_founder('test_apple', '鲁鲁道祖')
74
+ founder.set_founder('test_banana', '鲁鲁道祖')
75
+ founder.set_founder('test_orange', "文钊道祖")
76
+ print(founder.get_founder('test_apple'))
77
+ print(founder.get_all_items_from_founder('Alice'))
78
+ print(founder.get_top_rank())
src/GameMaster.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ try:
5
+ from src.Database import Database
6
+ from src.Captioner import Captioner
7
+ from src.ImageBase import Imagebase
8
+ from src.RemoteDatabase import RemoteDatabase
9
+ from src.get_major_object import get_major_object, verify_keyword_in_base
10
+ from src.generate_cultivation import generate_cultivation_with_rag
11
+ except:
12
+ from Database import Database
13
+ from Captioner import Captioner
14
+ from ImageBase import Imagebase
15
+ from RemoteDatabase import RemoteDatabase
16
+ from get_major_object import get_major_object, verify_keyword_in_base
17
+ from generate_cultivation import generate_cultivation_with_rag
18
+
19
+
20
+ class GameMaster:
21
+ def __init__( self ):
22
+ self.textdb = self.init_textdb()
23
+
24
+ self.clip_extractor = self.textdb.clip_extractor
25
+
26
+ self.imgdb = self.init_imgdb()
27
+
28
+ self.captioner = Captioner()
29
+
30
+ self.minimal_image_threshold = 0.9
31
+
32
+ self.remote = RemoteDatabase()
33
+
34
+ def init_textdb( self ):
35
+ text_db = Database()
36
+ text_db.init_bge_extractor()
37
+ text_db.init_clip_extractor()
38
+ return text_db
39
+
40
+ def init_imgdb( self ):
41
+ img_db = Imagebase()
42
+ return img_db
43
+
44
+ def random_image_text_data( self, n = 12 ):
45
+ random_img_datas = self.remote.random_sample(n)
46
+ # keep image_name and keywords only
47
+ image_names = [img_data['image_name'] for img_data in random_img_datas]
48
+ blank_image_path = "datas/blank_item.jpg"
49
+ for i in range(len(image_names)):
50
+ if not os.path.exists(image_names[i]):
51
+ image_names[i] = blank_image_path
52
+
53
+ keywords_zh = [img_data['keyword'] for img_data in random_img_datas]
54
+ keywords = [img_data['translated_word'] for img_data in random_img_datas]
55
+ descriptions = []
56
+
57
+ for keyword, keyword_zh in zip(keywords, keywords_zh):
58
+ result = self.remote.search_by_en_keyword(keyword)
59
+ if result and "description_in_cultivation" in result:
60
+ description = result['description_in_cultivation']
61
+ if "name_in_cultivation" in result:
62
+ description = result['name_in_cultivation'] + "--" + description
63
+ descriptions.append(description)
64
+ else:
65
+ descriptions.append("")
66
+
67
+ #return tuple of imapge path and description
68
+ return zip(image_names, descriptions)
69
+
70
+
71
+ def search_with_path( self, image_path , threshold = None ):
72
+ # this is a relatively light weight search
73
+ image_feature = self.clip_extractor.extract_image_from_file(image_path)
74
+
75
+ # image_search_result = img_db.search_with_image_name(image_path)
76
+ # image_search_result = self.imgdb.top_k_search(image_feature, top_k=1)
77
+ image_search_result = self.remote.top_k_search(image_feature, 'clip_feature', top_k=1)
78
+
79
+ search_result = None
80
+
81
+ if threshold is None:
82
+ threshold = self.minimal_image_threshold
83
+
84
+ if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold:
85
+
86
+ # try find data with translated_word
87
+ result = self.remote.search_by_en_keyword(image_search_result[0]['translated_word'])
88
+ if result and "name_in_cultivation" in result:
89
+ search_result = result
90
+ search_result['similarity'] = image_search_result[0]['similarity']
91
+ else:
92
+ print("Warning! Unfound keyword: ", image_search_result[0]['translated_word'])
93
+
94
+ # backup_results = None
95
+ # if search_result is None:
96
+ # try search with textdb
97
+ backup_results = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5)
98
+
99
+ return search_result, backup_results, image_feature
100
+
101
+ def generate_cultivation_data( self, image_path , image_feature, text_search_result ):
102
+ # this is very expensive
103
+
104
+ cultivation_data = None
105
+
106
+ try:
107
+ caption_response = self.captioner.caption(image_path)
108
+ except:
109
+ print("Error occurred while captioning the image ", image_path)
110
+ return cultivation_data
111
+
112
+ if text_search_result is None:
113
+ # complete text search
114
+ text_search_result = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5)
115
+
116
+ seen = set()
117
+ keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
118
+
119
+ try:
120
+ json_response = get_major_object(caption_response , keywords)
121
+ except:
122
+ print("Error occurred while getting major object from caption ", caption_response)
123
+ return cultivation_data
124
+
125
+ in_base_data , alt_data = verify_keyword_in_base(json_response , self.remote )
126
+
127
+ if in_base_data is not None:
128
+ cultivation_data = in_base_data
129
+
130
+ # 这意味着找到了一张新的图片,不需要生成额外的词条
131
+ # required_fields = ['image_name', 'keyword', 'translated_word']
132
+ image_data = {
133
+ 'image_name': image_path,
134
+ 'keyword': in_base_data['keyword'],
135
+ 'translated_word': in_base_data['translated_word']
136
+ }
137
+ #self.imgdb.add_image( image_data, True, image_feature )
138
+ self.remote.add_data(image_data, None, image_feature, None)
139
+ elif alt_data is not None:
140
+ try:
141
+ cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result)
142
+ except:
143
+ print("Error occurred while generating cultivation data")
144
+ return cultivation_data
145
+
146
+ new_data = {
147
+ "keyword": alt_data['keyword'],
148
+ "name_in_cultivation": cultivation_data['new_name'],
149
+ "description_in_cultivation": cultivation_data['final_enhanced_description'],
150
+ "translated_word": alt_data['translated_word'],
151
+ "description": alt_data['description']
152
+ }
153
+ #self.textdb.add_data(new_data)
154
+ text_feature = self.textdb.clip_extractor.extract_text(new_data['translated_word'] + '.' + new_data['description'])
155
+ print("Added new data to textdb: ", new_data["name_in_cultivation"])
156
+
157
+ image_data = {
158
+ 'image_name': image_path,
159
+ 'keyword': new_data['keyword'],
160
+ 'translated_word': new_data['translated_word']
161
+ }
162
+ #self.imgdb.add_image( image_data, True, image_feature )
163
+ self.remote.add_data(image_data, new_data, image_feature, text_feature)
164
+ print("Added new image to imgdb: ", image_data["keyword"])
165
+
166
+ cultivation_data = new_data
167
+
168
+ self.remote.add_file(image_path)
169
+ return cultivation_data
170
+
171
+
172
+
173
+ if __name__ == "__main__":
174
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
175
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
176
+
177
+ game_master = GameMaster()
178
+
179
+ target_folder="temp_images"
180
+
181
+ image_files = glob(os.path.join(target_folder, "*.jpg"))
182
+
183
+ for index, image_path in enumerate(image_files):
184
+ print("index:" , index )
185
+
186
+ search_result, backup_results, image_feature = game_master.search_with_path(image_path)
187
+
188
+ if search_result:
189
+ print(search_result)
190
+
191
+ break
192
+
193
+ test_image_path = "temp_images/向日葵.jpg"
194
+
195
+ search_result, backup_results, image_feature = game_master.search_with_path(test_image_path)
196
+ cultivation_data = game_master.generate_cultivation_data( \
197
+ test_image_path, image_feature, backup_results )
198
+ print(cultivation_data)
src/ImageBase.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ class Imagebase:
8
+ def __init__(self, parquet_path=None):
9
+ self.default_parquet_path = 'datas/imagebase.parquet'
10
+ self.parquet_path = parquet_path or self.default_parquet_path
11
+ self.datas = None
12
+
13
+ if os.path.exists(self.parquet_path):
14
+ # self.load_from_parquet(self.parquet_path)
15
+ pass
16
+ self.clip_extractor = None
17
+
18
+ def random_sample(self, num_samples=12):
19
+ if self.datas is not None:
20
+ return self.datas.sample(num_samples).to_dict(orient='records')
21
+ else:
22
+ return []
23
+
24
+ def load_from_parquet(self, parquet_path):
25
+ self.datas = pd.read_parquet(parquet_path)
26
+
27
+ def save_to_parquet(self, parquet_path=None):
28
+ parquet_path = parquet_path or self.default_parquet_path
29
+ if self.datas is not None:
30
+ self.datas.to_parquet(parquet_path)
31
+
32
+ def init_clip_extractor(self):
33
+ if self.clip_extractor is None:
34
+ try:
35
+ from CLIPExtractor import CLIPExtractor
36
+ except:
37
+ from src.CLIPExtractor import CLIPExtractor
38
+
39
+ cache_dir = "models"
40
+ self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir)
41
+
42
+ def top_k_search(self, query_feature, top_k=15):
43
+ if self.datas is None:
44
+ return []
45
+ if 'clip_feature' not in self.datas.columns:
46
+ raise ValueError("clip_feature column not found in the data.")
47
+
48
+ query_feature = np.array(query_feature).reshape(1, -1)
49
+ attribute_features = np.stack(self.datas['clip_feature'].dropna().values)
50
+
51
+ similarities = cosine_similarity(query_feature, attribute_features)[0]
52
+
53
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
54
+
55
+ top_k_results = self.datas.iloc[top_k_indices].copy()
56
+
57
+ top_k_results['similarity'] = similarities[top_k_indices]
58
+
59
+ # Drop the 'clip_feature' column
60
+ top_k_results = top_k_results.drop(columns=['clip_feature'])
61
+
62
+ return top_k_results.to_dict(orient='records')
63
+
64
+
65
+ def search_with_image_name(self, image_name):
66
+ self.init_clip_extractor()
67
+
68
+ img_feature = self.clip_extractor.extract_image_from_file(image_name)
69
+
70
+ return self.top_k_search(img_feature)
71
+
72
+ def search_with_image(self, image, if_opencv=False):
73
+ self.init_clip_extractor()
74
+
75
+ img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv)
76
+
77
+ return self.top_k_search(img_feature)
78
+
79
+ def add_image(self, data, if_save = True, image_feature = None):
80
+ required_fields = ['image_name', 'keyword', 'translated_word']
81
+ if not all(field in data for field in required_fields):
82
+ raise ValueError(f"Data must contain the following fields: {required_fields}")
83
+
84
+
85
+
86
+ image_name = data['image_name']
87
+ if image_feature is None:
88
+ self.init_clip_extractor()
89
+ data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name)
90
+ else:
91
+ data['clip_feature'] = image_feature
92
+
93
+ if self.datas is None:
94
+ self.datas = pd.DataFrame([data])
95
+ else:
96
+ self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True)
97
+ if if_save:
98
+ self.save_to_parquet()
99
+
100
+ def add_images(self, datas):
101
+ for data in datas:
102
+ self.add_image(data, if_save=False)
103
+ self.save_to_parquet()
104
+
105
+ import os
106
+ from glob import glob
107
+
108
+ def scan_and_update_imagebase(db, target_folder="temp_images"):
109
+ # 获取target_folder目录下所有.jpg文件
110
+ image_files = glob(os.path.join(target_folder, "*.jpg"))
111
+
112
+ duplicate_count = 0
113
+ added_count = 0
114
+
115
+ for image_path in image_files:
116
+ # 使用文件名作为keyword
117
+ keyword = os.path.basename(image_path).rsplit('.', 1)[0]
118
+ translated_word = keyword # 可以根据需要调整translated_word
119
+
120
+ # 搜索数据库中是否有相似的图片
121
+ results = db.search_with_image_name(image_path)
122
+
123
+ if results and results[0]['similarity'] > 0.9:
124
+ print(f"Image '{image_path}' is considered a duplicate.")
125
+ duplicate_count += 1
126
+ else:
127
+ new_image_data = {
128
+ 'image_name': image_path,
129
+ 'keyword': keyword,
130
+ 'translated_word': translated_word
131
+ }
132
+ db.add_image(new_image_data)
133
+ print(f"Image '{image_path}' added to the database.")
134
+ added_count += 1
135
+
136
+ print(f"Total duplicate images found: {duplicate_count}")
137
+ print(f"Total new images added to the database: {added_count}")
138
+
139
+ if __name__ == '__main__':
140
+ img_db = Imagebase()
141
+
142
+ # 目标目录
143
+ target_folder = "temp_images"
144
+
145
+ # 扫描并更新数据库
146
+ scan_and_update_imagebase(img_db, target_folder)
147
+
148
+ # Usage example
149
+ # img_db = Imagebase()
150
+
151
+ # new_image_data = {
152
+ # 'image_name': "datas/老虎.jpg",
153
+ # 'keyword': 'tiger',
154
+ # 'translated_word': '老虎'
155
+ # }
156
+
157
+ # img_db.add_image(new_image_data)
158
+
159
+ # image_path = "datas/老虎.jpg"
160
+ # results = img_db.search_with_image_name(image_path)
161
+ # for result in results[:3]:
162
+ # print(result)
src/RemoteDatabase.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import requests
4
+ import pickle
5
+ import base64
6
+ import json, os
7
+
8
+ def serialize_feature(feature):
9
+ if feature is None:
10
+ return None
11
+ return base64.encodebytes(pickle.dumps(feature)).decode('ascii')
12
+
13
+ def deserialize_feature(feature):
14
+ if feature is None:
15
+ return None
16
+ return pickle.loads(base64.decodebytes(feature.encode('ascii')))
17
+
18
+ class RemoteDatabase:
19
+ def __init__(self):
20
+ self.url = 'http://110.40.175.218:6007/'
21
+ #self.url = 'http://127.0.0.1:6007/'
22
+ pass
23
+
24
+ def top_k_search(self, query_feature, attribute='clip_feature', top_k=15):
25
+ url = self.url + 'top_k_search'
26
+ query_feature = serialize_feature(query_feature)
27
+ response = requests.post(url,
28
+ data=json.dumps({"feature":query_feature, "attribute":attribute, "top_k":top_k}),
29
+ headers={'Content-Type': 'application/json'})
30
+ response.encoding = 'utf-8'
31
+ return json.loads(response.text)
32
+
33
+ def search_by_en_keyword(self, keyword):
34
+ url = self.url + 'search_by_en_keyword'
35
+ response = requests.post(url,
36
+ data=json.dumps({"keyword":keyword}),
37
+ headers={'Content-Type': 'application/json'})
38
+ response.encoding = 'utf-8'
39
+ return json.loads(response.text)
40
+
41
+ def random_sample(self, n):
42
+ url = self.url + 'random_sample'
43
+ response = requests.post(url,
44
+ data=json.dumps({"number":n}),
45
+ headers={'Content-Type': 'application/json'})
46
+ response.encoding = 'utf-8'
47
+ result = json.loads(response.text)
48
+ for r in result:
49
+ image_name = r['image_name']
50
+ image_name = os.path.basename(image_name)
51
+ self.download_file(image_name, os.path.join('temp_images', image_name))
52
+ return result
53
+
54
+ def get_top_founders(self, word="", top_k=20):
55
+ url = self.url + 'get_top_founders'
56
+ response = requests.post(url,
57
+ data=json.dumps({'top_k':top_k, 'word':word}),
58
+ headers={'Content-Type': 'application/json'})
59
+ response.encoding = 'utf-8'
60
+ return json.loads(response.text)
61
+
62
+ def set_founders(self, word, founder, enforce=False):
63
+ url = self.url + 'set_founder'
64
+ response = requests.post(url,
65
+ data=json.dumps({'founder':founder, 'word':word}),
66
+ headers={'Content-Type': 'application/json'})
67
+ response.encoding = 'utf-8'
68
+ return json.loads(response.text)
69
+
70
+ def add_data(self, img_data, text_data, img_feature, text_feature):
71
+ url = self.url + 'add_data'
72
+ img_feature = serialize_feature(img_feature)
73
+ text_feature = serialize_feature(text_feature)
74
+ response = requests.post(url,
75
+ data=json.dumps({"img_data":img_data, "text_data":text_data, "img_feature":img_feature, "text_feature":text_feature}),
76
+ headers={'Content-Type': 'application/json'})
77
+ response.encoding = 'utf-8'
78
+ return json.loads(response.text)
79
+
80
+ def add_file(self, file_path):
81
+ url = self.url + 'add_file'
82
+ response = requests.post(url,
83
+ data = {},
84
+ files = {'file': open(file_path, 'rb')},
85
+ stream=True)
86
+ response.encoding = 'utf-8'
87
+ return json.loads(response.text)
88
+
89
+ def download_file(self, file_name, save_path):
90
+ url = self.url + 'download_file'
91
+ params = {'file_name': file_name}
92
+
93
+ response = requests.post(url, data=params, stream=True)
94
+
95
+ if response.status_code == 200:
96
+ with open(save_path, 'wb') as file:
97
+ for chunk in response.iter_content(1024):
98
+ if chunk:
99
+ file.write(chunk)
100
+
101
+ if __name__ == '__main__':
102
+ db = RemoteDatabase()
103
+ db.download_file('xyxmi1lr55pm.jpg', './test.jpg')
src/ZhipuClient.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zhipuai import ZhipuAI
2
+ import os
3
+
4
+ class ZhipuClient:
5
+ def __init__(self, api_key_file_path = None):
6
+ # if api_key_file_path is None:
7
+ # cands = ['./datas/zhipu_key.txt', '../datas/zhipu_key.txt']
8
+ # flag = False
9
+ # for cand in cands:
10
+ # if os.path.exists(cand):
11
+ # api_key_file_path = cand
12
+ # flag = True
13
+ # break
14
+ # if not flag:
15
+ # raise ValueError("No valid api key file found.")
16
+
17
+ self.api_key = os.getenv("ZHIPU_4_API")
18
+ self.client = ZhipuAI(api_key=self.api_key)
19
+
20
+ def _load_access_token(self, file_path):
21
+ with open(file_path, 'r') as file:
22
+ return file.read().strip()
23
+
24
+ def prompt2response(self, prompt):
25
+ response = self.client.chat.completions.create(
26
+ model="glm-4", # 填写需要调用的模型名称
27
+ messages=[
28
+ {"role": "user", "content": prompt}
29
+ ],
30
+ )
31
+ return response.choices[0].message.content
32
+
33
+ # Usage:
34
+ # zhipu_client = ZhipuClient('../datas/zhipu_key.txt')
35
+ # response = zhipu_client.prompt2response('Your prompt here')
src/__pycache__/Captioner.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/__pycache__/Database.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
src/__pycache__/GameMaster.cpython-310.pyc ADDED
Binary file (4.83 kB). View file
 
src/__pycache__/ImageBase.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
src/__pycache__/ZhipuClient.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
src/__pycache__/generate_cultivation.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
src/__pycache__/get_major_object.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/__pycache__/text_embedding.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
src/generate_cultivation.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def data2reference( top_k_items, output_n = 3 ):
4
+ outputted_items = set()
5
+
6
+ output_str = "#Reference:\n"
7
+
8
+ for item in top_k_items:
9
+ item_in_life = item["keyword"]
10
+ if item_in_life in outputted_items:
11
+ continue
12
+ name_in_cultivation = item["name_in_cultivation"]
13
+ description_in_cultivation = item["description_in_cultivation"]
14
+ # output_str += f"name_in_life: {item_in_life}\n"
15
+ # output_str += f"name_in_cultivation: {name_in_cultivation}\n"
16
+ # output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
17
+ # output with into json format
18
+ output_data = {
19
+ "name_in_life": item_in_life,
20
+ "name_in_cultivation": name_in_cultivation,
21
+ "description_in_cultivation": description_in_cultivation
22
+ }
23
+ output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"
24
+
25
+ outputted_items.add(item_in_life)
26
+ if len(outputted_items) >= output_n:
27
+ break
28
+ return output_str.strip()
29
+
30
+
31
+
32
+ def data2prompt(query_item , top_k_items):
33
+
34
+ reference_prompt = data2reference(top_k_items, 3)
35
+
36
+ task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"
37
+
38
+ input_prompt = "# Input:\n"
39
+ if "keyword" in query_item:
40
+ input_prompt += f"input_name:{query_item['keyword']}\n"
41
+ if "description" in query_item:
42
+ input_prompt += f"description_in_life:{query_item['description']}\n"
43
+ else:
44
+ # directly dump query_item
45
+ input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"
46
+
47
+ CoT_prompt = \
48
+ """Let's think it step by step,以json形式输出逐个字段。包含以下字段
49
+ - name_in_life: 进一步明确要生成描述的物品名称
50
+ - name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
51
+ - description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
52
+ - echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
53
+ - critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
54
+ - echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
55
+ - analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
56
+ - echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
57
+ - candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
58
+ - analysis_candidates: 分析各个candidates有什么优点
59
+ - echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
60
+ - final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
61
+ - echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
62
+ - name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
63
+ - new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
64
+ """
65
+
66
+ return reference_prompt + task_prompt1 + input_prompt + CoT_prompt
67
+
68
+ try:
69
+ from src.ZhipuClient import ZhipuClient
70
+ except:
71
+ from ZhipuClient import ZhipuClient
72
+
73
+ zhipu_client = None
74
+
75
+
76
+ import json
77
+
78
+ def markdown_to_json(markdown_str):
79
+ # 移除Markdown语法中可能存在的标记,如代码块标记等
80
+ if markdown_str.startswith("```json"):
81
+ markdown_str = markdown_str[7:-3].strip()
82
+ elif markdown_str.startswith("```"):
83
+ markdown_str = markdown_str[3:-3].strip()
84
+
85
+ # 将字符串转换为JSON字典
86
+ json_dict = json.loads(markdown_str)
87
+
88
+ return json_dict
89
+
90
+ import re
91
+
92
+ def forced_extract(input_str, keywords):
93
+ result = {key: "" for key in keywords}
94
+
95
+ for key in keywords:
96
+ # 使用正则表达式来查找关键词-值对
97
+ pattern = f'"{key}":\s*"(.*?)"'
98
+ match = re.search(pattern, input_str)
99
+ if match:
100
+ result[key] = match.group(1)
101
+
102
+ return result
103
+
104
+ def generate_cultivation_with_rag( query_item, search_result ):
105
+ global zhipu_client
106
+ if zhipu_client is None:
107
+ zhipu_client = ZhipuClient()
108
+ prompt = data2prompt(query_item, search_result)
109
+ response = zhipu_client.prompt2response(prompt)
110
+
111
+ try:
112
+ json_response = markdown_to_json(response)
113
+ except:
114
+ keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
115
+ json_response = forced_extract(response, keyword_list)
116
+
117
+ if "new_name" not in json_response or json_response["new_name"] == "":
118
+ if "name_in_cultivation_1" in json_response:
119
+ json_response["new_name"] = json_response["name_in_cultivation_1"]
120
+ else:
121
+ json_response["new_name"] = ""
122
+
123
+ if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
124
+ if "description_in_cultivation_1" in json_response:
125
+ json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
126
+ else:
127
+ json_response["final_enhanced_description"] = json_response["new_name"]
128
+
129
+
130
+ return json_response
131
+
132
+ if __name__ == '__main__':
133
+ try:
134
+ from src.Database import Database
135
+ except:
136
+ from Database import Database
137
+
138
+ db = Database()
139
+
140
+ try:
141
+ from src.Captioner import Captioner
142
+ except:
143
+ from Captioner import Captioner
144
+
145
+ import os
146
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
147
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
148
+
149
+
150
+ captioner = Captioner()
151
+
152
+ test_image = "temp_images/3or47vg0.jpg"
153
+ caption_response = captioner.caption(test_image)
154
+
155
+ # print(caption_response)
156
+
157
+ search_result = db.search_with_image_name( test_image )
158
+
159
+ # print(search_result[0].keys())
160
+ # reference_str = data2reference(search_result, output_n = 3)
161
+ # print(reference_str)
162
+
163
+ seen = set()
164
+ keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
165
+ # print(keywords)
166
+
167
+ # prompt = data2prompt(caption_response , keywords)
168
+ # print(prompt)
169
+ from get_major_object import get_major_object, verify_keyword_in_base
170
+
171
+ json_response = get_major_object(caption_response , keywords)
172
+
173
+ print(json_response)
174
+
175
+ print()
176
+
177
+ in_base_data , alt_data = verify_keyword_in_base(json_response , db)
178
+
179
+ if alt_data is not None:
180
+ result = generate_cultivation_with_rag(alt_data , search_result)
181
+ print(result)
182
+
183
+
src/get_comments_from_level.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_comments_from_level(inbase_similarity_level, inlibrary_similarity_level):
2
+ name = "{name}" # Placeholder for the item name
3
+
4
+ # Define comments based on the similarity levels
5
+ comments_mapping = {
6
+ (0, 0): f"道友,在下才疏学浅不太认识这个东西,我感觉这个东西有点像古籍上所说的{name},还要找经验更丰富的长老来确定下",
7
+ (0, 1): f"这个东西似乎有些眼熟,但天机阁尚未收录,古籍上的描述也有些模糊,可能需要长老们进一步鉴定,我猜是{name}",
8
+ (0, 2): f"这东西颇为罕见,天机阁未曾有过记录,但古籍中的描述与{name}颇为相似,我将呈给长老们以作鉴定",
9
+ (0, 3): f"想必这一定是{name}吧,虽然天机阁还没有收录过这个东西,倒是和修仙古籍上的记载非常相像。我赶紧拿给长老再鉴定下",
10
+ (1, 0): f"天机阁的记录中似乎没有这个东西,但我依稀记得古籍中提到过{name},还需长老进一步确认",
11
+ (1, 1): f"这个物品有些特别,天机阁的记录不多,古籍中的描述也只是一笔带过,可能是{name},还需长老鉴定",
12
+ (1, 2): f"此物颇为罕见,天机阁记录较少,但古籍中的描述与{name}有一定相似之处,长老们或能给出答案",
13
+ (1, 3): f"虽然古籍中对{name}的描述详细,但天机阁中却鲜有记录,或许这是一件稀世之宝",
14
+ (2, 0): f"天机阁中对此物知之甚少,但古籍中曾提到{name},这件物品或许不简单,需长老们鉴定",
15
+ (2, 1): f"天机阁中对此物的记录不多,古籍中对{name}的描述也有限,但似乎是一件非凡之物",
16
+ (2, 2): f"这件物品在古籍中有所记载,天机阁也有少量收录,看来是{name}无疑,但还需长老确认",
17
+ (2, 3): f"虽然在古籍中有记载,天机阁过往有一点点收录,但也算稀世珍宝,{name}确实非凡",
18
+ (3, 0): f"天机阁中没有记录,但古籍中对{name}的描述颇为详细,这件物品可能是个谜",
19
+ (3, 1): f"天机阁中记录较少,但古籍中对{name}的描述详尽,这件物品或许有着不同寻常的来历",
20
+ (3, 2): f"古籍中记载{name}颇多,天机阁中也有所收录,看来这东西并不罕见",
21
+ (3, 3): f"{name}这种东西很常见啊,天机阁的库房里面都有不少呢"
22
+ }
23
+
24
+ # Return the appropriate comment based on the similarity levels
25
+ return comments_mapping.get((inbase_similarity_level, inlibrary_similarity_level), "道友,我会给出初步的鉴定") + "。"
26
+
27
+ # Example usage:
28
+ # comments = get_comments_from_level(2, 3)
29
+ # print(comments)
src/get_major_object.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def data2prompt(caption_response, ref_words):
2
+
3
+ ref_word_str = ",".join(ref_words[:5])
4
+
5
+ task_prompt = "Based on the following Caption Response, you will output a description of the Major Object's name."
6
+
7
+ input_str = "# Caption Response:\n" + caption_response + "\n"
8
+
9
+ CoT_prompt = \
10
+ f"""
11
+ Let's think it step by step. Output each field in JSON format. Include the following fields:
12
+ - major_object: From the caption response, identify the major_object. If not present, extract it again from the detailed_description or caption_response.
13
+ - better_major_object: Reread the description in the caption response to see if there's a more suitable word for the major object. If not, still output major_object.
14
+ - echo_1: "I will generate a simple description in about 200 words in English for the better_major_object, introducing what the input object is."
15
+ - description: Generate a WIKI description for the better_major_object (explain what is the better_major_object).
16
+ - major_object_chinese: Translate the better_major_object into Chinese.
17
+ - echo_2: "I will check whether there is synonym of the major_object_chinese in the '{ref_word_str}'."
18
+ - synonym: If present, output the synonym directly; otherwise, output "NOT_INCLUDED."
19
+ - recheck: Based on the content of the Caption Response, determine whether the synonym is accurate. If accurate, output "ACCURATE"; otherwise, output "NOT_ACCURATE."
20
+ """
21
+ return task_prompt + input_str + CoT_prompt
22
+
23
+
24
+ # def data2prompt(caption_response , ref_words ):
25
+
26
+
27
+ # ref_word_str = ",".join(ref_words[:5])
28
+
29
+ # ref_str = "# Reference Word:\n"+ref_word_str+"\n\n"
30
+
31
+ # task_prompt = "你将根据下面的Caption Response,输出Major Object的名称描述"
32
+
33
+ # input_str = "# Caption Response:\n"+caption_response+"\n"
34
+
35
+ # CoT_prompt = \
36
+ # """
37
+ # Let's think it step by step,以json形式输出逐个字段。包含以下字段
38
+ # - major_object: 从caption response中,确认major_object,如果没有,则从detailed_description或者caption_response中重新抽取
39
+ # - better_major_object: 重新阅读caption response中的描述,看看是否有更合适的major object的词语,如果没有则仍然输出major_object
40
+ # - echo_1: "I will generate a simple description in about 200 words in English for the input word, introducing what the input object is"
41
+ # - description: generate the description for the input object
42
+ # - major_object_chinese: 将major_object翻译为中文
43
+ # - echo_2: "我将判断reference word中,是否存在major_object的同义词"
44
+ # - 同义词: 如果存在,则直接输出同义词,否则输出"NOT_INCLUDED"
45
+ # - recheck: 结合Caption Response的内容,判断同义词是否准确,如果准确,则输出"ACCURATE",否则输出"NOT_ACCURATE"
46
+ # """
47
+ # return ref_str+task_prompt+input_str+CoT_prompt
48
+
49
+ try:
50
+ from src.ZhipuClient import ZhipuClient
51
+ except:
52
+ from ZhipuClient import ZhipuClient
53
+
54
+ zhipu_client = None
55
+
56
+ import json
57
+
58
+ def markdown_to_json(markdown_str):
59
+ # 移除Markdown语法中可能存在的标记,如代码块标记等
60
+ if markdown_str.startswith("```json"):
61
+ markdown_str = markdown_str[7:-3].strip()
62
+ elif markdown_str.startswith("```"):
63
+ markdown_str = markdown_str[3:-3].strip()
64
+
65
+ # 将字符串转换为JSON字典
66
+ json_dict = json.loads(markdown_str)
67
+
68
+ return json_dict
69
+
70
+ import re
71
+
72
+ def forced_extract(input_str, keywords):
73
+ result = {key: "" for key in keywords}
74
+
75
+ for key in keywords:
76
+ # 使用正则表达式来查找关键词-值对
77
+ pattern = f'"{key}":\s*"(.*?)"'
78
+ match = re.search(pattern, input_str)
79
+ if match:
80
+ result[key] = match.group(1)
81
+
82
+ return result
83
+
84
+ def get_major_object(caption_response, ref_words):
85
+ global zhipu_client
86
+ if zhipu_client is None:
87
+ zhipu_client = ZhipuClient()
88
+ prompt = data2prompt(caption_response , ref_words)
89
+ response = zhipu_client.prompt2response(prompt)
90
+
91
+ try:
92
+ json_response = markdown_to_json(response)
93
+ except:
94
+ keyword_list = ["major_object", "better_major_object", "description", "major_object_chinese", "synonym", "recheck"]
95
+ json_response = forced_extract(response, keyword_list)
96
+
97
+ return json_response
98
+
99
+ def verify_keyword_in_base( json_response , database ):
100
+
101
+ keyword2verify = []
102
+ if "better_major_object" in json_response:
103
+ keyword2verify.append(json_response["better_major_object"].lower())
104
+
105
+ if "major_object" in json_response:
106
+ keyword2verify.append(json_response["major_object"].lower())
107
+
108
+ if "recheck" in json_response and json_response["recheck"] == "ACCURATE":
109
+ if "synonym" in json_response and json_response["synonym"] != "NOT_INCLUDED":
110
+ keyword2verify.append(json_response["synonym"].lower())
111
+
112
+ ans = None
113
+
114
+ for keyword in keyword2verify:
115
+ res = database.search_by_en_keyword(keyword)
116
+ if res is None:
117
+ continue
118
+ ans = res
119
+ return ans, None
120
+
121
+ if len(keyword2verify) == 0:
122
+ return None, None
123
+
124
+ # 这里我们需要一个新的data, keyword是中文名, translated_word是英文名,description是英文描述
125
+ description = keyword2verify[0]
126
+ if "description" in json_response:
127
+ description = json_response["description"]
128
+
129
+ translated_word = keyword2verify[0]
130
+
131
+ keyword = translated_word
132
+ if "major_object_chinese" in json_response:
133
+ keyword = json_response["major_object_chinese"]
134
+
135
+ data = {
136
+ "keyword": keyword,
137
+ "translated_word": translated_word,
138
+ "description": description
139
+ }
140
+
141
+ return None, data
142
+
143
+
144
+
145
+
146
+ if __name__ == '__main__':
147
+ try:
148
+ from src.Database import Database
149
+ except:
150
+ from Database import Database
151
+
152
+ db = Database()
153
+
154
+ try:
155
+ from src.Captioner import Captioner
156
+ except:
157
+ from Captioner import Captioner
158
+
159
+ import os
160
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
161
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
162
+
163
+
164
+ captioner = Captioner()
165
+
166
+ test_image = "temp_images/3or47vg0.jpg"
167
+ caption_response = captioner.caption(test_image)
168
+
169
+ # print(caption_response)
170
+
171
+ search_result = db.search_with_image_name( test_image )
172
+
173
+ seen = set()
174
+ keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
175
+ # print(keywords)
176
+
177
+ # prompt = data2prompt(caption_response , keywords)
178
+ # print(prompt)
179
+
180
+ json_response = get_major_object(caption_response , keywords)
181
+
182
+ print(json_response)
183
+
184
+ print()
185
+
186
+ in_base_data , alt_data = verify_keyword_in_base(json_response , db)
187
+
188
+ if in_base_data is not None:
189
+ print(in_base_data)
190
+
191
+ if alt_data is not None:
192
+ print(alt_data)
193
+
194
+ # {'keyword': '埃菲尔铁塔', 'translated_word': 'eiffel tower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable stru
195
+ # ower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable structures in the world. Designed and constructed by the engineer Gustave Eiffel and his company for the 1889 Exposition Universelle (World's Fair) to celebrate the 100th anniversary of the French Revolution, the tower was initially criticized by some of France's leading artists and intellectuals. However, it quickly became a beloved landmark and a symbol of French pride. Standing 324 meters tall, the tower is made of wrought iron and consists of thousands of metal parts, including over 18,000 individual iron rivets. It is renowned for its architectural and engineering design, and it is visited by millions of people each year, making it one of the most visited paid monuments in the world."}
src/text_embedding.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import os
4
+
5
+ class TextExtractor:
6
+ def __init__(self, model_name, proxy=None):
7
+ """
8
+ Initialize the TextExtractor with a specified model and optional proxy settings.
9
+
10
+ Parameters:
11
+ - model_name (str): The name of the pre-trained model to load from HuggingFace Hub.
12
+ - proxy (str, optional): The proxy address to use for HTTP and HTTPS requests.
13
+ """
14
+ # if proxy is None:
15
+ # proxy = 'http://localhost:8234'
16
+
17
+ # if proxy:
18
+ # os.environ['HTTP_PROXY'] = proxy
19
+ # os.environ['HTTPS_PROXY'] = proxy
20
+ try:
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.model = AutoModel.from_pretrained(model_name)
23
+ except:
24
+ print('try switch on local_files_only')
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
26
+ self.model = AutoModel.from_pretrained(model_name, local_files_only=True)
27
+
28
+ self.model.eval()
29
+
30
+ def extract(self, sentences):
31
+ """
32
+ Extract sentence embeddings for the provided sentences.
33
+
34
+ Parameters:
35
+ - sentences (list of str): A list of sentences to extract embeddings for.
36
+
37
+ Returns:
38
+ - torch.Tensor: The normalized sentence embeddings.
39
+ """
40
+ encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
41
+
42
+ with torch.no_grad():
43
+ model_output = self.model(**encoded_input)
44
+ sentence_embeddings = model_output[0][:, 0]
45
+
46
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
47
+ return sentence_embeddings
48
+
49
+ import pandas as pd
50
+ def get_qas(excel_file = None):
51
+
52
+ defaule_excel_file = 'data/output_fixid.xlsx'
53
+ if excel_file is None:
54
+ excel_file = defaule_excel_file
55
+
56
+ # 读取Excel文件
57
+ df = pd.read_excel(excel_file)
58
+
59
+ df = df[df["question"].notna()]
60
+ df = df[df["summary"].notna()]
61
+
62
+ datas = []
63
+
64
+ # 遍历DataFrame的每一行
65
+ for index, row in df.iterrows():
66
+ id = row['id']
67
+ question = row['question']
68
+ short_answer = row['summary']
69
+ category = row['category']
70
+
71
+ texts = [question, short_answer]
72
+
73
+ data_value = {
74
+ "texts":texts,
75
+ }
76
+
77
+ data = {
78
+ "id":id,
79
+ "value":data_value
80
+ }
81
+
82
+ datas.append(data)
83
+
84
+ return datas
85
+
86
+ from tqdm import tqdm
87
+
88
+ def extract_embedding(datas, text_extractor):
89
+ """
90
+ Extract embeddings for each item in the provided data.
91
+
92
+ Parameters:
93
+ - datas (list of dict): A list of dictionaries containing text data.
94
+
95
+ Returns:
96
+ - list of dict: The input data with added embeddings.
97
+ """
98
+ for data in tqdm(datas):
99
+ texts = data["value"]["texts"]
100
+ text = "。".join(texts)
101
+ embeddings = text_extractor.extract(text)
102
+ embeddings_list = embeddings.tolist() # Convert tensor to list of lists
103
+ data["value"]["embedding"] = embeddings_list
104
+ return datas
105
+
106
+ def save_parquet(datas, file_path):
107
+ """
108
+ Save the provided data to a Parquet file.
109
+
110
+ Parameters:
111
+ - datas (list of dict): A list of dictionaries containing text data and embeddings.
112
+ - file_path (str): The path to the output Parquet file.
113
+ """
114
+ # Flatten the data for easier conversion to DataFrame
115
+ flattened_data = []
116
+ for data in datas:
117
+ id = data["id"]
118
+ texts = data["value"]["texts"]
119
+ text = "。".join(texts)
120
+ embedding = data["value"]["embedding"]
121
+ flattened_data.append({
122
+ "id": id,
123
+ "text": text,
124
+ "embedding": embedding
125
+ })
126
+
127
+ # Create DataFrame
128
+ df = pd.DataFrame(flattened_data)
129
+
130
+ # Save DataFrame to Parquet
131
+ df.to_parquet(file_path, index=False)
132
+
133
+ import pandas as pd
134
+ import os
135
+
136
+ def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'):
137
+ """
138
+ Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified.
139
+
140
+ Parameters:
141
+ - parquet_file (str): The path to the Parquet file.
142
+ - regen (bool): Whether to regenerate embeddings.
143
+
144
+ Returns:
145
+ - dict: A dictionary mapping IDs to list of float embeddings.
146
+ """
147
+ if regen or not os.path.exists(parquet_file):
148
+ print("Regenerating embeddings...")
149
+ # Example usage:
150
+ model_name = 'BAAI/bge-small-zh-v1.5'
151
+ text_extractor = TextExtractor(model_name)
152
+
153
+ datas = get_qas()
154
+ print("Extracting embeddings for", len(datas), "data items")
155
+
156
+ datas = extract_embedding(datas, text_extractor)
157
+ save_parquet(datas, parquet_file)
158
+
159
+ df = pd.read_parquet(parquet_file)
160
+
161
+ id2embedding = {}
162
+ for index, row in df.iterrows():
163
+ id = row['id']
164
+ embedding = row['embedding']
165
+ id2embedding[id] = embedding[0]
166
+
167
+ return id2embedding
168
+
169
+ import torch
170
+ from sklearn.metrics.pairwise import cosine_similarity
171
+ import heapq
172
+
173
+ def __get_id2top30map(id2embedding):
174
+ """
175
+ Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity.
176
+
177
+ Parameters:
178
+ - id2embedding (dict): A dictionary mapping IDs to list of float embeddings.
179
+
180
+ Returns:
181
+ - dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs.
182
+ """
183
+ ids = list(id2embedding.keys())
184
+ embeddings = torch.tensor([id2embedding[id] for id in ids])
185
+
186
+ # Compute cosine similarity matrix
187
+ cos_sim_matrix = cosine_similarity(embeddings)
188
+
189
+ id2top30map = {}
190
+ for i, id in enumerate(ids):
191
+ # Get the similarity scores for the current ID
192
+ sim_scores = cos_sim_matrix[i]
193
+
194
+ # Get the top 30 indices (excluding the current ID itself)
195
+ top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x])
196
+ top_indices.remove(i) # Remove the index of the current ID
197
+
198
+ # Map the indices back to IDs
199
+ top_30_ids = [ids[idx] for idx in top_indices[:30]]
200
+
201
+ id2top30map[id] = top_30_ids
202
+
203
+ return id2top30map
204
+
205
+ import pickle
206
+
207
+ def get_id2top30map( id2embedding = None ):
208
+ default_save_pkl = "data/id2top30map.pkl"
209
+ if id2embedding is None:
210
+ if os.path.exists(default_save_pkl):
211
+ with open(default_save_pkl, 'rb') as f:
212
+ id2top30map = pickle.load(f)
213
+ else:
214
+ print("No embedding found, generating new one...")
215
+ id2embedding = get_id2embedding(regen=False)
216
+ id2top30map = __get_id2top30map(id2embedding)
217
+ with open(default_save_pkl, 'wb') as f:
218
+ pickle.dump(id2top30map, f)
219
+ else:
220
+ id2top30map = __get_id2top30map(id2embedding)
221
+
222
+ return id2top30map
223
+
224
+
225
+
226
+ if __name__ == '__main__':
227
+ if False:
228
+ # Example usage:
229
+ model_name = 'BAAI/bge-small-zh-v1.5'
230
+ sentences = ["样例数据-1", "样例数据-2"]
231
+
232
+ text_extractor = TextExtractor(model_name)
233
+ embeddings = text_extractor.extract(sentences)
234
+ print("Sentence embeddings:", embeddings)
235
+
236
+ datas = get_qas()
237
+
238
+ print("extract embedding for ", len(datas), " datas")
239
+
240
+ datas = extract_embedding(datas, text_extractor )
241
+
242
+ default_parquet_save_name = "data/qa_with_embedding.parquet"
243
+
244
+ save_parquet(datas, default_parquet_save_name)
245
+ if True:
246
+ id2embedding = get_id2embedding(regen=False)
247
+ print(len(id2embedding[4]))
248
+ id2top30map = get_id2top30map( None )
249
+ print("ID to Top 30 Neighbors dictionary:", id2top30map[4])
250
+
251
+ if True:
252
+
253
+ start_id = 332
254
+ visited_ids = [start_id]
255
+ current_queue = [start_id]
256
+
257
+ expend_num = 5
258
+
259
+ for iteration in range(10):
260
+ current_node = current_queue.pop(0)
261
+ top30 = id2top30map[current_node]
262
+ current_expend = []
263
+ for id in top30:
264
+ if id not in visited_ids:
265
+ visited_ids.append(id)
266
+ current_queue.append(id)
267
+ current_expend.append(id)
268
+ if len(current_expend) >= expend_num:
269
+ break
270
+ display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend])
271
+ print(display_text)
272
+
273
+ from get_qa_and_image import get_qa_and_image
274
+ image_datas = get_qa_and_image()
275
+
276
+ id2index = {}
277
+
278
+ for i, data in enumerate(image_datas):
279
+ id2index[data['id']] = i
280
+
281
+ indexes = [id2index[i] for i in visited_ids if i in id2index]
282
+ image_names = [image_datas[index]['value']['image'] for index in indexes]
283
+
284
+ target_copy_folder = "data/asso_collection"
285
+
286
+ import shutil
287
+ # copy image into target_copy_folder
288
+ for image_name in image_names:
289
+ shutil.copy(image_name, target_copy_folder)