thefish1 commited on
Commit
6dcc10e
·
1 Parent(s): 6c98343
Files changed (2) hide show
  1. app.py +13 -9
  2. vec_db.py +119 -0
app.py CHANGED
@@ -6,11 +6,11 @@ import re
6
  from load_data import load_data
7
  from openai import OpenAI
8
  from transformers import AutoTokenizer, AutoModel
9
- from fetch_from_database import encode, insert_keywords_to_weaviate, fetch_summary_from_database,init_database
10
  import weaviate
11
  import os
12
  import subprocess
13
-
14
 
15
  # 设置 Matplotlib 的缓存目录
16
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
@@ -23,13 +23,18 @@ auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3"
23
 
24
 
25
 
26
- database_client = weaviate.Client(
27
- url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud",
28
- auth_client_secret=auth_config
29
- )
 
 
 
 
30
 
31
- class_name="Lhnjames123321"
32
 
 
33
  tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
34
  model = AutoModel.from_pretrained("bert-base-chinese")
35
 
@@ -194,9 +199,8 @@ def respond(
194
 
195
  query_keywords = list(keywords_dict.keys())
196
  #此处将max_matches作为距离变量
197
- class_name="Lhnjames123321"
198
 
199
- max_matches,top_keywords_list,top_summary = fetch_summary_from_database(query_keywords,class_name)
200
 
201
 
202
  print(f"max_matches: {max_matches}")
 
6
  from load_data import load_data
7
  from openai import OpenAI
8
  from transformers import AutoTokenizer, AutoModel
9
+ from vec_db import encode_list_to_avg, fetch_response_from_db
10
  import weaviate
11
  import os
12
  import subprocess
13
+ import torch
14
 
15
  # 设置 Matplotlib 的缓存目录
16
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
 
23
 
24
 
25
 
26
+ URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
27
+ APIKEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
28
+
29
+ # Connect to a WCS instance
30
+ client = weaviate.connect_to_wcs(
31
+ cluster_url=URL,
32
+ auth_credentials=weaviate.auth.AuthApiKey(APIKEY))
33
+
34
 
35
+ class_name="ad_database02"
36
 
37
+ device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
38
  tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
39
  model = AutoModel.from_pretrained("bert-base-chinese")
40
 
 
199
 
200
  query_keywords = list(keywords_dict.keys())
201
  #此处将max_matches作为距离变量
 
202
 
203
+ max_matches,top_keywords_list,top_summary = fetch_response_from_db(query_keywords,class_name)
204
 
205
 
206
  print(f"max_matches: {max_matches}")
vec_db.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weaviate
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import subprocess
7
+ import os
8
+ # 设置 Matplotlib 缓存目录为可写的目录
9
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
10
+ # 设置 Hugging Face Transformers 缓存目录为可写的目录
11
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
12
+ #
13
+ # try:
14
+ # # 运行 Docker 容器的命令
15
+ # command = [
16
+ # "docker", "run",
17
+ # "-p", "8080:8080",
18
+ # "-p", "50051:50051",
19
+ # "cr.weaviate.io/semitechnologies/weaviate:1.24.20"
20
+ # ]
21
+ #
22
+ # # 执行命令
23
+ # subprocess.run(command, check=True)
24
+ # print("Docker container is running.")
25
+ #
26
+ # except subprocess.CalledProcessError as e:
27
+ # print(f"An error occurred: {e}")
28
+
29
+ class_name = 'Lhnjames123321'
30
+ auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3")
31
+ client = weaviate.Client(
32
+ url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud",
33
+ auth_client_secret=auth_config
34
+ )
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+ model = AutoModel.from_pretrained("bert-base-chinese").to(device)
37
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
38
+
39
+ def encode_sentences(sentences, model, tokenizer, device):
40
+ inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True, max_length=512)
41
+ inputs.to(device)
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ embeddings = outputs.last_hidden_state.mean(dim=1)
45
+ return embeddings.cpu().numpy()
46
+
47
+ # def class_exists(client, class_name):
48
+ # existing_classes = client.schema.get_classes()
49
+ # return any(cls['class'] == class_name for cls in existing_classes)
50
+
51
+ def init_weaviate():
52
+ # if class_exists(client, class_name)==0:
53
+ # class_obj = {
54
+ # 'class': class_name,
55
+ # 'vectorIndexConfig': {
56
+ # 'distance': 'cosine'
57
+ # },
58
+ # }
59
+ # client.schema.create_class(class_obj)
60
+
61
+ file_path = 'data.json'
62
+ sentence_data = []
63
+
64
+ with open(file_path, 'r', encoding='utf-8') as f:
65
+ for line in f:
66
+ try:
67
+ data = json.loads(line.strip())
68
+ sentence1 = data.get('response', '')
69
+ sentence_data.append(sentence1)
70
+ except json.JSONDecodeError as e:
71
+ print(f"Error parsing JSON: {e}")
72
+ continue
73
+
74
+ sentence_embeddings = encode_sentences(sentence_data, model, tokenizer, device)
75
+
76
+ data = {'sentence': sentence_data,
77
+ 'embeddings': sentence_embeddings.tolist()}
78
+ df = pd.DataFrame(data)
79
+
80
+ with client.batch(batch_size=100) as batch:
81
+ for i in range(df.shape[0]):
82
+ print(f'importing data: {i + 1}/{df.shape[0]}')
83
+ properties = {
84
+ 'sentence_id': i + 1,
85
+ 'sentence': df.sentence[i],
86
+ }
87
+ custom_vector = df.embeddings[i]
88
+ client.batch.add_data_object(
89
+ properties,
90
+ class_name=class_name,
91
+ vector=custom_vector
92
+ )
93
+ print('import completed')
94
+
95
+
96
+ def use_weaviate(input_str):
97
+ query = encode_sentences([input_str], model, tokenizer, device)[0].tolist()
98
+ nearVector = {
99
+ 'vector': query
100
+ }
101
+
102
+ response = (
103
+ client.query
104
+ .get(class_name, ['sentence_id', 'sentence'])
105
+ .with_near_vector(nearVector)
106
+ .with_limit(5)
107
+ .with_additional(['distance'])
108
+ .do()
109
+ )
110
+ print(response)
111
+ results = response['data']['Get'][class_name]
112
+ text_list = [result['sentence'] for result in results]
113
+ return text_list
114
+
115
+ if __name__ == '__main__':
116
+ init_weaviate()
117
+ input_str = input("请输入查询的文本:")
118
+ ans = use_weaviate(input_str)
119
+ print("查询结果:", ans)