File size: 5,234 Bytes
f843983
 
8ab10bb
8aba95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f221f0
dd1f453
8aba95b
 
 
bc7a926
 
 
 
 
 
8aba95b
 
 
 
 
 
 
 
 
f843983
 
 
 
 
 
 
 
 
 
 
 
1fd4ac2
f843983
f0a4e40
 
 
 
fe6ba28
 
 
 
 
 
 
 
8f9e4c8
c812fc2
 
 
 
 
 
 
 
 
 
8f9e4c8
 
 
 
 
 
 
 
 
fe6ba28
 
 
 
 
 
 
 
 
 
 
2e1fa1f
 
 
 
 
 
 
 
fe6ba28
8f9e4c8
fe6ba28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ccb529
fe6ba28
f5ccf05
9ccb529
 
45ba126
 
9ccb529
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import streamlit as st
import asyncio
from typing_extensions import TypedDict, List
from IPython.display import Image, display
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.schema import Document
from langgraph.graph import START, END, StateGraph
from langchain.prompts import PromptTemplate
import uuid
from langchain_groq import ChatGroq
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_chroma import Chroma
from langchain_community.document_loaders import NewsURLLoader
from langchain_community.retrievers.wikipedia import WikipediaRetriever
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import Chroma
from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain.schema import Document
from langgraph.graph import START, END, StateGraph
from langchain_community.document_loaders.directory import DirectoryLoader
from langchain.document_loaders import TextLoader
from functions import *


lang_api_key = os.getenv("lang_api_key")
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
groq_api_key = os.getenv("groq_api_key")




os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus"
os.environ["LANGCHAIN_API_KEY"] = lang_api_key
os.environ["LANGCHAIN_PROJECT"] = "Lithuanian_Law_RAG_QA"
os.environ["GROQ_API_KEY"] = groq_api_key
os.environ["SERPER_API_KEY"] = SERPER_API_KEY





def main():


    

    st.set_page_config(page_title="Info Assistant: ",
                       page_icon=":books:")
    

    st.header("Info Assistant :" ":books:")
    
    st.markdown("""
        ###### Get support of **"Info Assistant"**, who has in memory a lot of Data Science related articles. 
        If it can't answer based on its knowledge base, information will be found on the internet :books:
    """)


    if "messages" not in st.session_state:
        st.session_state["messages"] = [
        {"role": "assistant", "content": "Hi, I'm a chatbot who is  based on respublic of Lithuania law documents. How can I help you?"}
    ]


    class GraphState(TypedDict):
        """
        Represents the state of our graph.

        Attributes:
            question: question
            generation: LLM generation
            search: whether to add search
            documents: list of documents
            generations_count : generations count
        """

        question: str
        generation: str
        search: str
        documents: List[str]
        steps: List[str]
        generation_count: int

    
    search_type = st.selectbox(
        "Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (similarity)]", 
        options=["mmr", "similarity"], 
        index=1  
    )

    k = st.select_slider(
        "Select amount of documents to be retrieved. Default value (5): ", 
        options=list(range(2, 16)), 
        value=4  
    )

    llm = ChatGroq(
    model="gemma2-9b-it",  # Specify the Gemma2 9B model
    temperature=0.0,
    max_tokens=400,
    max_retries=3
    )
    
    retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=350, chunk_overlap=30)
    



    # Graph
    workflow = StateGraph(GraphState)

    # Define the nodes
    workflow.add_node("ask_question", ask_question)
    workflow.add_node("retrieve", retrieve)  # retrieve
    workflow.add_node("grade_documents", grade_documents)  # grade documents
    workflow.add_node("generate", generate)  # generatae
    workflow.add_node("web_search", web_search)  # web search
    workflow.add_node("transform_query", transform_query)


    # Build graph
    workflow.set_entry_point("ask_question")
    workflow.add_conditional_edges(
        "ask_question",
        grade_question_toxicity,
    
        {
        "good": "retrieve",
        'bad': END,
        
        },
    )

    workflow.add_edge("retrieve", "grade_documents")
    workflow.add_conditional_edges(
        "grade_documents",
        decide_to_generate,
        {
            "search": "web_search",
            "generate": "generate",
        
        },
    )
    workflow.add_edge("web_search", "generate")
    workflow.add_conditional_edges(
        "generate",
        grade_generation_v_documents_and_question,
        {
            "not supported": "generate",
            "useful": END,
            "not useful": "transform_query",
        },
    )

    workflow.add_edge("transform_query", "retrieve")

    custom_graph = workflow.compile()
    
    

    
    if user_question := st.text_input("Ask a question about your documents:"):
        asyncio.run(handle_userinput(user_question, custom_graph))            



if __name__ == "__main__":
    main()