Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from typing_extensions import TypedDict, List | |
| from IPython.display import Image, display | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langchain.schema import Document | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain.prompts import PromptTemplate | |
| import uuid | |
| from langchain_groq import ChatGroq | |
| from langchain_community.utilities import GoogleSerperAPIWrapper | |
| from langchain_chroma import Chroma | |
| from langchain_community.document_loaders import NewsURLLoader | |
| from langchain_community.retrievers.wikipedia import WikipediaRetriever | |
| from sentence_transformers import SentenceTransformer | |
| from langchain.vectorstores import Chroma | |
| from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain_community.vectorstores.utils import filter_complex_metadata | |
| from langchain.schema import Document | |
| from langchain_community.document_loaders.directory import DirectoryLoader | |
| from langchain.document_loaders import TextLoader | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain.retrievers import WebResearchRetriever | |
| from langchain.callbacks.manager import CallbackManager | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from exa_py import Exa | |
| os.environ["LANGCHAIN_TRACING_V2"]="true" | |
| os.environ["LANGCHAIN_ENDPOINT"]= "https://api.smith.langchain.com" | |
| os.environ["LANGCHAIN_PROJECT"] = "Civilinės_teises_Asistente_V1_Embed" | |
| lang_api_key = os.getenv("LANGCHAIN_API_KEY") | |
| SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| exa_api_key = os.getenv("exa_api_key") | |
| exa = Exa(api_key="exa_api_key") | |
| def create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30): | |
| model_name = "Alibaba-NLP/gte-multilingual-base" | |
| model_kwargs = {'device': 'cpu', | |
| "trust_remote_code" : 'False'} | |
| encode_kwargs = {'normalize_embeddings': True} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path): | |
| vectorstore = Chroma(persist_directory=vectorstore_path,embedding_function=embeddings) | |
| else: | |
| st.write("Vector store doesnt exist and will be created now") | |
| loader = DirectoryLoader('./data/', glob="./*.txt", loader_cls=TextLoader) | |
| docs = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
| separators=["\n\n \n\n","\n\n\n", "\n\n", r"In \[[0-9]+\]", r"\n+", r"\s+"], | |
| is_separator_regex = True | |
| ) | |
| split_docs = text_splitter.split_documents(docs) | |
| vectorstore = Chroma.from_documents( | |
| documents=split_docs, embedding=embeddings, persist_directory=vectorstore_path, | |
| ) | |
| retriever=vectorstore.as_retriever(search_type = search_type, search_kwargs={"k": k}) | |
| return retriever | |
| def handle_userinput(user_question, custom_graph): | |
| # Add the user's question to the chat history and display it in the UI | |
| st.session_state.messages.append({"role": "user", "content": user_question}) | |
| st.chat_message("user").write(user_question) | |
| # Generate a unique thread ID for the graph's state | |
| config = {"configurable": {"thread_id": str(uuid.uuid4())}} | |
| try: | |
| # Invoke the custom graph with the input question | |
| state_dict = custom_graph.invoke( | |
| {"question": user_question, "steps": []}, config | |
| ) | |
| docs = state_dict["documents"] | |
| with st.sidebar: | |
| st.subheader("Dokumentai, kuriuos Birutė gavo kaip kontekstą") | |
| with st.spinner("Processing"): | |
| for doc in docs: | |
| # Extract document content | |
| content = doc | |
| # Extract document metadata if available | |
| #metadata =doc.metadata.get('original_doc_name', 'unknown') | |
| # Display content and metadata | |
| st.write(f"Documentas: {content}") | |
| # Check if a response (generation) was produced by the graph | |
| if 'generation' in state_dict and state_dict['generation']: | |
| response = state_dict["generation"] | |
| # Add the assistant's response to the chat history and display it | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| st.chat_message("assistant").write(response) | |
| else: | |
| st.chat_message("assistant").write("Your question violates toxicity rules or contains sensitive information.") | |
| except Exception as e: | |
| # Display an error message in case of failure | |
| st.chat_message("assistant").write("Klaida: Arba per didelis kontekstas suteiktas modeliui, arba užklausų serveryje yra per daug") | |
| from typing import Annotated | |
| def create_workflow(retriever): | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| question: question | |
| generation: LLM generation | |
| search: whether to add search | |
| documents: list of documents | |
| generations_count : generations count | |
| """ | |
| question: Annotated[str, "Single"] # Ensuring only one value per step | |
| generation: str | |
| search: str | |
| documents: List[str] | |
| steps: List[str] | |
| generation_count: int | |
| llm = ChatGroq( | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.2, | |
| max_tokens=600, | |
| max_retries=3, | |
| ) | |
| llm_checker = ChatGroq( | |
| model="llama3-groq-70b-8192-tool-use-preview", | |
| temperature=0.1, | |
| max_tokens=400, | |
| max_retries=3, | |
| ) | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("ask_question", lambda state: ask_question(state)) | |
| workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) | |
| workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm_checker))) | |
| workflow.add_node("generate", lambda state: generate(state, QA_chain(llm))) | |
| workflow.add_node("web_search", web_search) | |
| #workflow.add_node("transform_query", lambda state: transform_query(state, create_question_rewriter(llm))) | |
| # Build graph | |
| workflow.set_entry_point("ask_question") | |
| workflow.add_edge("ask_question", "retrieve") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| #workflow.add_edge("retrieve", "generate") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "search": "web_search", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("web_search", "generate") | |
| workflow.add_edge("generate", END) | |
| custom_graph = workflow.compile() | |
| return custom_graph | |
| def retrieval_grader_grader(llm): | |
| """ | |
| Function to create a grader object using a passed LLM model. | |
| Args: | |
| llm: The language model to be used for grading. | |
| Returns: | |
| Callable: A pipeline function that grades relevance based on the LLM. | |
| """ | |
| class GradeDocuments(BaseModel): | |
| """Ar faktas gali būti, nors truputi, naudingas atsakant į klausimą.""" | |
| binary_score: str = Field( | |
| description="Documentai yra aktualūs klausimui, 'yes' arba 'no'" | |
| ) | |
| # Create the structured LLM grader using the passed LLM | |
| structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
| # Define the prompt template | |
| prompt = PromptTemplate( | |
| template="""Jūs esate mokytojas, vertinantis viktoriną. Jums bus suteikta: | |
| 1/ KLAUSIMAS {question} | |
| 2/ Studento pateiktas FAKTAS {documents} | |
| Jūs vertinate RELEVANCE RECALL: | |
| yes reiškia, kad FAKTAS yra susijęs su KLAUSIMU. | |
| no reiškia, kad FAKTAS nesusijęs su KLAUSIMU. | |
| yes yra aukščiausias (geriausias) balas. no yra žemiausias balas, kurį galite duoti. | |
| Jeigu galima iš Studento pateiktas FAKTAS gauti bet kokių įžvalgu susijusiu su KLAUSIMAS, duok įvertinimą yes. | |
| Žingsnis po žingsnio paaiškinkite savo samprotavimus. Įsitikinkite, kad jūsų samprotavimai ir išvados yra teisingi. | |
| Iš pradžių venkite tiesiog nurodyti teisingą atsakymą. | |
| Klausimas: {question} \n | |
| FAKTAS: \n\n {documents} \n\n | |
| Suteikite dvejetainį balą „yes“ arba „no“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n | |
| Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo. | |
| """, | |
| input_variables=['documents', 'question'], | |
| ) | |
| # Combine the prompt with the structured LLM grader | |
| retrieval_grader = prompt | structured_llm_grader | |
| # Return the grader object | |
| return retrieval_grader | |
| def create_toxicity_checker(llm): | |
| """ | |
| Function to create a toxicity checker object using a passed LLM model. | |
| Args: | |
| llm: The language model to be used for toxicity checking. | |
| Returns: | |
| Callable: A pipeline function that checks for toxicity based on the LLM. | |
| """ | |
| # Define the class for toxicity check inside the function | |
| class PromtGuard(BaseModel): | |
| """Binary score for toxicity check on question.""" | |
| score: str = Field( | |
| description="Ar klausimams yra nesaugus, 'taip' arba 'ne'" | |
| ) | |
| # Create the structured LLM toxicity checker using the passed LLM | |
| # Create the structured LLM grader using the passed LLM | |
| structured_llm_toxicity_checker = llm.with_structured_output(PromtGuard) | |
| # Define the prompt template | |
| prompt = PromptTemplate( | |
| template="""Jūs esate klasifikatorius, kuris įvertins tekstą pagal pateiktus kriterijus. | |
| Čia pateikiami taškų kriterijai: | |
| (1) įžeidinėjimai | |
| (2) grasinimai | |
| (3) labai neigiami komentarai | |
| (4) bet kokią asmenį identifikuojančią informaciją, pvz., API raktus | |
| Rezultatas: | |
| „Taip“ reiškia, kad tekstas atitinka BET KOKĮ kriterijų. | |
| Įvertinimas „ne“ reiškia, kad tekstas neatitinka VISŲ kriterijų. | |
| Štai klausimas: {question} | |
| Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar atsakymas yra naudingas norint išspręsti klausimą. | |
| Jei atsakyme yra pasikartojančių frazių, kartojimas, tada grąžinkite „ne“\n | |
| Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""", | |
| input_variables=["question"], | |
| ) | |
| # Combine the prompt with the structured LLM toxicity checker | |
| toxicity_grader = prompt | structured_llm_toxicity_checker | |
| # Return the toxicity checker object | |
| return toxicity_grader | |
| def grade_question_toxicity(state, toxicity_grader): | |
| """ | |
| Grades the question for toxicity. | |
| Args: | |
| state (dict): The current graph state. | |
| Returns: | |
| str: 'good' if the question passes the toxicity check, 'bad' otherwise. | |
| """ | |
| steps = state["steps"] | |
| steps.append("promt guard") | |
| score = toxicity_grader.invoke({"question": state["question"]}) | |
| grade = getattr(score, 'score', None) | |
| if grade == "yes": | |
| return "bad" | |
| else: | |
| return "good" | |
| def create_helpfulness_checker(llm): | |
| """ | |
| Function to create a helpfulness checker object using a passed LLM model. | |
| Args: | |
| llm: The language model to be used for checking the helpfulness of answers. | |
| Returns: | |
| Callable: A pipeline function that checks if the student's answer is helpful. | |
| """ | |
| class helpfulness_checker(BaseModel): | |
| """Binary score for toxicity check on question.""" | |
| score: str = Field( | |
| description="Ar atsakymas yra naudingas?, 'taip' arba 'ne'" | |
| ) | |
| # Create the structured LLM toxicity checker using the passed LLM | |
| structured_llm_helpfulness_checker = llm.with_structured_output(helpfulness_checker) | |
| # Create the structured LLM helpfulness checker using the passed LLM | |
| # Define the prompt template | |
| prompt = PromptTemplate( | |
| template="""Jums bus pateiktas KLAUSIMAS {question} ir ATSAKYMAS {generation}. | |
| Įvertinkite ATSAKYMĄ pagal šiuos kriterijus: | |
| Aktualumas: ATSAKYMAS turi būti tiesiogiai susijęs su KLAUSIMU ir konkrečiai į jį atsakyti. | |
| Pakankamas: ATSAKYME turi būti pakankamai informacijos, kad būtų galima visapusiškai atsakyti į KLAUSIMĄ. Jei ATSAKYME vartojamos tokios frazės kaip „nežinau“, „neturiu pakankamai informacijos“, „pateiktuose dokumentuose apie tai neužsimenama“ ar panašių posakių, kuriuose vengiama tiesiogiai atsakyti į KLAUSIMĄ, įvertinkite „ne“. | |
| Aiškumas ir glaustumas: ATSAKYMAS turi būti aiškus, be jokių nereikalingų frazių ar pasikartojimų. Jei jame yra perteklinė arba netiesioginė informacija, o ne tiesioginis atsakymas, įvertinkite „ne“. | |
| Balų skaičiavimo instrukcijos: | |
| „Taip“ reiškia, kad ATSAKYMAS atitinka visus šiuos kriterijus ir tiesiogiai susijęs su KLAUSIMU. | |
| Įvertinimas „ne“ reiškia, kad ATSAKYMAS neatitinka visų šių kriterijų. | |
| Jei randate tokio žodžio tekstą, kaip aš nežinau, nepakanka informacijos arba panašaus į šį, balas yra ne. | |
| Pateikite balą kaip JSON su vienu raktu "balas" ir be papildomo teksto""", | |
| input_variables=["generation", "question"] | |
| ) | |
| # Combine the prompt with the structured LLM helpfulness checker | |
| helpfulness_grader = prompt | structured_llm_helpfulness_checker | |
| # Return the helpfulness checker object | |
| return helpfulness_grader | |
| def create_hallucination_checker(llm): | |
| """ | |
| Function to create a hallucination checker object using a passed LLM model. | |
| Args: | |
| llm: The language model to be used for checking hallucinations in the student's answer. | |
| Returns: | |
| Callable: A pipeline function that checks if the student's answer contains hallucinations. | |
| """ | |
| class hallucination_checker(BaseModel): | |
| """Binary score for toxicity check on question.""" | |
| score: str = Field( | |
| description="Ar dokumentas yra susijes su atsakymu?, 'taip' arba 'ne'" | |
| ) | |
| # Create the structured LLM toxicity checker using the passed LLM | |
| structured_llm_hallucination_checker = llm.with_structured_output(hallucination_checker) | |
| # Define the prompt template | |
| prompt = PromptTemplate( | |
| template="""Jūs esate mokytojas, vertinantis viktoriną. | |
| Jums bus pateikti FAKTAI ir MOKINIO ATSAKYMAS. | |
| Jūs vertinate MOKINIO ATSAKYMĄ iš šaltinio FAKTAI. Sutelkite dėmesį į MOKINIO ATSAKYMO teisingumą ir bet kokių haliucinacijų aptikimą. | |
| Įsitikinkite, kad MOKINIO ATSAKYMAS atitinka šiuos kriterijus: | |
| (1) jame nėra informacijos, nesusijusios su FAKTAIS | |
| (2) STUDENTŲ ATSAKYMAS turėtų būti visiškai pagrįstas ir pagrįstas pirminiuose dokumentuose pateikta informacija | |
| Rezultatas: | |
| „Taip“ reiškia, kad studento atsakymas atitinka visus kriterijus. Tai aukščiausias (geriausias) balas. | |
| Balas „ne“ reiškia, kad studento atsakymas neatitinka visų kriterijų. Tai yra žemiausias galimas balas, kurį galite duoti. | |
| Žingsnis po žingsnio paaiškinkite savo samprotavimus, kad įsitikintumėte, jog argumentai ir išvados yra teisingi. | |
| Iš pradžių venkite tiesiog nurodyti teisingą atsakymą. | |
| MOKINIO ATSAKYMAS: {generation} \n | |
| FAKTAI: \n\n {documents} \n\n | |
| Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n | |
| Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""", | |
| input_variables=["generation", "documents"], | |
| ) | |
| # Combine the prompt with the structured LLM hallucination checker | |
| hallucination_grader = prompt | structured_llm_haliucinations_checker | |
| # Return the hallucination checker object | |
| return hallucination_grader | |
| def create_question_rewriter(llm): | |
| """ | |
| Function to create a question rewriter object using a passed LLM model. | |
| Args: | |
| llm: The language model to be used for rewriting questions. | |
| Returns: | |
| Callable: A pipeline function that rewrites questions for optimized vector store retrieval. | |
| """ | |
| # Define the prompt template for question rewriting | |
| re_write_prompt = PromptTemplate( | |
| template="""Esate klausimų perrašytojas, kurio specializacija yra Lietuvos teisė, tobulinanti klausimus, kad būtų galima optimizuoti jų paiešką iš teisinių dokumentų. Jūsų tikslas – išaiškinti teisinę intenciją, pašalinti dviprasmiškumą ir pakoreguoti formuluotes taip, kad jos atspindėtų teisinę kalbą, daugiausia dėmesio skiriant atitinkamiems raktiniams žodžiams, siekiant užtikrinti tikslų informacijos gavimą iš Lietuvos teisės šaltinių. | |
| Man nereikia paaiškinimų, tik perrašyto klausimo. | |
| Štai pradinis klausimas: \n\n {question}. Patobulintas klausimas be paaiškinimų : \n""", | |
| input_variables=["question"], | |
| ) | |
| # Combine the prompt with the LLM and output parser | |
| question_rewriter = re_write_prompt | llm | StrOutputParser() | |
| # Return the question rewriter object | |
| return question_rewriter | |
| def transform_query(state, question_rewriter): | |
| """ | |
| Transform the query to produce a better question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates question key with a re-phrased question | |
| """ | |
| print("---TRANSFORM QUERY---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| steps = state["steps"] | |
| steps.append("question_transformation") | |
| # Re-write question | |
| better_question = question_rewriter.invoke({"question": question}) | |
| print(f" Transformed question: {better_question}") | |
| return {"documents": documents, "question": better_question} | |
| def format_google_results_search(google_results): | |
| formatted_documents = [] | |
| # Extract data from answerBox | |
| answer_box = google_results.get("answerBox", {}) | |
| answer_box_title = answer_box.get("title", "No title") | |
| answer_box_answer = answer_box.get("answer", "No text") | |
| # Extract and add organic results as separate Documents | |
| for result in google_results.get("organic", []): | |
| title = result.get("title", "No title") | |
| link = result.get("link", "Nėra svetainės adreso") | |
| snippet = result.get("snippet", "No snippet available") | |
| document = Document( | |
| metadata={ | |
| "Organinio rezultato pavadinimas": title, | |
| }, | |
| page_content=( | |
| f"Pavadinimas: {title} " | |
| f"Straipsnio ištrauka: {snippet} " | |
| f"Nuoroda: {link} " | |
| ) | |
| ) | |
| formatted_documents.append(document) | |
| return formatted_documents | |
| def format_google_results_news(google_results): | |
| formatted_documents = [] | |
| # Loop through each organic result and create a Document for it | |
| for result in google_results['organic']: | |
| title = result.get('title', 'No title') | |
| link = result.get('link', 'No link') | |
| descripsion = result.get('description', 'No link') | |
| snippet = result.get('snippet', 'No summary available') | |
| text = result.get('text' , 'no text') | |
| # Create a Document object with similar metadata structure to WikipediaRetriever | |
| document = Document( | |
| metadata={ | |
| 'Title': title, | |
| 'Description': descripsion, | |
| 'Text' : text, | |
| 'Snippet': snippet, | |
| 'Source': link | |
| }, | |
| page_content=snippet # Using the snippet as the page content | |
| ) | |
| formatted_documents.append(document) | |
| return formatted_documents | |
| def QA_chain(llm): | |
| """ | |
| Creates a question-answering chain using the provided language model. | |
| Args: | |
| llm: The language model to use for generating answers. | |
| Returns: | |
| An LLMChain configured with the question-answering prompt and the provided model. | |
| """ | |
| # Define the prompt template | |
| prompt = PromptTemplate( | |
| template="""Esi teisės asistentas, kurio užduotis yra atsakyti konkrečiai, informatyviai ir glaustai , pagrindžiant savo atsakymą į klausima pagal pateiktus dokumentus. | |
| Atsakymas turi būti lietuvių kalba. Nesikartok. | |
| Jei negali atsakyti į klausimą, pasakyk, Atsiprašau, nežinau atsakymo į jūsų klausimą. | |
| Neužduok papildomų klausimų. | |
| Klausimas: {question} | |
| Dokumentai: {documents} | |
| Atsakymas: | |
| """, | |
| input_variables=["question", "documents"], | |
| ) | |
| rag_chain = prompt | llm | StrOutputParser() | |
| return rag_chain | |
| def grade_generation_v_documents_and_question(state,hallucination_grader,answer_grader ): | |
| """ | |
| Determines whether the generation is grounded in the document and answers the question. | |
| """ | |
| print("---CHECK HALLUCINATIONS---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = state["generation"] | |
| generation_count = state.get("generation_count") # Use state.get to avoid KeyError | |
| print(f" generation number: {generation_count}") | |
| # Grading hallucinations | |
| score = hallucination_grader.invoke( | |
| {"documents": documents, "generation": generation} | |
| ) | |
| grade = getattr(score, 'score', None) | |
| # Check hallucination | |
| if grade == "yes": | |
| print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
| # Check question-answering | |
| print("---GRADE GENERATION vs QUESTION---") | |
| score = answer_grader.invoke({"question": question, "generation": generation}) | |
| grade = getattr(score, 'score', None) | |
| if grade == "yes": | |
| print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
| return "useful" | |
| else: | |
| print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
| return "not useful" | |
| else: | |
| if generation_count > 1: | |
| print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, TRANSFORM QUERY---") | |
| # Reset count if it exceeds limit | |
| return "not useful" | |
| else: | |
| print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
| # Increment correctly here | |
| print(f" generation number after increment: {state['generation_count']}") | |
| return "not supported" | |
| def ask_question(state): | |
| """ | |
| Initialize question | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Question | |
| """ | |
| steps = state["steps"] | |
| question = state["question"] | |
| generations_count = state.get("generations_count", 0) | |
| steps.append("question_asked") | |
| return {"question": question, "steps": steps,"generation_count": generations_count} | |
| def retrieve(state , retriever): | |
| """ | |
| Retrieve documents | |
| Args: | |
| state (dict): The current graph state | |
| retriever: The retriever object | |
| Returns: | |
| state (dict): New key added to state, documents, that contains retrieved documents | |
| """ | |
| steps = state["steps"] | |
| question = state["question"] | |
| documents = retriever.invoke(question) | |
| steps.append("retrieve_documents") | |
| return {"documents": documents, "question": question, "steps": steps} | |
| def generate(state,QA_chain): | |
| """ | |
| Generate answer | |
| """ | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = QA_chain.stream({"documents": documents, "question": question}) | |
| steps = state["steps"] | |
| steps.append("generate_answer") | |
| generation_count = state["generation_count"] | |
| generation_count += 1 | |
| return { | |
| "documents": documents, | |
| "question": question, | |
| "generation": generation, | |
| "steps": steps, | |
| "generation_count": generation_count # Include generation_count in return | |
| } | |
| def grade_documents(state, retrieval_grader): | |
| question = state["question"] | |
| documents = state["documents"] | |
| steps = state["steps"] | |
| steps.append("grade_document_retrieval") | |
| filtered_docs = [] | |
| web_results_list = [] | |
| search = "No" | |
| for d in documents: | |
| # Call the grading function | |
| score = retrieval_grader.invoke({"question": question, "documents": d}) | |
| print(f"Grader output for document: {score}") # Detailed debugging output | |
| # Extract the grade | |
| grade = getattr(score, 'binary_score', None) | |
| if grade and grade.lower() in ["yes", "true", "1",'taip']: | |
| filtered_docs.append(d) | |
| elif len(filtered_docs) < 4: | |
| search = "Yes" | |
| # Check the decision-making process | |
| print(f"Final decision - Perform web search: {search}") | |
| print(f"Filtered documents count: {len(filtered_docs)}") | |
| return { | |
| "documents": filtered_docs, | |
| "question": question, | |
| "search": search, | |
| "steps": steps, | |
| } | |
| def clean_exa_document(doc): | |
| """ | |
| Extracts and retains only the title, url, text, and summary from the exa result document. | |
| """ | |
| return { | |
| " Pavadinimas: ": doc.title, | |
| " Apibendrinimas: ": doc.summary, | |
| " Straipnsio internetinis adresas: ": doc.url, | |
| " Tekstas: ": doc.text | |
| } | |
| def web_search(state): | |
| question = state["question"] | |
| documents = state.get("documents", []) | |
| steps = state["steps"] | |
| steps.append("web_search") | |
| k = 8 - len(documents) | |
| web_results_list = [] | |
| # Fetch results from exa | |
| exa_results_raw = exa.search_and_contents( | |
| query=question, | |
| start_published_date="2018-01-01T22:00:01.000Z", | |
| type="keyword", | |
| num_results=2, | |
| text={"max_characters": 7000}, | |
| summary={ | |
| "query": "Tell in summary a meaning about what is article written. This summary has to be written in a way to be related to {question} Provide facts, be concise. Do it in Lithuanian language." | |
| }, | |
| include_domains=[ "infolex.lt", "vmi.lt", "lrs.lt", "e-seimas.lrs.lt", "teise.pro",'lt.wikipedia.org', 'teismai.lt' ], | |
| ) | |
| # Extract results | |
| exa_results = exa_results_raw.results if hasattr(exa_results_raw, "results") else [] | |
| cleaned_exa_results = [clean_exa_document(doc) for doc in exa_results] | |
| if len(cleaned_exa_results) <1: | |
| web_results = GoogleSerperAPIWrapper(k=2, gl="lt", hl="lt", type="search").results(question) | |
| formatted_documents = format_google_results_search(web_results) | |
| web_results_list.extend(formatted_documents if isinstance(formatted_documents, list) else [formatted_documents]) | |
| combined_documents = documents + cleaned_exa_results +web_results_list | |
| else: | |
| combined_documents = documents + cleaned_exa_results | |
| return {"documents": combined_documents, "question": question, "steps": steps} | |
| def decide_to_generate(state): | |
| """ | |
| Determines whether to generate an answer, or re-generate a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| search = state["search"] | |
| if search == "Yes": | |
| return "search" | |
| else: | |
| return "generate" | |
| def decide_to_generate2(state): | |
| """ | |
| Determines whether to generate an answer, or re-generate a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| search = state["search"] | |
| if search == "Yes": | |
| return "search" | |
| else: | |
| return "generate" |