Jashneet's picture
init!
5832556 verified
# final for hugging face
import os
import streamlit as st
import pandas as pd
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
from mirascope.core import groq
from pydantic import BaseModel
# Set page config
st.set_page_config(page_title="Smart Course Search with Faiss", page_icon="🔍", layout="wide")
# Groq API Key
GROQ_API_KEY = "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"
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
# Initialize Groq client
groq_client = Groq(api_key=GROQ_API_KEY)
# Initialize Sentence Transformer model
@st.cache_resource
def init_model():
return SentenceTransformer("all-MiniLM-L6-v2")
model = init_model()
# Load the CSV file
@st.cache_data
def load_data():
# In Hugging Face Spaces
csv_path = "Analytics_vidhya_final_data.csv"
if os.path.exists(csv_path):
return pd.read_csv(csv_path)
else:
st.error(f"CSV file not found at {csv_path}. Please make sure it's uploaded to your Space.")
return pd.DataFrame()
courses_df = load_data()
# Initialize Faiss index and document store
dimension = 384 # Embedding dimension for "all-MiniLM-L6-v2"
if 'faiss_index' not in st.session_state:
st.session_state['faiss_index'] = faiss.IndexFlatL2(dimension)
st.session_state['document_store'] = []
faiss_index = st.session_state['faiss_index']
document_store = st.session_state['document_store']
# Function to prepare course data
def prepare_course_data(df):
documents = []
metadatas = []
for index, row in df.iterrows():
title = row.get('Title', '').strip()
curriculum = row.get('Course Curriculum', '').strip()
description = row.get('Course Description', '').strip()
if not title or not curriculum or not description:
continue
content = f"{title} {curriculum} {description}".strip()
documents.append(content)
metadata = {
"title": title,
"curriculum": curriculum,
"description": description,
}
metadatas.append(metadata)
return documents, metadatas
# Add courses to Faiss index
def add_courses_to_faiss(df):
documents, metadatas = prepare_course_data(df)
if not documents:
st.warning("No valid documents to add to the database")
return 0
try:
embeddings = model.encode(documents)
faiss_index.add(np.array(embeddings, dtype="float32"))
document_store.extend(metadatas)
return len(documents)
except Exception as e:
st.error(f"Error adding documents to Faiss: {str(e)}")
return 0
# Faiss search function
def faiss_search(query, k=3):
if faiss_index.ntotal == 0:
st.warning("Faiss index is empty. Cannot perform search.")
return []
query_embedding = model.encode([query])
distances, indices = faiss_index.search(np.array(query_embedding, dtype="float32"), k)
results = []
for i, idx in enumerate(indices[0]):
if idx < len(document_store):
results.append({
"content": document_store[idx],
"metadata": document_store[idx],
"score": -distances[0][i]
})
return results
# Groq search function
def groq_search(user_query):
prompt = f"""
You are an AI assistant specializing in data science, machine learning, artificial intelligence, generative AI, data engineering, and data analytics. Your task is to analyze the following user query and determine if it's related to these fields:
User Query: "{user_query}"
Please provide a detailed response that includes:
1. Whether the query is related to the mentioned fields (data science, ML, AI, GenAI, data engineering, or data analytics).
2. If related, explain how it connects to these fields and suggest potential subtopics or courses that might be relevant.
3. If not directly related, try to find any indirect connections to the mentioned fields.
Your response should be informative and help guide a course recommendation system. End your response with a clear YES if the query is related to the mentioned fields, or NO if it's completely unrelated.
"""
try:
chat_completion = groq_client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model="llama3-8b-8192",
)
return chat_completion.choices[0].message.content
except Exception as e:
st.error(f"Error in Groq API call: {str(e)}")
return "ERROR: Unable to process the query"
# Mirascope analysis
class SearchResult(BaseModel):
final_output: int
@groq.call("llama-3.1-70b-versatile", response_model=SearchResult)
def extract_relevance(text: str) -> str:
return f"""Extract the integer from text whether we move forward or not it can be either 0 or 1: {text}"""
# Streamlit UI
st.title("Smart Course Search System")
# Show Faiss index count
db_count = faiss_index.ntotal
st.write(f"Current number of courses in the Faiss index: {db_count}")
# Add courses to database if not already added
if db_count == 0 and not courses_df.empty:
added_count = add_courses_to_faiss(courses_df)
st.success(f"{added_count} courses added to the Faiss index!")
db_count = faiss_index.ntotal
st.write(f"Updated number of courses in the Faiss index: {db_count}")
# Search query input
user_query = st.text_input("Enter your search query")
if user_query:
with st.spinner("Analyzing your query..."):
groq_response = groq_search(user_query)
search_result = extract_relevance(groq_response)
if search_result.final_output == 1:
st.success("Your query is relevant to our course catalog. Here are the search results:")
results = faiss_search(user_query)
if results:
st.subheader("Search Results")
for i, result in enumerate(results, 1):
with st.expander(f"Result {i}: {result['metadata']['title']}"):
st.write(f"Relevance Score: {result['score']:.2f}")
st.subheader("Course Curriculum")
st.write(result['metadata']['curriculum'])
st.subheader("Course Description")
st.write(result['metadata']['description'])
else:
st.info("No specific courses found for your query. Try a different search term.")
else:
st.warning("We're sorry, but we couldn't find any courses directly related to your search query.")
st.write("Our current catalog focuses on data science, machine learning, artificial intelligence, generative AI, data engineering, and data analytics. Please try a different search term related to these fields.")
# Debug information
if st.checkbox("Show Debug Information"):
st.subheader("Debug Information")
st.write(f"Database count: {db_count}")
if db_count > 0:
st.write("Sample document:")
if document_store:
st.json(document_store[0])
if __name__ == "__main__":
pass