Spaces:
Running
Running
File size: 4,879 Bytes
5c38fee daa80cf 5c38fee 4b2f569 5ef932e 5c38fee 82fd045 5c38fee c26729e 9e5685a 5c38fee 5ef932e 5c38fee 5ef932e 838b33c 82fd045 0835210 f8ac3f0 3cc300e 0835210 3cc300e 6eee7c9 ddeba7a 6eee7c9 ddeba7a 6eee7c9 120d45f 82fd045 5c38fee aabb4c2 838b33c 5c38fee 838b33c 9e5685a 4af4bf5 5c38fee a086fab 5c38fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import logging, os, sys, threading
from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference
lock = threading.Lock()
RAG_INGESTION = False
RAG_OFF = "Off"
RAG_NAIVE = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if not openai_api_key:
raise gr.Error("OpenAI API Key is required.")
if not prompt:
raise gr.Error("Prompt is required.")
if not rag_option:
raise gr.Error("Retrieval-Augmented Generation is required.")
with lock:
db, collection = connect_to_database()
if (RAG_INGESTION):
return rag_ingestion(collection)
else:
### Pre-retrieval processing: index filter
### Post-retrieval processing: result filter
#match_stage = {
# "$match": {
# "accommodates": { "$eq": 2},
# "bedrooms": { "$eq": 1}
# }
#}
#additional_stages = [match_stage]
###
"""
projection_stage = {
"$project": {
"_id": 0,
"name": 1,
"accommodates": 1,
"address.street": 1,
"address.government_area": 1,
"address.market": 1,
"address.country": 1,
"address.country_code": 1,
"address.location.type": 1,
"address.location.coordinates": 1,
"address.location.is_location_exact": 1,
"summary": 1,
"space": 1,
"neighborhood_overview": 1,
"notes": 1,
"score": {"$meta": "vectorSearchScore"}
}
}
additional_stages = [projection_stage]
"""
###
review_average_stage = {
"$addFields": {
"averageReviewScore": {
"$divide": [
{
"$add": [
"$review_scores.review_scores_accuracy",
"$review_scores.review_scores_cleanliness",
"$review_scores.review_scores_checkin",
"$review_scores.review_scores_communication",
"$review_scores.review_scores_location",
"$review_scores.review_scores_value",
]
},
6 # Divide by the number of review score types to get the average
]
},
# Calculate a score boost factor based on the number of reviews
"reviewCountBoost": "$number_of_reviews"
}
}
weighting_stage = {
"$addFields": {
"combinedScore": {
# Example formula that combines average review score and review count boost
"$add": [
{"$multiply": ["$averageReviewScore", 0.9]}, # Weighted average review score
{"$multiply": ["$reviewCountBoost", 0.1]} # Weighted review count boost
]
}
}
}
# Apply the combinedScore for sorting
sorting_stage_sort = {
"$sort": {"combinedScore": -1} # Descending order to boost higher combined scores
}
additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
###
#additional_stages = []
###
search_results = rag_retrieval(openai_api_key, prompt, db, collection, additional_stages)
return rag_inference(openai_api_key, prompt, search_results)
gr.close_all()
PROMPT = "Recommend a place that's modern, spacious, and within walking distance from restaurants."
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = PROMPT, lines = 1),
gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
outputs = [gr.Markdown(label = "Completion")],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"]
)
demo.launch() |