|
import os |
|
|
|
|
|
|
|
from pinecone import Pinecone, ServerlessSpec |
|
import time |
|
from tqdm import tqdm |
|
from dotenv import load_dotenv |
|
from backend.utils import logger |
|
import pandas as pd |
|
from backend.services.embedding_service import get_text_embedding |
|
from sentence_transformers import CrossEncoder |
|
|
|
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
|
load_dotenv() |
|
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") |
|
logger = logger.get_logger() |
|
NAMESPACE = "health-care-dataset" |
|
INDEX_NAME = "health-care-index" |
|
PINECONE = Pinecone(api_key=PINECONE_API_KEY) |
|
|
|
def rerank_results(query, results, score_threshold=0.5): |
|
pairs = [(query, result["metadata"]["question"]) for result in results] |
|
scores = reranker.predict(pairs) |
|
|
|
|
|
filtered_results = [ |
|
result for score, result in zip(scores, results) if score >= score_threshold |
|
] |
|
|
|
|
|
return sorted(filtered_results, key=lambda x: x['score'], reverse=True) |
|
|
|
def initialize_pinecone_index(pinecone, index_name, dimension=384, metric="cosine", cloud="aws", region="us-east-1"): |
|
""" |
|
Retrieves an existing Pinecone index or creates a new one if it does not exist. |
|
|
|
This method checks for the presence of the specified index. If the index does not exist, |
|
it initiates the creation process, waits until the index is ready, and then returns the index. |
|
|
|
Args: |
|
pinecone (Pinecone): Pinecone client instance. |
|
index_name (str): Name of the index to retrieve or create. |
|
dimension (int, optional): Vector dimension for the index. Default is 384. |
|
metric (str, optional): Distance metric for the index. Default is "cosine". |
|
cloud (str, optional): Cloud provider for hosting the index. Default is "aws". |
|
region (str, optional): Region where the index will be hosted. Default is "us-east-1". |
|
|
|
Returns: |
|
pinecone.Index: The Pinecone index instance. |
|
|
|
Raises: |
|
Exception: If an error occurs during index creation or retrieval. |
|
|
|
Example: |
|
>>> index = get_or_create_index(pinecone, "sample_index") |
|
Logs: "Index 'sample_index' is ready and accessible." |
|
""" |
|
try: |
|
logger.info(f"Checking if the index '{index_name}' exists...") |
|
|
|
|
|
if not pinecone.has_index(index_name): |
|
logger.info(f"Index '{index_name}' does not exist. Creating a new index...") |
|
|
|
|
|
pinecone.create_index( |
|
name=index_name, |
|
dimension=dimension, |
|
metric=metric, |
|
spec=ServerlessSpec(cloud=cloud, region=region) |
|
) |
|
logger.info(f"Index '{index_name}' creation initiated. Waiting for it to be ready...") |
|
|
|
|
|
while True: |
|
index_status = pinecone.describe_index(index_name) |
|
if index_status.status.get("ready", False): |
|
index = pinecone.Index(index_name) |
|
logger.info(f"Index '{index_name}' is ready and accessible.") |
|
return index |
|
else: |
|
logger.debug(f"Index '{index_name}' is not ready yet. Checking again in 1 second.") |
|
time.sleep(1) |
|
else: |
|
|
|
index = pinecone.Index(index_name) |
|
logger.info(f"Index '{index_name}' already exists. Returning the existing index.") |
|
return index |
|
|
|
except Exception as e: |
|
logger.error(f"Error occurred while getting or creating the Pinecone index: {str(e)}", exc_info=True) |
|
return None |
|
|
|
index = initialize_pinecone_index(PINECONE, INDEX_NAME) |
|
|
|
def delete_records_by_ids(ids_to_delete): |
|
""" |
|
Deletes specified IDs from the database index. |
|
|
|
This method interacts with the index to delete entries based on the provided list of IDs. |
|
It logs a success message if the deletion is successful or returns an error message if it fails. |
|
|
|
Args: |
|
ids_to_delete (list): |
|
A list of unique identifiers (IDs) to be deleted from the database. |
|
|
|
Returns: |
|
str: A success message is logged upon successful deletion. |
|
If an error occurs, a string describing the failure is returned. |
|
|
|
Raises: |
|
Exception: Logs an error if the deletion process encounters an issue. |
|
|
|
Example: |
|
>>> remove_ids_from_database(['id_123', 'id_456']) |
|
Logs: "IDs deleted successfully." |
|
|
|
Notes: |
|
- The method assumes `get_index()` initializes the index object. |
|
- Deletion occurs within the specified `NAMESPACE`. |
|
""" |
|
try: |
|
index.delete(ids=ids_to_delete, namespace=NAMESPACE) |
|
logger.info("IDs deleted successfully.") |
|
except Exception as e: |
|
return f"Failed to delete the IDs: {e}" |
|
|
|
|
|
def retrieve_relevant_metadata(embedding, prompt, n_result=3, score_threshold=0.47): |
|
""" |
|
Retrieves and reranks relevant context data based on a given prompt. |
|
""" |
|
try: |
|
response = index.query( |
|
top_k=n_result, |
|
vector=embedding, |
|
namespace=NAMESPACE, |
|
include_metadata=True |
|
) |
|
|
|
|
|
filtered_results = [ |
|
{ |
|
"question": entry.get('metadata', {}).get('question', 'N/A'), |
|
"answer": entry.get('metadata', {}).get('answer', 'N/A'), |
|
"instruction": entry.get('metadata', {}).get('instruction', 'N/A'), |
|
"score": str(entry.get('score', 0)), |
|
"id": str(entry.get('id', 'N/A')) |
|
} |
|
for entry in response.get('matches', []) |
|
if float(entry.get('score', 0)) >= score_threshold |
|
] |
|
|
|
logger.info(f"Retrieved filtered data: {filtered_results}") |
|
return filtered_results if filtered_results else [{"response": "No relevant data found."}] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Failed to fetch context for prompt: '{prompt}'. Error: {e}") |
|
return [{"response": "Failed to fetch data due to an error."}] |
|
|
|
|
|
def upsert_vector_data(df: pd.DataFrame): |
|
|
|
""" |
|
Generates embeddings for the given DataFrame and uploads data to Pinecone in batches. |
|
|
|
Parameters: |
|
- df (pd.DataFrame): DataFrame containing 'input', 'question', and 'answer' columns. |
|
|
|
Returns: |
|
- None |
|
""" |
|
|
|
try: |
|
df["embedding"] = [ |
|
get_text_embedding([q])[0] |
|
for q in tqdm(df["input"], desc="Generating Embeddings") |
|
] |
|
except Exception as e: |
|
logger.error(f"Error generating embeddings: {e}") |
|
return |
|
|
|
|
|
BATCH_SIZE = 500 |
|
|
|
for i in tqdm(range(0, len(df), BATCH_SIZE), desc="Uploading Data to Pinecone"): |
|
batch = df.iloc[i : i + BATCH_SIZE] |
|
|
|
vectors = [] |
|
for idx, (embedding, (_, row_data)) in enumerate(zip(batch["embedding"], batch.iterrows())): |
|
question = row_data.get("input") |
|
vector_id = f"{question[:50]}:{i + idx}" |
|
metadata = { |
|
"question": row_data.get("input"), |
|
"answer": row_data.get("output"), |
|
"instruction": row_data.get("instruction"), |
|
} |
|
vectors.append((vector_id, embedding, metadata)) |
|
|
|
try: |
|
index.upsert(vectors=vectors,namespace=NAMESPACE) |
|
except Exception as e: |
|
logger.error(f"Error uploading batch starting at index {i}: {e}") |
|
|
|
logger.info("All question-answer pairs stored successfully!") |
|
|
|
def retrieve_context_from_pinecone(embedding, n_result=3, score_threshold=0.4): |
|
""" |
|
Retrieves relevant context from Pinecone using vector embeddings. |
|
|
|
Args: |
|
- embedding (list): Embedding vector for query. |
|
- n_result (int): Number of top results to retrieve. |
|
- score_threshold (float): Minimum score threshold for relevance. |
|
|
|
Returns: |
|
- str: Combined context or fallback message. |
|
""" |
|
if not embedding or not isinstance(embedding, list): |
|
logger.warning("Invalid embedding received.") |
|
return "No relevant context found." |
|
|
|
try: |
|
response = index.query( |
|
top_k=n_result, |
|
vector=embedding, |
|
namespace=NAMESPACE, |
|
include_metadata=True |
|
) |
|
|
|
|
|
if 'matches' not in response: |
|
logger.error("Unexpected response structure from Pinecone.") |
|
return "No relevant context found." |
|
|
|
|
|
filtered_results = [] |
|
for entry in response.get('matches', []): |
|
score = entry.get('score', 0) |
|
answer = entry.get('metadata', {}).get('answer', 'N/A') |
|
|
|
if score >= score_threshold: |
|
filtered_results.append(f"{answer} (Score: {score:.2f})") |
|
else: |
|
logger.info(f"Entry skipped due to low score: {score:.2f}") |
|
|
|
|
|
context = "\n".join(filtered_results) if filtered_results else "No relevant context found." |
|
|
|
return context |
|
|
|
except Exception as e: |
|
logger.error(f"Unexpected error in Pinecone retrieval: {e}", exc_info=True) |
|
return "Error retrieving context. Please try again later." |