Spaces:
Sleeping
Sleeping
import json | |
import time | |
from sentence_transformers import SentenceTransformer | |
from pinecone import Pinecone, ServerlessSpec | |
from groq import Groq | |
from tqdm.auto import tqdm | |
import os | |
# Constants (hardcoded) | |
FILE_PATH = "anjibot_chunks.json" | |
BATCH_SIZE = 384 | |
INDEX_NAME = "groq-llama-3-rag" | |
PINECONE_API_KEY = os.getenv["PINECONE_API_KEY"] | |
GROQ_API_KEY = os.getenv["GROQ_API_KEY"] | |
DIMENSIONS = 768 | |
def load_data(file_path: str) -> dict: | |
with open(file_path, 'r') as file: | |
return json.load(file) | |
def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any: | |
pc = Pinecone(api_key=api_key) | |
spec = ServerlessSpec(cloud="aws", region='us-east-1') | |
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()] | |
# Check if index already exists; if not, create it | |
if index_name not in existing_indexes: | |
pc.create_index(index_name, dimension=dims, metric='cosine', spec=spec) | |
# Wait for the index to be initialized | |
while not pc.describe_index(index_name).status['ready']: | |
time.sleep(1) | |
return pc.Index(index_name) | |
def upsert_data_to_pinecone(index: any, data: dict): | |
encoder = SentenceTransformer('dwzhu/e5-base-4k') | |
for i in tqdm(range(0, len(data['id']), BATCH_SIZE)): | |
# Find end of batch | |
i_end = min(len(data['id']), i + BATCH_SIZE) | |
# Create batch | |
batch = {k: v[i:i_end] for k, v in data.items()} | |
# Create embeddings | |
chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]] | |
embeds = encoder.encode(chunks) | |
# Ensure correct length | |
assert len(embeds) == (i_end - i) | |
# Upsert to Pinecone | |
to_upsert = list(zip(batch["id"], embeds, batch["metadata"])) | |
index.upsert(vectors=to_upsert) | |
def get_docs(query: str, index: any, encoder: any, top_k: int) -> list[str]: | |
xq = encoder.encode(query) | |
res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True) | |
return [x["metadata"]['content'] for x in res["matches"]] | |
def get_response(query: str, docs: list[str], groq_client: any) -> str: | |
system_message = ( | |
"You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastica but still sweet.\n" | |
"Provide the answer to class related queries using\n" | |
"context provided below.\n" | |
"If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n" | |
"Anji's phone number: 08145170886.\n\n" | |
"CONTEXT:\n" | |
"\n---\n".join(docs) | |
) | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": query} | |
] | |
chat_response = groq_client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=messages | |
) | |
return chat_response.choices[0].message.content | |
def handle_query(user_query: str): | |
# Load data | |
data = load_data(FILE_PATH) | |
# Initialize Pinecone | |
index = initialize_pinecone(PINECONE_API_KEY, INDEX_NAME, DIMENSIONS) | |
# Upsert data into Pinecone | |
upsert_data_to_pinecone(index, data) | |
# Initialize encoder and Groq client | |
encoder = SentenceTransformer('dwzhu/e5-base-4k') | |
groq_client = Groq(api_key=GROQ_API_KEY) | |
# Get relevant documents | |
docs = get_docs(user_query, index, encoder, top_k=5) | |
# Generate and return response | |
response = get_response(user_query, docs, groq_client) | |
return response | |