Spaces:
Sleeping
Sleeping
import os | |
import csv | |
from functools import partial | |
from typing import List | |
import time | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import chromadb | |
from chromadb.config import Settings | |
from chromadb import Documents, EmbeddingFunction, Embeddings | |
from google import genai | |
from google.genai import types | |
from tqdm import tqdm | |
from google.genai.errors import ClientError | |
class GeminiEmbeddingFunction(EmbeddingFunction): | |
def __init__(self, gemini_client: genai.Client, emb_model: str): | |
self.gemini_client = gemini_client | |
self.emb_model = emb_model | |
def __call__(self, input_batch: List[Documents]) -> List[Embeddings]: | |
gemini_out = self.gemini_client.models.embed_content(model=self.emb_model, contents=input_batch) | |
embeddings = [e.values for e in gemini_out.embeddings] | |
return embeddings | |
def create_or_get_chroma_db( | |
gemini_client: genai.Client, | |
emb_model: str, | |
articles_md_root: str, | |
file2url_path: str, | |
db_root: str, | |
) -> chromadb.Collection: | |
# Create the database root directory | |
os.makedirs(db_root, exist_ok=True) | |
# Initialize the Chroma client | |
chroma_client = chromadb.PersistentClient(path=db_root, settings=Settings(anonymized_telemetry=False)) | |
# Attempt to retrieve the existing collection | |
db = chroma_client.get_or_create_collection(name="Oura_Support_Faq", | |
embedding_function=GeminiEmbeddingFunction(gemini_client, emb_model)) | |
# Check if the collection already exists | |
if db.count() > 0: | |
print(f"Collection already exists with {db.count()} documents.") | |
return db | |
# Load the filename2url mapping | |
with open(file2url_path, 'r') as f: | |
reader = csv.reader(f) | |
rows = [row for row in reader] | |
filename2url = {rows[1] + '.md': rows[0] for rows in rows} | |
filename2title = {rows[1] + '.md': rows[2] for rows in rows} | |
# Load and chunk the documents from the output directory | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
add_start_index=True, | |
separators=["##", "\n\n", "\n", " ", ""], | |
) | |
splits = [] | |
for filename in os.listdir(articles_md_root): | |
if filename.endswith('.md'): | |
with open(os.path.join(articles_md_root, filename), 'r', encoding='utf-8') as f: | |
content = f.read() | |
meta = {'Source': f'[{filename2title[filename]}]({filename2url[filename]})',} | |
chunks = splitter.create_documents([content], metadatas=[meta]) | |
splits.extend(chunks) | |
# Extract documents, metadata, and IDs from the chunks | |
documents = [chunk.page_content for chunk in splits] | |
metadatas = [chunk.metadata for chunk in splits] | |
ids = [f"id_{i}" for i in range(len(documents))] | |
# Using batching to embed multiple documents at once, without calling the API too many times | |
batch_size = 64 | |
# Iterate over the documents in batches | |
for i in tqdm(range(0, len(documents), batch_size)): | |
slice_batch = slice(i, i + batch_size) | |
doc_batch, meta_batch, ids_batch = documents[slice_batch], metadatas[slice_batch], ids[slice_batch] | |
# Add documents, embeddings, IDs, and metadata to the collection | |
# using upsert to add new documents or update existing ones | |
db.upsert( | |
documents=doc_batch, | |
ids=ids_batch, | |
metadatas=meta_batch | |
) | |
time.sleep(2) # Optional: sleep to avoid hitting API limits | |
print(f"Added {len(documents)} documents to the database.") | |
return db | |
def get_prompt_from_question(question: str, db: chromadb.Collection, n: int, verbose: bool = True) -> str: | |
prompt_template = """ | |
Context from Oura documentation/forums:""" | |
result = db.query(query_texts=[question], n_results=n) | |
passages = result['documents'][0] | |
for p in range(len(passages)): | |
sim = result['distances'][0][p] | |
source = result['metadatas'][0][p]['Source'] | |
prompt_template += f"\n\nChunk {p}:\n" | |
prompt_template += f"Similarity: {sim:.3f}\n" | |
prompt_template += f"Source: {source}\n" | |
prompt_template += passages[p] | |
prompt_template += "\n\n" + f"User Question: {question}" | |
prompt_template += "\n\n" + "Answer:" | |
if verbose: | |
print(prompt_template) | |
return prompt_template | |
def chatbot_response( | |
user_input: str, | |
history: list, | |
db: chromadb.Collection, | |
llm_name: str, | |
system_prompt: str, | |
turns_to_keep: int, | |
num_neighbors_per_query: int, | |
) -> str: | |
# Add the user input to the conversation history | |
# keeping only the last turns_to_keep turns | |
history = history[-turns_to_keep:] | |
conversation_history = "\n".join([f'User: {turn[0]}; Agent: {turn[1]}' for turn in history]) | |
try: | |
prompt = get_prompt_from_question(user_input, db, num_neighbors_per_query) | |
prompt = f"Previous turns: {conversation_history} \n\n New prompt: {prompt}" | |
print('==========================================================================') | |
print(prompt) | |
print('==========================================================================') | |
# Generate the response using the Gemini API | |
response = client.models.generate_content( | |
model=llm_name, | |
contents=prompt, | |
config=types.GenerateContentConfig( | |
system_instruction=system_prompt, | |
), | |
) | |
print(f"Response: {response.text}") | |
print('') | |
print('') | |
except ClientError as e: | |
print(f"Got the error: {e}.") | |
print('Maybe API is busy. Will try in a second...') | |
time.sleep(3) | |
prompt = get_prompt_from_question(user_input, db, num_neighbors_per_query) | |
response = client.models.generate_content( | |
model=llm_name, | |
contents=prompt, | |
config=types.GenerateContentConfig( | |
system_instruction=system_prompt, | |
), | |
) | |
return response.text | |
if __name__ == "__main__": | |
llm_name = 'gemini-2.0-flash-001' | |
emb_model = 'models/text-embedding-004' | |
articles_md_root = './assets/oura_articles' | |
file2url_path = './assets/data/links_paths.csv' | |
db_root = os.path.join(os.getcwd(), 'assets', f'databases-{emb_model.split("/")[-1]}') # ./assets/databases-<emb_model_name>/ | |
turns_to_keep = 5 | |
num_neighbors_per_query = 5 | |
system_prompt = """ | |
You are an AI assistant specializing in providing support for the Oura mobile application, assisting users with their inquiries based solely on the provided context. | |
## Rules: | |
- **Exclusive Reliance on Provided Context**: Answer questions using only the supplied context. Do not incorporate external knowledge. | |
- **Handling Insufficient Context**: | |
- If the context lacks sufficient information, respond with: *"I cannot answer based on the provided information."* | |
- If a user query contains ambiguous references (e.g., "it", "this") and the context does not clarify them, politely ask for clarification: *"Could you please specify what you mean by 'it'? 😊"* | |
- **Citations**: Cite information using the 'Source' metadata provided with each chunk. Keep citations sparse—cite once per paragraph or at the end of the relevant section. Use hyper-links. | |
- **Preference for Relevant Chunks**: Prioritize information from chunks with lower similarity scores, as they are more pertinent. | |
## Formatting Guidelines: | |
- **Markdown Usage**: Format responses in Markdown for clarity and readability. | |
- **Tone**: Maintain a friendly and engaging tone. 😊 A couple of well-placed emojis are encouraged! | |
- **Image Inclusion**: | |
- Use HTML for images: `<img src="..." alt="..." style="object-fit: contain; ..." />` | |
- If the original `alt` text includes "icon", add `width: 50px; height: 50px;` to the `style` attribute, e.g. battery, share, adjustment, menu etc icons. | |
## Examples: | |
**Example 1: Sufficient Context** | |
*User Question*: "How can I change the units of measurement in the Oura app?" | |
*Context Provided*: | |
- Chunk 1: "# Oura App Languages | |
The Oura App is currently available in: | |
* Danish | |
... | |
* Swedish | |
The Oura App also supports both metric and imperial units of measurement, which can be adjusted through the app's  menu > Settings > Units." | |
*Response*: | |
"You can adjust the units of measurement in the Oura app through the app's <img src="https://support.ouraring.com/hc/article_attachments/28601966533139" alt="Icon Bars Menu.png" style="object-fit: contain; width: 50px; height: 50px;"/> menu > **Settings** > **Units**. Source: [Oura App Languages](https://support.ouraring.com/hc/en-us/articles/360058028053-Oura-App-Languages)" | |
**Example 2: Ambiguous Reference with Insufficient Context** | |
*User Question*: "Can you do X?" | |
*Agent Response*: "Yes, you can." | |
*User Follow-up*: "How to do it?" | |
*Context Provided*: *(No relevant information about 'X')* | |
*Response*: | |
"Could you please specify what you mean by 'it' so I can assist you better? 😊" | |
**Example 3: Insufficient Context without Ambiguity** | |
*User Question*: "What is the Oura app's refund policy?" | |
*Context Provided*: *(No information on refund policy)* | |
*Response*: | |
"I cannot answer based on the provided information." | |
**Example 4: Image Inclusion** | |
*User Question*: "How to do X?" | |
*Context Provided*: "Oura App supports X, and you can do it by following these steps: | |
... | |
 | |
... | |
 | |
... | |
[](/hc/article_attachments/36252969067283) | |
..." | |
*Response*: | |
"To do X, follow these steps: | |
... | |
<img src="https://support.ouraring.com/hc/article_attachments/34549633600147" alt="hw_reset_remastered.gif" style="object-fit: contain;"/> | |
... | |
<img src="https://support.ouraring.com/hc/article_attachments/28720126068115" alt="ring battery level icon" style="object-fit: contain; width: 50px; height: 50px;"/> | |
... | |
<img src="https://support.ouraring.com/hc/article_attachments/36252969067283" alt="app_ux_today_tab.png" style="object-fit: contain;"/> | |
..." | |
(only an icon is resized) | |
## Final Reminders: | |
- Base responses strictly on the retrieved context. | |
- Avoid fabricating information. | |
- Be friendly and engaging. | |
- Cite sparsely. | |
- Use html for images with `object-fit: contain;` style, and resize icons to `50px` width and height. | |
- When in doubt, seek clarification or acknowledge the lack of information. | |
""" | |
# Initialize the Gemini client | |
client = genai.Client(api_key=os.environ['GEMINI_KEY']) | |
# Initialize Chroma database: create or load the database | |
db = create_or_get_chroma_db( | |
gemini_client=client, | |
emb_model=emb_model, | |
articles_md_root=articles_md_root, | |
file2url_path=file2url_path, | |
db_root=db_root, | |
) | |
chatbot_response_partial = partial( | |
chatbot_response, | |
db=db, | |
llm_name=llm_name, | |
system_prompt=system_prompt, | |
turns_to_keep=turns_to_keep, | |
num_neighbors_per_query=num_neighbors_per_query, | |
) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
chatbot = gr.ChatInterface( | |
title='My Precious: Your Inner Circle of Insight', | |
fn=chatbot_response_partial, | |
) | |
# Launch the Gradio app | |
demo.launch() | |