|
import streamlit as st |
|
import transformers |
|
from dotenv import load_dotenv, find_dotenv |
|
import os |
|
|
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch |
|
|
|
from huggingface_hub import InferenceClient |
|
from pymongo import MongoClient |
|
from pymongo.collection import Collection |
|
from typing import Dict, Any |
|
from datetime import datetime |
|
|
|
MONGO_URI = st.secrets["MONGO_URI"] |
|
HF_TOKEN = st.secrets["HF_TOKEN"] |
|
DB_NAME = "txts" |
|
COLLECTION_NAME = "txts_collection" |
|
VECTOR_SEARCH_INDEX = "vector_index" |
|
|
|
|
|
@st.cache_resource |
|
def init_mongodb(): |
|
|
|
cluster = MongoClient(MONGO_URI) |
|
return cluster[DB_NAME][COLLECTION_NAME] |
|
|
|
@st.cache_resource |
|
def init_vector_search() -> MongoDBAtlasVectorSearch: |
|
print('CACHING VECTOR SEARCH') |
|
return MongoDBAtlasVectorSearch.from_connection_string( |
|
connection_string=MONGO_URI, |
|
namespace=f"{DB_NAME}.{COLLECTION_NAME}", |
|
embedding=embedding_model, |
|
index_name=VECTOR_SEARCH_INDEX, |
|
) |
|
|
|
@st.cache_resource |
|
def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings: |
|
return HuggingFaceInferenceAPIEmbeddings( |
|
api_key=HF_TOKEN, |
|
model_name="sentence-transformers/all-mpnet-base-v2", |
|
) |
|
|
|
def get_context_from_retrived_docs(retrieved_docs): |
|
return "\n\n".join(doc.page_content for doc in retrieved_docs) |
|
|
|
def format_prompt(user_query, retreived_context): |
|
prompt = f"""Use the following pieces of context to answer the question at the end. |
|
|
|
START OF CONTEXT: |
|
{retreived_context} |
|
END OF CONTEXT: |
|
|
|
START OF QUESTION: |
|
{user_query} |
|
END OF QUESTION: |
|
|
|
If you do not know the answer, just say that you do not know. |
|
NEVER assume things. |
|
""".format(retreived_context=retreived_context, user_query=user_query) |
|
|
|
return prompt |
|
|
|
|
|
|
|
mongodb_collection = init_mongodb() |
|
embedding_model = init_embedding_model() |
|
vector_search = init_vector_search() |
|
hf_client = InferenceClient(api_key=HF_TOKEN) |
|
|
|
|
|
user_query = st.text_area('Ask a question about CTP Class') |
|
|
|
if user_query: |
|
|
|
|
|
relevent_documents = vector_search.similarity_search(query=user_query, k=10) |
|
|
|
|
|
context = get_context_from_retrived_docs(relevent_documents) |
|
|
|
|
|
prompt = format_prompt(user_query=user_query, retreived_context=context) |
|
|
|
|
|
response = hf_client.chat.completions.create( |
|
model="Qwen/Qwen2.5-1.5B-Instruct", |
|
messages=[{ |
|
"role": "system", |
|
"content": 'you are an assistant, answer the question below' |
|
},{ |
|
"role": "user", |
|
"content": prompt |
|
}], |
|
max_tokens=1400, |
|
temperature=0.2, |
|
) |
|
model_response = response.choices[0].message.content |
|
st.text(model_response) |
|
|