# prompt: fastapi route 処理作成 引数は calat wehth state x

from fastapi import APIRouter, HTTPException
from babyagi.classesa import da

import psycopg2
from sentence_transformers import SentenceTransformer
from fastapi import APIRouter, HTTPException

router = APIRouter(prefix="/leaning", tags=["leaning"])
@router.get("/route/{calat}/{wehth}/{state}/{x}")
async def route(calat: float, wehth: float, state: str, x: int):

    result = calculate(x,y,z,c)
    # Validate input parameters
    #if not (0.0 <= calat <= 90.0):
    #    raise HTTPException(status_code=400, detail="Invalid calat value.")


    # Process the request and return a response
    # ...

    return {"result": "OK"}

class ProductDatabase:
    def __init__(self, database_url):
        self.database_url = database_url
        self.conn = None
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    
    def connect(self):
        self.conn = psycopg2.connect(self.database_url)
    
    def close(self):
        if self.conn:
            self.conn.close()
    
    def setup_vector_extension_and_column(self):
        with self.conn.cursor() as cursor:
            # pgvector拡張機能のインストール
            cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
            
            # ベクトルカラムの追加
            cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);")
            
            self.conn.commit()

    def get_embedding(self, text):
        embedding = self.model.encode(text)
        return embedding

    def insert_vector(self, product_id, text):
        vector = self.get_embedding(text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id))
            self.conn.commit()

    def search_similar_vectors(self, query_text, top_k=50):
        query_vector = self.get_embedding(query_text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("""
                SELECT id,price,carat, cut, color, clarity, depth, diamondprice.table, x, y, z, vector_col <=> %s::vector AS distance
                FROM diamondprice
                WHERE vector_col IS NOT NULL
                ORDER BY distance asc
                LIMIT %s;
            """, (query_vector, top_k))
            results = cursor.fetchall()
            return results

    def search_similar_all(self, query_text, top_k=5):
        query_vector = self.get_embedding(query_text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("""
                SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z
                FROM diamondprice
                order by id asc
                limit 10000000
            """, (query_vector, top_k))
            results = cursor.fetchall()
            return results            

def calculate(query:str):
    # データベース接続情報
    DATABASE_URL = "postgresql://miyataken999:yz1wPf4KrWTm@ep-odd-mode-93794521.us-east-2.aws.neon.tech/neondb?sslmode=require"
    
    # ProductDatabaseクラスのインスタンスを作成
    db = ProductDatabase(DATABASE_URL)
    
    # データベースに接続
    db.connect()
    
    try:
        # pgvector拡張機能のインストールとカラムの追加
        db.setup_vector_extension_and_column()
        print("Vector extension installed and column added successfully.")
        query_text="1"
        results = db.search_similar_all(query_text)
        print("Search results:")
        DEBUG=0
        if DEBUG==1:
            for result in results:
                print(result) 
                id = result[0]
                sample_text = str(result[1])+str(result[2])+str(result[3])+str(result[4])+str(result[5])+str(result[6])+str(result[7])+str(result[8])+str(result[9])
                print(sample_text)
                db.insert_vector(id, sample_text) 
        #return
        # サンプルデータの挿入
        #sample_text = """"""
        #sample_product_id = 1  # 実際の製品IDを使用
        #db.insert_vector(sample_product_id, sample_text)
        #db.insert_vector(2, sample_text)

        #print(f"Vector inserted for product ID {sample_product_id}.")

        
        # ベクトル検索
        query_text = "2.03Very GoodJSI262.058.08.068.125.05"

        query_text = "2.03Very GoodJSI2"

        #query_text = "2.03-Very Good-J-SI2-62.2-58.0-7.27-7.33-4.55"
        results = db.search_similar_vectors(query)
        res_all = ""
        print("Search results:")
        for result in results:
            print(result)
            res_all += result+""
        # send to chat    
    
    finally:
        # 接続を閉じる
        db.close()
#router = APIRouter()