kenken999 commited on
Commit
71a8168
·
1 Parent(s): abfee45
babyagi/babyagi.py CHANGED
@@ -24,6 +24,31 @@ from transformers import AutoTokenizer, AutoModel
24
  import torch
25
  import numpy
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # モデル名を指定
28
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
29
 
@@ -745,3 +770,32 @@ def main():
745
 
746
  if __name__ == "__main__":
747
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  import torch
25
  import numpy
26
 
27
+ import psycopg2
28
+
29
+ class ProductDatabase:
30
+ def __init__(self, database_url):
31
+ self.database_url = database_url
32
+ self.conn = None
33
+
34
+ def connect(self):
35
+ self.conn = psycopg2.connect(self.database_url)
36
+
37
+ def close(self):
38
+ if self.conn:
39
+ self.conn.close()
40
+
41
+ def fetch_data(self):
42
+ with self.conn.cursor() as cursor:
43
+ cursor.execute("SELECT id FROM products")
44
+ rows = cursor.fetchall()
45
+ return rows
46
+
47
+ def update_data(self, product_id, new_price):
48
+ with self.conn.cursor() as cursor:
49
+ cursor.execute("UPDATE products SET price = %s WHERE id = %s", (new_price, product_id))
50
+ self.conn.commit()
51
+
52
  # モデル名を指定
53
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
54
 
 
770
 
771
  if __name__ == "__main__":
772
  main()
773
+
774
+ def test_postgres():
775
+ # データベース接続情報
776
+ DATABASE_URL = "postgresql://miyataken999:yz1wPf4KrWTm@ep-odd-mode-93794521.us-east-2.aws.neon.tech/neondb?sslmode=require"
777
+
778
+ # ProductDatabaseクラスのインスタンスを作成
779
+ db = ProductDatabase(DATABASE_URL)
780
+
781
+ # データベースに接続
782
+ db.connect()
783
+
784
+ try:
785
+ # データを取得
786
+ products = db.fetch_data()
787
+ print("Fetched products:")
788
+ for product in products:
789
+ print(product)
790
+
791
+ # データを更新(例: 価格を更新)
792
+ for product in products:
793
+ product_id = product[0]
794
+ print(product_id)
795
+ #new_price = product[2] * 1.1 # 価格を10%増加させる
796
+ #db.update_data(product_id, new_price)
797
+ #print(f"Updated product ID {product_id} with new price {new_price}")
798
+
799
+ finally:
800
+ # 接続を閉じる
801
+ db.close()
babyagi/classes ADDED
File without changes
babyagi/classesa/vector.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+ class ProductDatabase:
5
+ def __init__(self, database_url):
6
+ self.database_url = database_url
7
+ self.conn = None
8
+ self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
9
+
10
+ def connect(self):
11
+ self.conn = psycopg2.connect(self.database_url)
12
+
13
+ def close(self):
14
+ if self.conn:
15
+ self.conn.close()
16
+
17
+ def setup_vector_extension_and_column(self):
18
+ with self.conn.cursor() as cursor:
19
+ # pgvector拡張機能のインストール
20
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
21
+
22
+ # ベクトルカラムの追加
23
+ cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);")
24
+
25
+ self.conn.commit()
26
+
27
+ def get_embedding(self, text):
28
+ embedding = self.model.encode(text)
29
+ return embedding
30
+
31
+ def insert_vector(self, product_id, text):
32
+ vector = self.get_embedding(text).tolist() # ndarray をリストに変換
33
+ with self.conn.cursor() as cursor:
34
+ cursor.execute("UPDATE products SET vector_col = %s WHERE id = %s", (vector, product_id))
35
+ self.conn.commit()
36
+
37
+ def search_similar_vectors(self, query_text, top_k=5):
38
+ query_vector = self.get_embedding(query_text).tolist() # ndarray をリストに変換
39
+ with self.conn.cursor() as cursor:
40
+ cursor.execute("""
41
+ SELECT id, vector_col <=> %s::vector AS distance
42
+ FROM products
43
+ ORDER BY distance
44
+ LIMIT %s;
45
+ """, (query_vector, top_k))
46
+ results = cursor.fetchall()
47
+ return results
48
+
49
+ def main():
50
+ # データベース接続情報
51
+ DATABASE_URL = "postgresql://miyataken999:yz1wPf4KrWTm@ep-odd-mode-93794521.us-east-2.aws.neon.tech/neondb?sslmode=require"
52
+
53
+ # ProductDatabaseクラスのインスタンスを作成
54
+ db = ProductDatabase(DATABASE_URL)
55
+
56
+ # データベースに接続
57
+ db.connect()
58
+
59
+ try:
60
+ # pgvector拡張機能のインストールとカラムの追加
61
+ db.setup_vector_extension_and_column()
62
+ print("Vector extension installed and column added successfully.")
63
+
64
+ # サンプルデータの挿入
65
+ sample_text = """検査にはどのぐらい時間かかりますか?⇒当日に分かります。
66
+ 法人取引やってますか?⇒大丈夫ですよ。成約時に必要な書類の説明
67
+ LINEで金粉送って、査定はできますか?⇒できますが、今お話した内容と同様で、検査が必要な旨を返すだけなので、金粉ではなく、他のお品物でLINE査定くださいと。
68
+ 分かりました、またどうするか検討して連絡しますと"""
69
+ sample_product_id = 1 # 実際の製品IDを使用
70
+ db.insert_vector(sample_product_id, sample_text)
71
+ db.insert_vector(2, sample_text)
72
+
73
+ print(f"Vector inserted for product ID {sample_product_id}.")
74
+
75
+
76
+ # ベクトル検索
77
+ query_text = "今お話した内容と同様で"
78
+ results = db.search_similar_vectors(query_text)
79
+ print("Search results:")
80
+ for result in results:
81
+ print(result)
82
+
83
+ finally:
84
+ # 接続を閉じる
85
+ db.close()
86
+
87
+ if __name__ == "__main__":
88
+ main()