Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import asyncio | |
from typing_extensions import TypedDict, List | |
from IPython.display import Image, display | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain.schema import Document | |
from langgraph.graph import START, END, StateGraph | |
from langchain.prompts import PromptTemplate | |
import uuid | |
from langchain_groq import ChatGroq | |
from langchain_community.utilities import GoogleSerperAPIWrapper | |
from langchain_chroma import Chroma | |
from langchain_community.document_loaders import NewsURLLoader | |
from langchain_community.retrievers.wikipedia import WikipediaRetriever | |
from sentence_transformers import SentenceTransformer | |
from langchain.vectorstores import Chroma | |
from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain_community.vectorstores.utils import filter_complex_metadata | |
from langchain.schema import Document | |
from langgraph.graph import START, END, StateGraph | |
from langchain_community.document_loaders.directory import DirectoryLoader | |
from langchain.document_loaders import TextLoader | |
from functions import * | |
lang_api_key = os.getenv("lang_api_key") | |
SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
groq_api_key = os.getenv("groq_api_key") | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus" | |
os.environ["LANGCHAIN_API_KEY"] = lang_api_key | |
os.environ["LANGCHAIN_PROJECT"] = "Info Assistant" | |
os.environ["GROQ_API_KEY"] = groq_api_key | |
os.environ["SERPER_API_KEY"] = SERPER_API_KEY | |
def main(): | |
st.set_page_config(page_title="Info Assistant: ", | |
page_icon=":books:") | |
st.header("Info Assistant :" ":books:") | |
logo_path = "digital-a-high-resolution-logo-transparent (2).png" | |
link = ""https://digitala.lt/"" | |
st.logo(logo_path | |
, *, link=link, icon_image=None) | |
st.markdown(""" | |
###### Get support of **"Info Assistant"**, who has in memory a lot of Data Science related articles. | |
If it can't answer based on its knowledge base, information will be found on the internet :books: | |
""") | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [ | |
{"role": "assistant", "content": "Hi, I'm a chatbot who is based on respublic of Lithuania law documents. How can I help you?"} | |
] | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
question: question | |
generation: LLM generation | |
search: whether to add search | |
documents: list of documents | |
generations_count : generations count | |
""" | |
question: str | |
generation: str | |
search: str | |
documents: List[str] | |
steps: List[str] | |
generation_count: int | |
search_type = st.selectbox( | |
"Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (similarity)]", | |
options=["mmr", "similarity"], | |
index=1 | |
) | |
k = st.select_slider( | |
"Select amount of documents to be retrieved. Default value (5): ", | |
options=list(range(2, 16)), | |
value=4 | |
) | |
llm = ChatGroq( | |
model="gemma2-9b-it", # Specify the Gemma2 9B model | |
temperature=0.0, | |
max_tokens=400, | |
max_retries=3 | |
) | |
retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=550, chunk_overlap=40) | |
# Graph | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("ask_question", lambda state: ask_question(state, retriever)) | |
workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) | |
workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm) )) # grade documents | |
workflow.add_node("generate", lambda state: generate(state,QA_chain(llm) )) # generatae | |
workflow.add_node("web_search", web_search) # web search | |
workflow.add_node("transform_query", lambda state: transform_query(state,create_question_rewriter(llm) )) | |
# Build graph | |
workflow.set_entry_point("ask_question") | |
workflow.add_conditional_edges( | |
"ask_question", | |
lambda state: grade_question_toxicity(state, create_toxicity_checker(llm)), | |
{ | |
"good": "retrieve", | |
'bad': END, | |
}, | |
) | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"search": "web_search", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("web_search", "generate") | |
workflow.add_conditional_edges( | |
"generate", | |
lambda state: grade_generation_v_documents_and_question(state, create_hallucination_checker(llm), create_helpfulness_checker(llm)), | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "transform_query", | |
}, | |
) | |
workflow.add_edge("transform_query", "retrieve") | |
custom_graph = workflow.compile() | |
if user_question := st.text_input("Ask a question about your documents:"): | |
asyncio.run(handle_userinput(user_question, custom_graph)) | |
if __name__ == "__main__": | |
main() |