import gradio as gr import psycopg2 from openai import OpenAI import json import os from typing import List, Dict from pgvector.psycopg2 import register_vector import numpy as np # 가중치 및 임계값 설정 DEFAULT_FULL_WEIGHT = 0.2 DEFAULT_TOPIC_WEIGHT = 0.5 DEFAULT_CUSTOMER_WEIGHT = 0.2 DEFAULT_AGENT_WEIGHT = 0.1 DEFAULT_SIMILARITY_THRESHOLD = 0.5 # DB 연결 설정 def get_db_conn(): return psycopg2.connect( host=os.environ["VECTOR_HOST"], port=5432, dbname=os.environ["VECTOR_DBNAME"], user=os.environ["VECTOR_USER"], password=os.environ["VECTOR_SECRET"] ) # OpenAI 클라이언트 초기화 client = OpenAI() def get_embedding(text: str) -> List[float]: """ 텍스트를 OpenAI의 text-embedding-ada-002 모델을 사용하여 임베딩 벡터로 변환합니다. Java의 float[](float32)와 호환되도록 명시적으로 float32로 변환합니다. Args: text (str): 임베딩할 텍스트 Returns: List[float]: 임베딩 벡터 (float32) """ try: response = client.embeddings.create( input=text, model="text-embedding-ada-002", encoding_format="float" ) # 명시적으로 float32로 변환하여 Java의 float[]와 호환되게 함 return np.array(response.data[0].embedding, dtype=np.float32).tolist() except Exception as e: print(f"임베딩 생성 중 오류 발생: {str(e)}") raise def format_vector_for_pg(vector: List[float]) -> str: """ 임베딩 벡터를 PostgreSQL 포맷으로 변환합니다. 입력된 벡터가 float32 타입인지 확인합니다. """ # 벡터가 float32 타입인지 확인하고, 아니면 변환 # NumPy 배열이 아닌 경우에도 처리 if not isinstance(vector, np.ndarray): vector = np.array(vector, dtype=np.float32) elif vector.dtype != np.float32: vector = vector.astype(np.float32) vector_str = ','.join([f"{x}" for x in vector]) return f"[{vector_str}]" def get_text_value(node: Dict, field_name: str) -> str: """ 딕셔너리에서 텍스트 값을 안전하게 추출합니다. 자바의 getTextValue() 메소드와 동일한 기능입니다. """ if node and field_name in node and node[field_name] is not None: return node[field_name] return None def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]: """ 채팅 데이터에서 유사한 콘텐츠를 검색합니다. Args: query (str): 검색할 쿼리 텍스트 max_results (int): 반환할 최대 결과 수 Returns: List[Dict]: 검색 결과 목록 """ limit = max_results if max_results is not None else 100 # 자바와 동일한 가중치 설정 full_w = DEFAULT_FULL_WEIGHT topic_w = DEFAULT_TOPIC_WEIGHT customer_w = DEFAULT_CUSTOMER_WEIGHT agent_w = DEFAULT_AGENT_WEIGHT threshold = DEFAULT_SIMILARITY_THRESHOLD try: # 쿼리 임베딩 생성 query_embedding = get_embedding(query) # PostgreSQL 포맷으로 벡터 변환 query_vector = format_vector_for_pg(query_embedding) # DB 연결 conn = get_db_conn() register_vector(conn) # 자바 코드와 동일한 SQL 쿼리 구현 sql = """ WITH embeddings AS ( SELECT id, metadata, content, CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim FROM vector_store_multi_embeddings WHERE full_embedding IS NOT NULL OR topic_embedding IS NOT NULL OR customer_embedding IS NOT NULL OR agent_embedding IS NOT NULL ) SELECT id, metadata, content, (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity FROM embeddings ORDER BY combined_similarity DESC LIMIT %s """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w, limit) with conn.cursor() as cur: cur.execute(sql) rows = cur.fetchall() results = [] for row in rows: id_val = row[0] metadata_json = row[1] content = row[2] similarity_score = float(row[3]) # 메타데이터 파싱 try: metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json result = { "id": id_val, "similarityScore": similarity_score, "content": content, "chatId": get_text_value(metadata, "chatId"), "topic": get_text_value(metadata, "topic") } # 시간 필드 처리 - 타임스탬프를 ISO 형식 문자열로 변환 if "startTime" in metadata and metadata["startTime"] is not None: # PostgreSQL 타임스탬프 또는 숫자일 경우 ISO 문자열로 변환 start_time = metadata["startTime"] if isinstance(start_time, str): # 이미 문자열이면 그대로 사용 result["startTime"] = start_time else: # 타임스탬프나 숫자인 경우 문자열로 변환 from datetime import datetime try: # 밀리초 타임스탬프인 경우 처리 if isinstance(start_time, int) or (isinstance(start_time, str) and start_time.isdigit()): dt = datetime.fromtimestamp(int(start_time)/1000) else: # PostgreSQL 타임스탬프 객체 처리 dt = datetime.fromisoformat(str(start_time).replace('Z', '+00:00')) result["startTime"] = dt.strftime('%Y-%m-%dT%H:%M:%S') except: # 변환 실패시 원본 값 사용 result["startTime"] = start_time if "endTime" in metadata and metadata["endTime"] is not None: # startTime과 동일한 로직 적용 end_time = metadata["endTime"] if isinstance(end_time, str): result["endTime"] = end_time else: from datetime import datetime try: if isinstance(end_time, int) or (isinstance(end_time, str) and end_time.isdigit()): dt = datetime.fromtimestamp(int(end_time)/1000) else: dt = datetime.fromisoformat(str(end_time).replace('Z', '+00:00')) result["endTime"] = dt.strftime('%Y-%m-%dT%H:%M:%S') except: result["endTime"] = end_time results.append(result) except Exception as e: print(f"메타데이터 파싱 오류: {e}") continue # 임계값 필터링 filtered_results = [r for r in results if r["similarityScore"] >= threshold] return filtered_results except Exception as e: print(f"다중 임베딩 검색 중 오류 발생: {str(e)}") return [] finally: if 'conn' in locals(): conn.close() def search_similar_chat_by_date( query: str, start_date: str = None, end_date: str = None, max_results: int = 100 ) -> List[Dict]: """ 지정된 날짜 범위 내의 채팅 데이터를 검색합니다. Args: query (str): 검색할 쿼리 텍스트 start_date (str): 검색 시작 날짜 (YYYY-MM-DD 형식) end_date (str): 검색 종료 날짜 (YYYY-MM-DD 형식) max_results (int): 반환할 최대 결과 수 Returns: List[Dict]: 검색 결과 목록 """ limit = max_results if max_results is not None else 100 # 자바와 동일한 가중치 설정 full_w = DEFAULT_FULL_WEIGHT topic_w = DEFAULT_TOPIC_WEIGHT customer_w = DEFAULT_CUSTOMER_WEIGHT agent_w = DEFAULT_AGENT_WEIGHT threshold = DEFAULT_SIMILARITY_THRESHOLD try: # 쿼리 임베딩 생성 query_embedding = get_embedding(query) # PostgreSQL 포맷으로 벡터 변환 query_vector = format_vector_for_pg(query_embedding) # DB 연결 conn = get_db_conn() register_vector(conn) # 자바 코드와 동일한 SQL 쿼리 시작 sql = """ WITH embeddings AS ( SELECT id, metadata, content, CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim FROM vector_store_multi_embeddings WHERE (full_embedding IS NOT NULL OR topic_embedding IS NOT NULL OR customer_embedding IS NOT NULL OR agent_embedding IS NOT NULL) """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w) # 날짜 필터 추가 if start_date and start_date.strip(): # 시작 시간 추가하여 ISO 형식으로 비교 iso_start_date = start_date + "T00:00:00" sql += f" AND (metadata->>'startTime') >= '{iso_start_date}'" if end_date and end_date.strip(): # 종료 시간 추가하여 ISO 형식으로 비교 iso_end_date = end_date + "T23:59:59" sql += f" AND (metadata->>'startTime') <= '{iso_end_date}'" sql += """ ) SELECT id, metadata, content, (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity FROM embeddings ORDER BY combined_similarity DESC LIMIT %s """ with conn.cursor() as cur: # 여기서는 limit를 파라미터로 전달 cur.execute(sql, (limit,)) rows = cur.fetchall() results = [] for row in rows: id_val = row[0] metadata_json = row[1] content = row[2] similarity_score = float(row[3]) # 메타데이터 파싱 try: metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json result = { "id": id_val, "similarityScore": similarity_score, "content": content, "chatId": get_text_value(metadata, "chatId"), "topic": get_text_value(metadata, "topic") } # 시간 필드 변환 없이 그대로 사용 (이미 KST로 저장되어 있음) if "startTime" in metadata and metadata["startTime"] is not None: result["startTime"] = metadata["startTime"] if "endTime" in metadata and metadata["endTime"] is not None: result["endTime"] = metadata["endTime"] results.append(result) except Exception as e: print(f"메타데이터 파싱 오류: {e}") continue # 임계값 필터링 (자바 코드와 동일하게 구현) filtered_results = [r for r in results if r["similarityScore"] >= threshold] return filtered_results except Exception as e: print(f"다중 임베딩 날짜 검색 중 오류 발생: {str(e)}") return [] finally: if 'conn' in locals(): conn.close() # Gradio 웹 인터페이스 설정 with gr.Blocks() as demo: gr.Markdown("# Chat Analysis Search") gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat") gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date") if __name__ == "__main__": demo.launch(mcp_server=True)