Spaces:
Sleeping
Sleeping
| import openai, os, time | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from document_model import Listing, SearchResultItem | |
| from pydantic import ValidationError | |
| from pymongo.collection import Collection | |
| from pymongo.errors import OperationFailure | |
| from pymongo.operations import SearchIndexModel | |
| from pymongo.mongo_client import MongoClient | |
| DB_NAME = "airbnb_dataset" | |
| COLLECTION_NAME = "listings_reviews" | |
| def connect_to_database(): | |
| MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"] | |
| mongo_client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, appname="advanced-rag") | |
| db = mongo_client.get_database(DB_NAME) | |
| collection = db.get_collection(COLLECTION_NAME) | |
| return db, collection | |
| def rag_ingestion(collection): | |
| dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train") | |
| dataset_df = pd.DataFrame(dataset) | |
| listings = process_records(dataset_df) | |
| collection.delete_many({}) | |
| collection.insert_many(listings) | |
| return "Manually create a vector search index (in free tier, this feature is not available via SDK)" | |
| def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"): | |
| # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot' | |
| get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index) | |
| # Check if there are any results | |
| if not get_knowledge: | |
| return "No results found.", "No source information available." | |
| # Convert search results into a list of SearchResultItem models | |
| search_results_models = [ | |
| SearchResultItem(**result) | |
| for result in get_knowledge | |
| ] | |
| # Convert search results into a DataFrame for better rendering in Jupyter | |
| search_results_df = pd.DataFrame([item.dict() for item in search_results_models]) | |
| print("###") | |
| print(search_results_df) | |
| print("###") | |
| return search_results_df | |
| def rag_inference(openai_api_key, prompt, search_results): | |
| openai.api_key = openai_api_key | |
| # Generate system response using OpenAI's completion | |
| content = f"Answer this user question: {prompt} with the following context:\n{search_results}" | |
| completion = openai.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are an AirBnB listing recommendation system."}, | |
| { | |
| "role": "user", | |
| "content": content | |
| } | |
| ] | |
| ) | |
| completion_result = completion.choices[0].message.content | |
| print("###") | |
| print(completion_result) | |
| print("###") | |
| return completion_result | |
| def process_records(data_frame): | |
| records = data_frame.to_dict(orient="records") | |
| # Handle potential NaT values | |
| for record in records: | |
| for key, value in record.items(): | |
| # List values | |
| if isinstance(value, list): | |
| processed_list = [None if pd.isnull(v) else v for v in value] | |
| record[key] = processed_list | |
| # Scalar values | |
| else: | |
| if pd.isnull(value): | |
| record[key] = None | |
| try: | |
| # Convert each dictionary to a Listing instance | |
| return [Listing(**record).dict() for record in records] | |
| except ValidationError as e: | |
| print("Validation error:", e) | |
| return [] | |
| def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index_text"): | |
| """ | |
| Perform a vector search in the MongoDB collection based on the user query. | |
| Args: | |
| user_query (str): The user's query string. | |
| db (MongoClient.database): The database object. | |
| collection (MongoCollection): The MongoDB collection to search. | |
| additional_stages (list): Additional aggregation stages to include in the pipeline. | |
| Returns: | |
| list: A list of matching documents. | |
| """ | |
| # Generate embedding for the user query | |
| query_embedding = get_embedding(openai_api_key, user_query) | |
| if query_embedding is None: | |
| return "Invalid query or embedding generation failed." | |
| # Define the vector search stage | |
| vector_search_stage = { | |
| "$vectorSearch": { | |
| "index": vector_index, # specifies the index to use for the search | |
| "queryVector": query_embedding, # the vector representing the query | |
| "path": "text_embeddings", # field in the documents containing the vectors to search against | |
| "numCandidates": 150, # number of candidate matches to consider | |
| "limit": 20, # return top 20 matches | |
| } | |
| } | |
| # Define the aggregate pipeline with the vector search stage and additional stages | |
| pipeline = [vector_search_stage] + additional_stages | |
| # Execute the search | |
| results = collection.aggregate(pipeline) | |
| explain_query_execution = db.command( # sends a database command directly to the MongoDB server | |
| 'explain', { # return information about how MongoDB executes a query or command without actually running it | |
| 'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed | |
| 'pipeline': pipeline, # the aggregation pipeline to analyze | |
| 'cursor': {} # indicates that default cursor behavior should be used | |
| }, | |
| verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline | |
| vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch'] | |
| #millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed'] | |
| print(vector_search_explain) | |
| #print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds") | |
| return list(results) | |
| def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_2"): | |
| """ | |
| Perform a vector search in the MongoDB collection based on the user query. | |
| Args: | |
| user_query (str): The user's query string. | |
| db (MongoClient.database): The database object. | |
| collection (MongoCollection): The MongoDB collection to search. | |
| additional_stages (list): Additional aggregation stages to include in the pipeline. | |
| Returns: | |
| list: A list of matching documents. | |
| """ | |
| # Generate embedding for the user query | |
| query_embedding = get_embedding(user_query) | |
| if query_embedding is None: | |
| return "Invalid query or embedding generation failed." | |
| # Define the vector search stage | |
| vector_search_stage = { | |
| "$vectorSearch": { | |
| "index": vector_index, # specifies the index to use for the search | |
| "queryVector": query_embedding, # the vector representing the query | |
| "path": "text_embeddings", # field in the documents containing the vectors to search against | |
| "numCandidates": 150, # number of candidate matches to consider | |
| "limit": 20, # return top 20 matches | |
| "filter": { | |
| "$and": [ | |
| {"accommodates": {"$gte": 2}}, | |
| {"bedrooms": {"$lte": 7}} | |
| ] | |
| }, | |
| } | |
| } | |
| # Define the aggregate pipeline with the vector search stage and additional stages | |
| pipeline = [vector_search_stage] + additional_stages | |
| # Execute the search | |
| results = collection.aggregate(pipeline) | |
| explain_query_execution = db.command( # sends a database command directly to the MongoDB server | |
| 'explain', { # return information about how MongoDB executes a query or command without actually running it | |
| 'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed | |
| 'pipeline': pipeline, # the aggregation pipeline to analyze | |
| 'cursor': {} # indicates that default cursor behavior should be used | |
| }, | |
| verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline | |
| vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch'] | |
| millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed'] | |
| print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds") | |
| return list(results) | |
| def get_embedding(openai_api_key, text): | |
| """Generate an embedding for the given text using OpenAI's API.""" | |
| # Check for valid input | |
| if not text or not isinstance(text, str): | |
| return None | |
| openai.api_key = openai_api_key | |
| try: | |
| embedding = openai.embeddings.create( | |
| input=text, | |
| model="text-embedding-3-small", dimensions=1536).data[0].embedding | |
| return embedding | |
| except Exception as e: | |
| print(f"Error in get_embedding: {e}") | |
| return None |