Spaces:
Running
Running
from io import StringIO | |
import sys | |
import os | |
# Set EasyOCR cache directory to a writable location | |
os.environ["EASYOCR_CACHE_DIR"] = "/app/.EASYOCR" | |
import easyocr | |
# Monkey-patch the easyocr.Reader to force the model_storage directory parameter | |
_original_init = easyocr.Reader.__init__ | |
def new_init(self, *args, **kwargs): | |
if args and "lang_list" in kwargs: | |
del kwargs["lang_list"] | |
kwargs.setdefault("model_storage_directory", "/app/.EasyOCR") | |
_original_init(self, *args, **kwargs) | |
easyocr.Reader.__init__ = new_init | |
#from huggingface_hub import login | |
import gradio as gr | |
import json | |
import csv | |
import hashlib | |
import uuid | |
import logging | |
from typing import Annotated, List, Dict, Sequence, TypedDict | |
# LangChain & related imports | |
from langchain_core.runnables import RunnableConfig | |
from langchain_core.tools import tool, StructuredTool | |
from pydantic import BaseModel, Field | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
from langchain_core.documents import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.retrievers import EnsembleRetriever | |
# Extraction for Documents | |
from langchain_docling.loader import ExportType | |
from langchain_docling import DoclingLoader | |
from docling.chunking import HybridChunker | |
# Extraction for HTML | |
from langchain_community.document_loaders import WebBaseLoader | |
from urllib.parse import urlparse | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langgraph.prebuilt import InjectedStore | |
from langgraph.store.base import BaseStore | |
from langgraph.store.memory import InMemoryStore | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain.embeddings import init_embeddings | |
from langgraph.graph import StateGraph | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_core.messages import ( | |
SystemMessage, | |
AIMessage, | |
HumanMessage, | |
BaseMessage, | |
ToolMessage, | |
) | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Suppress all library logs at or below WARNING for user experience: | |
logging.disable(logging.WARNING) | |
# HF_TOKEN = os.getenv("HF_TOKEN") # Read from environment variable | |
# if HF_TOKEN: | |
# login(token=HF_TOKEN) # Log in to Hugging Face Hub | |
# else: | |
# print("Warning: HF_TOKEN not found in environment variables.") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Read from environment variable | |
if not GROQ_API_KEY: | |
print("Warning: GROQ_API_KEY not found in environment variables.") | |
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" | |
# ============================================================================= | |
# Document Extraction Functions | |
# ============================================================================= | |
def extract_documents(doc_path: str) -> List[str]: | |
""" | |
Recursively collects all file paths from folder 'doc_path'. | |
Used by ExtractDocument.load_files() to find documents to parse. | |
""" | |
extracted_docs = [] | |
for root, _, files in os.walk(doc_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
extracted_docs.append(file_path) | |
return extracted_docs | |
def _generate_uuid(page_content: str) -> str: | |
"""Generate a UUID for a chunk of text using MD5 hashing.""" | |
md5_hash = hashlib.md5(page_content.encode()).hexdigest() | |
return str(uuid.UUID(md5_hash[0:32])) | |
def load_file(file_path: str) -> List[Document]: | |
""" | |
Load a file from the given path and return a list of Document objects. | |
""" | |
_documents = [] | |
# Load the file and extract the text chunks | |
try: | |
loader = DoclingLoader( | |
file_path = file_path, | |
export_type = ExportType.DOC_CHUNKS, | |
chunker = HybridChunker(tokenizer=EMBED_MODEL_ID), | |
) | |
docs = loader.load() | |
logger.info(f"Total parsed doc-chunks: {len(docs)} from Source: {file_path}") | |
for d in docs: | |
# Tag each document's chunk with the source file and a unique ID | |
doc = Document( | |
page_content=d.page_content, | |
metadata={ | |
"source": file_path, | |
"doc_id": _generate_uuid(d.page_content), | |
"source_type": "file", | |
} | |
) | |
_documents.append(doc) | |
logger.info(f"Total generated LangChain document chunks: {len(_documents)}\n.") | |
except Exception as e: | |
logger.error(f"Error loading file: {file_path}. Exception: {e}\n.") | |
return _documents | |
# Define function to load documents from a folder | |
def load_files_from_folder(doc_path: str) -> List[Document]: | |
""" | |
Load documents from the given folder path and return a list of Document objects. | |
""" | |
_documents = [] | |
# Extract all files path from the given folder | |
extracted_docs = extract_documents(doc_path) | |
# Iterate through each document and extract the text chunks | |
for file_path in extracted_docs: | |
_documents.extend(load_file(file_path)) | |
return _documents | |
# ============================================================================= | |
# Load structured data in csv file to LangChain Document format | |
def load_mcq_csvfiles(file_path: str) -> List[Document]: | |
""" | |
Load structured data in mcq csv file from the given file path and return a list of Document object. | |
Expected format: each row of csv is comma separated into "mcq_number", "mcq_type", "text_content" | |
""" | |
_documents = [] | |
# iterate through each csv file and load each row into _dict_per_question format | |
# Ensure we process only CSV files | |
if not file_path.endswith(".csv"): | |
return _documents # Skip non-CSV files | |
try: | |
# Open and read the CSV file | |
with open(file_path, mode='r', encoding='utf-8') as file: | |
reader = csv.DictReader(file) | |
for row in reader: | |
# Ensure required columns exist in the row | |
if not all(k in row for k in ["mcq_number", "mcq_type", "text_content"]): # Ensure required columns exist and exclude header | |
logger.error(f"Skipping row due to missing fields: {row}") | |
continue | |
# Tag each row of csv is comma separated into "mcq_number", "mcq_type", "text_content" | |
doc = Document( | |
page_content = row["text_content"], # text_content segment is separated by "|" | |
metadata={ | |
"source": f"{file_path}_{row['mcq_number']}", # file_path + mcq_number | |
"doc_id": _generate_uuid(f"{file_path}_{row['mcq_number']}"), # Unique ID | |
"source_type": row["mcq_type"], # MCQ type | |
} | |
) | |
_documents.append(doc) | |
logger.info(f"Successfully loaded {len(_documents)} LangChain document chunks from {file_path}.") | |
except Exception as e: | |
logger.error(f"Error loading file: {file_path}. Exception: {e}\n.") | |
return _documents | |
# Define function to load documents from a folder for structured data in csv file | |
def load_files_from_folder_mcq(doc_path: str) -> List[Document]: | |
""" | |
Load mcq csv file from the given folder path and return a list of Document objects. | |
""" | |
_documents = [] | |
# Extract all files path from the given folder | |
extracted_docs = [ | |
os.path.join(doc_path, file) for file in os.listdir(doc_path) | |
if file.endswith(".csv") # Process only CSV files | |
] | |
# Iterate through each document and extract the text chunks | |
for file_path in extracted_docs: | |
_documents.extend(load_mcq_csvfiles(file_path)) | |
return _documents | |
# ============================================================================= | |
# Website Extraction Functions | |
# ============================================================================= | |
def _generate_uuid(page_content: str) -> str: | |
"""Generate a UUID for a chunk of text using MD5 hashing.""" | |
md5_hash = hashlib.md5(page_content.encode()).hexdigest() | |
return str(uuid.UUID(md5_hash[0:32])) | |
def ensure_scheme(url): | |
parsed_url = urlparse(url) | |
if not parsed_url.scheme: | |
return 'http://' + url # Default to http, or use 'https://' if preferred | |
return url | |
def extract_html(url: List[str]) -> List[Document]: | |
if isinstance(url, str): | |
url = [url] | |
""" | |
Extracts text from the HTML content of web pages listed in 'web_path'. | |
Returns a list of LangChain 'Document' objects. | |
""" | |
# Ensure all URLs have a scheme | |
web_paths = [ensure_scheme(u) for u in url] | |
loader = WebBaseLoader(web_paths) | |
loader.requests_per_second = 1 | |
docs = loader.load() | |
# Iterate through each document, clean the content, removing excessive line return and store it in a LangChain Document | |
_documents = [] | |
for doc in docs: | |
# Clean the concent | |
doc.page_content = doc.page_content.strip() | |
doc.page_content = doc.page_content.replace("\n", " ") | |
doc.page_content = doc.page_content.replace("\r", " ") | |
doc.page_content = doc.page_content.replace("\t", " ") | |
doc.page_content = doc.page_content.replace(" ", " ") | |
doc.page_content = doc.page_content.replace(" ", " ") | |
# Store it in a LangChain Document | |
web_doc = Document( | |
page_content=doc.page_content, | |
metadata={ | |
"source": doc.metadata.get("source"), | |
"doc_id": _generate_uuid(doc.page_content), | |
"source_type": "web" | |
} | |
) | |
_documents.append(web_doc) | |
return _documents | |
# ============================================================================= | |
# Vector Store Initialisation | |
# ============================================================================= | |
embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_ID) | |
# Initialise vector stores | |
general_vs = Chroma( | |
collection_name="general_vstore", | |
embedding_function=embedding_model, | |
persist_directory="./general_db" | |
) | |
mcq_vs = Chroma( | |
collection_name="mcq_vstore", | |
embedding_function=embedding_model, | |
persist_directory="./mcq_db" | |
) | |
in_memory_vs = Chroma( | |
collection_name="in_memory_vstore", | |
embedding_function=embedding_model | |
) | |
# Split the documents into smaller chunks for better embedding coverage | |
def split_text_into_chunks(docs: List[Document]) -> List[Document]: | |
""" | |
Splits a list of Documents into smaller text chunks using | |
RecursiveCharacterTextSplitter while preserving metadata. | |
Returns a list of Document objects. | |
""" | |
if not docs: | |
return [] | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, # Split into chunks of 1000 characters | |
chunk_overlap=200, # Overlap by 200 characters | |
add_start_index=True | |
) | |
chunked_docs = splitter.split_documents(docs) | |
return chunked_docs # List of Document objects | |
# ============================================================================= | |
# Retrieval Tools | |
# ============================================================================= | |
# Define a simple similarity search retrieval tool on msq_vs | |
class MCQRetrievalTool(BaseModel): | |
input: str = Field(..., title="input", description="Search topic.") | |
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.") | |
def mcq_retriever(input: str, k: int = 2) -> List[str]: | |
# Retrieve the top k most similar mcq question documents from the vector store | |
docs_func = mcq_vs.as_retriever( | |
search_type="similarity", | |
search_kwargs={ | |
'k': k, | |
'filter':{"source_type": "mcq_question"} | |
}, | |
) | |
docs_qns = docs_func.invoke(input, k=k) | |
# Extract the document IDs from the retrieved documents | |
doc_ids = [d.metadata.get("doc_id") for d in docs_qns if "doc_id" in d.metadata] | |
# Retrieve full documents based on the doc_ids | |
docs = mcq_vs.get(where = {'doc_id': {"$in":doc_ids}}) | |
qns_list = {} | |
for i, d in enumerate(docs['metadatas']): | |
qns_list[d['source'] + " " + d['source_type']] = docs['documents'][i] | |
return qns_list | |
# Create a StructuredTool from the function | |
mcq_retriever_tool = StructuredTool.from_function( | |
func = mcq_retriever, | |
name = "MCQ Retrieval Tool", | |
description = ( | |
""" | |
Use this tool to retrieve MCQ questions set when Human asks to generate a quiz related to a topic. | |
DO NOT GIVE THE ANSWERS to Human before Human has answered all the questions. | |
If Human give answers for questions you do not know, SAY you do not have the questions for the answer | |
and ASK if the Human want you to generate a new quiz and then SAVE THE QUIZ with Summary Tool before ending the conversation. | |
Input must be a JSON string with the schema: | |
- input (str): The search topic to retrieve MCQ questions set related to the topic. | |
- k (int): Number of question set to retrieve. | |
Example usage: input='What is AI?', k=5 | |
Returns: | |
- A dict of MCQ questions: | |
Key: 'metadata of question' e.g. './Documents/mcq/mcq.csv_Qn31 mcq_question' with suffix ['question', 'answer', 'answer_reason', 'options', 'wrong_options_reason'] | |
Value: Text Content | |
""" | |
), | |
args_schema = MCQRetrievalTool, | |
response_format="content", | |
return_direct = False, # Return the response as a list of strings | |
verbose = False # To log tool's progress | |
) | |
# ----------------------------------------------------------------------------- | |
# Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the general vector store | |
# Useful if the dataset has many similar documents | |
class GenRetrievalTool(BaseModel): | |
input: str = Field(..., title="input", description="User query.") | |
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.") | |
def gen_retriever(input: str, k: int = 2) -> List[str]: | |
# Use retriever of vector store to retrieve documents | |
docs_func = general_vs.as_retriever( | |
search_type="mmr", | |
search_kwargs = {'k': k, 'lambda_mult': 0.25} | |
) | |
docs = docs_func.invoke(input, k=k) | |
return [d.page_content for d in docs] | |
# Create a StructuredTool from the function | |
general_retriever_tool = StructuredTool.from_function( | |
func = gen_retriever, | |
name = "Assistant References Retrieval Tool", | |
description = ( | |
""" | |
Use this tool to retrieve reference information from Assistant reference database for Human queries related to a topic or | |
and when Human asked to generate guides to learn or study about a topic. | |
Input must be a JSON string with the schema: | |
- input (str): The user query. | |
- k (int): Number of results to retrieve. | |
Example usage: input='What is AI?', k=5 | |
Returns: | |
- A list of retrieved document's content string. | |
""" | |
), | |
args_schema = GenRetrievalTool, | |
response_format="content", | |
return_direct = False, # Return the content of the documents | |
verbose = False # To log tool's progress | |
) | |
# ----------------------------------------------------------------------------- | |
# Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the in-memory vector store | |
# Query in-memory vector store only | |
class InMemoryRetrievalTool(BaseModel): | |
input: str = Field(..., title="input", description="User query.") | |
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.") | |
def in_memory_retriever(input: str, k: int = 2) -> List[str]: | |
# Use retriever of vector store to retrieve documents | |
docs_func = in_memory_vs.as_retriever( | |
search_type="mmr", | |
search_kwargs = {'k': k, 'lambda_mult': 0.25} | |
) | |
docs = docs_func.invoke(input, k=k) | |
return [d.page_content for d in docs] | |
# Create a StructuredTool from the function | |
in_memory_retriever_tool = StructuredTool.from_function( | |
func = in_memory_retriever, | |
name = "In-Memory Retrieval Tool", | |
description = ( | |
""" | |
Use this tool when Human ask Assistant to retrieve information from documents that Human has uploaded. | |
Input must be a JSON string with the schema: | |
- input (str): The user query. | |
- k (int): Number of results to retrieve. | |
""" | |
), | |
args_schema = InMemoryRetrievalTool, | |
response_format="content", | |
return_direct = False, # Whether to return the tool’s output directly | |
verbose = False # To log tool's progress | |
) | |
# ----------------------------------------------------------------------------- | |
# Web Extraction Tool | |
class WebExtractionRequest(BaseModel): | |
input: str = Field(..., title="input", description="Search text.") | |
url: str = Field( | |
..., | |
title="url", | |
description="Web URL(s) to extract content from. If multiple URLs, separate them with a comma." | |
) | |
k: int = Field(5, title="Number of Results", description="The number of results to retrieve.") | |
# Extract content from a web URL, load into in_memory_vstore | |
def extract_web_path_tool(input: str, url: str, k: int = 5) -> List[str]: | |
if isinstance(url, str): | |
url = [url] | |
""" | |
Extract content from the web URLs based on user's input. | |
Args: | |
- input: The input text to search for. | |
- url: URLs to extract content from. | |
- k: Number of results to retrieve. | |
Returns: | |
- A list of retrieved document's content string. | |
""" | |
# Extract content from the web | |
html_docs = extract_html(url) | |
if not html_docs: | |
return f"No content extracted from {url}." | |
# Split the documents into smaller chunks for better embedding coverage | |
chunked_texts = split_text_into_chunks(html_docs) | |
in_memory_vs.add_documents(chunked_texts) # Add the chunked texts to the in-memory vector store | |
# Extract content from the in-memory vector store | |
# Use retriever of vector store to retrieve documents | |
docs_func = in_memory_vs.as_retriever( | |
search_type="mmr", | |
search_kwargs={ | |
'k': k, | |
'lambda_mult': 0.25, | |
'filter':{"source": {"$in": url}} | |
}, | |
) | |
docs = docs_func.invoke(input, k=k) | |
return [d.page_content for d in docs] | |
# Create a StructuredTool from the function | |
web_extraction_tool = StructuredTool.from_function( | |
func = extract_web_path_tool, | |
name = "Web Extraction Tool", | |
description = ( | |
"Assistant should use this tool to extract content from web URLs based on user's input, " | |
"Web extraction is initially load into database and then return k: Number of results to retrieve" | |
), | |
args_schema = WebExtractionRequest, | |
return_direct = False, # Whether to return the tool’s output directly | |
verbose = False # To log tool's progress | |
) | |
# ----------------------------------------------------------------------------- | |
# Ensemble Retrieval from General and In-Memory Vector Stores | |
class EnsembleRetrievalTool(BaseModel): | |
input: str = Field(..., title="input", description="User query.") | |
k: int = Field(5, title="Number of Results", description="Number of results.") | |
def ensemble_retriever(input: str, k: int = 5) -> List[str]: | |
# Use retriever of vector store to retrieve documents | |
general_retrieval = general_vs.as_retriever( | |
search_type="mmr", | |
search_kwargs = {'k': k, 'lambda_mult': 0.25} | |
) | |
in_memory_retrieval = in_memory_vs.as_retriever( | |
search_type="mmr", | |
search_kwargs = {'k': k, 'lambda_mult': 0.25} | |
) | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[general_retrieval, in_memory_retrieval], | |
weights=[0.5, 0.5] | |
) | |
docs = ensemble_retriever.invoke(input) | |
return [d.page_content for d in docs] | |
# Create a StructuredTool from the function | |
ensemble_retriever_tool = StructuredTool.from_function( | |
func = ensemble_retriever, | |
name = "Ensemble Retriever Tool", | |
description = ( | |
""" | |
Use this tool to retrieve information from reference database and | |
extraction of documents that Human has uploaded. | |
Input must be a JSON string with the schema: | |
- input (str): The user query. | |
- k (int): Number of results to retrieve. | |
""" | |
), | |
args_schema = EnsembleRetrievalTool, | |
response_format="content", | |
return_direct = False | |
) | |
############################################################################### | |
# LLM Model Setup | |
############################################################################### | |
TEMPERATURE = 0.5 | |
# model = ChatOpenAI( | |
# model="unsloth/llama-3-8b-Instruct-bnb-4bit", | |
# temperature=TEMPERATURE, | |
# timeout=None, | |
# max_retries=2, | |
# api_key="not_required", | |
# base_url="http://localhost:8000/v1", # Use the VLLM instance URL | |
# verbose=True | |
# ) | |
model = ChatGroq( | |
model_name="deepseek-r1-distill-llama-70b", | |
temperature=TEMPERATURE, | |
api_key=GROQ_API_KEY, | |
verbose=True | |
) | |
############################################################################### | |
# 1. Initialize memory + config | |
############################################################################### | |
in_memory_store = InMemoryStore( | |
index={ | |
"embed": init_embeddings("huggingface:sentence-transformers/all-MiniLM-L6-v2"), | |
"dims": 384, # Embedding dimensions | |
} | |
) | |
# A memory saver to checkpoint conversation states | |
checkpointer = MemorySaver() | |
# Initialize config with user & thread info | |
config = {} | |
config["configurable"] = { | |
"user_id": "user_1", | |
"thread_id": 0, | |
} | |
############################################################################### | |
# 2. Define MessagesState | |
############################################################################### | |
class MessagesState(TypedDict): | |
"""The state of the agent. | |
The key 'messages' uses add_messages as a reducer, | |
so each time this state is updated, new messages are appended. | |
# See https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers | |
""" | |
messages: Annotated[Sequence[BaseMessage], add_messages] | |
############################################################################### | |
# 3. Memory Tools | |
############################################################################### | |
def save_memory(summary_text: str, *, config: RunnableConfig, store: BaseStore) -> str: | |
"""Save the given memory for the current user and return the key.""" | |
user_id = config.get("configurable", {}).get("user_id") | |
thread_id = config.get("configurable", {}).get("thread_id") | |
namespace = (user_id, "memories") | |
memory_id = thread_id | |
store.put(namespace, memory_id, {"memory": summary_text}) | |
return f"Saved to memory key: {memory_id}" | |
def update_memory(state: MessagesState, config: RunnableConfig, *, store: BaseStore): | |
# Extract the messages list from the event, handling potential missing key | |
messages = state["messages"] | |
# Convert LangChain messages to dictionaries before storing | |
messages_dict = [{"role": msg.type, "content": msg.content} for msg in messages] | |
# Get the user id from the config | |
user_id = config.get("configurable", {}).get("user_id") | |
thread_id = config.get("configurable", {}).get("thread_id") | |
# Namespace the memory | |
namespace = (user_id, "memories") | |
# Create a new memory ID | |
memory_id = f"{thread_id}" | |
store.put(namespace, memory_id, {"memory": messages_dict}) | |
return f"Saved to memory key: {memory_id}" | |
# Define a Pydantic schema for the save_memory tool (if needed elsewhere) | |
# https://langchain-ai.github.io/langgraphjs/reference/classes/checkpoint.InMemoryStore.html | |
class RecallMemory(BaseModel): | |
query_text: str = Field(..., title="Search Text", description="The text to search from memories for similar records.") | |
k: int = Field(5, title="Number of Results", description="Number of results to retrieve.") | |
def recall_memory(query_text: str, k: int = 5) -> str: | |
"""Retrieve user memories from in_memory_store.""" | |
user_id = config.get("configurable", {}).get("user_id") | |
memories = [ | |
m.value["memory"] for m in in_memory_store.search((user_id, "memories"), query=query_text, limit=k) | |
if "memory" in m.value | |
] | |
return f"User memories: {memories}" | |
# Create a StructuredTool from the function | |
recall_memory_tool = StructuredTool.from_function( | |
func=recall_memory, | |
name="Recall Memory Tool", | |
description=""" | |
Retrieve memories relevant to the user's query. | |
""", | |
args_schema=RecallMemory, | |
response_format="content", | |
return_direct=False, | |
verbose=False | |
) | |
############################################################################### | |
# 4. Summarize Node (using StructuredTool) | |
############################################################################### | |
# Define a Pydantic schema for the Summary tool | |
class SummariseConversation(BaseModel): | |
summary_text: str = Field(..., title="text", description="Write a summary of entire conversation here") | |
def summarise_node(summary_text: str): | |
""" | |
Final node that summarizes the entire conversation for the current thread, | |
saves it in memory, increments the thread_id, and ends the conversation. | |
Returns a confirmation string. | |
""" | |
user_id = config["configurable"]["user_id"] | |
current_thread_id = config["configurable"]["thread_id"] | |
new_thread_id = str(int(current_thread_id) + 1) | |
# Prepare configuration for saving memory with updated thread id | |
config_for_saving = { | |
"configurable": { | |
"user_id": user_id, | |
"thread_id": new_thread_id | |
} | |
} | |
key = save_memory(summary_text, config=config_for_saving, store=in_memory_store) | |
#return f"Summary saved under key: {key}" | |
# Create a StructuredTool from the function (this wraps summarise_node) | |
summarise_tool = StructuredTool.from_function( | |
func=summarise_node, | |
name="Summary Tool", | |
description=""" | |
Summarize the current conversation in no more than | |
1000 words. Also retain any unanswered quiz questions along with | |
your internal answers so the next conversation thread can continue. | |
Do not reveal solutions to the user yet. Use this tool to save | |
the current conversation to memory and then end the conversation. | |
""", | |
args_schema=SummariseConversation, | |
response_format="content", | |
return_direct=False, | |
verbose=True | |
) | |
def call_summary(state: MessagesState, config: RunnableConfig): | |
# Convert message dicts to HumanMessage instances if needed. | |
system_message=""" | |
Summarize the current conversation in no more than | |
1000 words. Also retain any unanswered quiz questions along with | |
your internal answers. | |
""" | |
messages = [] | |
for m in state["messages"]: | |
if isinstance(m, dict): | |
# Use role from dict (defaulting to 'user' if missing) | |
messages.append(AIMessage(content=system_message, role=m.get("role", "assistant"))) | |
else: | |
messages.append(m) | |
summaries = llm_with_tools.invoke(messages) | |
summary_content = summaries.content | |
# Call Tool Manually | |
message_with_single_tool_call = AIMessage( | |
content="", | |
tool_calls=[ | |
{ | |
"name": "Summary Tool", | |
"args": {"summary_text": summary_content}, | |
"id": "tool_call_id", | |
"type": "tool_call", | |
} | |
], | |
) | |
tool_node.invoke({"messages": [message_with_single_tool_call]}) | |
############################################################################### | |
# 5. Build the Graph | |
############################################################################### | |
graph_builder = StateGraph(MessagesState) | |
# Use the built-in ToolNode from langgraph that calls any declared tools. | |
tools = [ | |
mcq_retriever_tool, | |
web_extraction_tool, | |
ensemble_retriever_tool, | |
general_retriever_tool, | |
in_memory_retriever_tool, | |
recall_memory_tool, | |
summarise_tool, | |
] | |
tool_node = ToolNode(tools=tools) | |
#end_node = ToolNode(tools=[summarise_tool]) | |
# Wrap your model with tools | |
llm_with_tools = model.bind_tools(tools) | |
############################################################################### | |
# 6. The agent's main node: call_model | |
############################################################################### | |
def call_model(state: MessagesState, config: RunnableConfig): | |
""" | |
The main agent node that calls the LLM with the user + system messages. | |
Since our vLLM chat wrapper expects a list of BaseMessage objects, | |
we convert any dict messages to HumanMessage objects. | |
If the LLM requests a tool call, we'll route to the 'tools' node next | |
(depending on the condition). | |
""" | |
# Convert message dicts to HumanMessage instances if needed. | |
messages = [] | |
for m in state["messages"]: | |
if isinstance(m, dict): | |
# Use role from dict (defaulting to 'user' if missing) | |
messages.append(HumanMessage(content=m.get("content", ""), role=m.get("role", "user"))) | |
else: | |
messages.append(m) | |
# Invoke the LLM (with tools) using the converted messages. | |
response = llm_with_tools.invoke(messages) | |
return {"messages": [response]} | |
def call_summary(state: MessagesState, config: RunnableConfig): | |
# Convert message dicts to HumanMessage instances if needed. | |
system_message=""" | |
Summarize the current conversation in no more than | |
1000 words. Also retain any unanswered quiz questions along with | |
your internal answers. | |
""" | |
messages = [] | |
for m in state["messages"]: | |
if isinstance(m, dict): | |
# Use role from dict (defaulting to 'user' if missing) | |
messages.append(AIMessage(content=system_message, role=m.get("role", "assistant"))) | |
else: | |
messages.append(m) | |
summaries = llm_with_tools.invoke(messages) | |
summary_content = summaries.content | |
# Call Tool Manually | |
message_with_single_tool_call = AIMessage( | |
content="", | |
tool_calls=[ | |
{ | |
"name": "Summary Tool", | |
"args": {"summary_text": summary_content}, | |
"id": "tool_call_id", | |
"type": "tool_call", | |
} | |
], | |
) | |
tool_node.invoke({"messages": [message_with_single_tool_call]}) | |
############################################################################### | |
# 7. Add Nodes & Edges, Then Compile | |
############################################################################### | |
graph_builder.add_node("agent", call_model) | |
graph_builder.add_node("tools", tool_node) | |
#graph_builder.add_node("summary", call_summary) | |
# Entry point | |
graph_builder.set_entry_point("agent") | |
# def custom_tools_condition(llm_output: dict) -> str: | |
# """Return which node to go to next based on the LLM output.""" | |
# # The LLM's JSON might have a field like {"name": "Recall Memory Tool", "arguments": {...}}. | |
# tool_name = llm_output.get("name", None) | |
# # If the LLM calls "Summary Tool", jump directly to the 'summary' node | |
# if tool_name == "Summary Tool": | |
# return "summary" | |
# # If the LLM calls any other recognized tool, go to 'tools' | |
# valid_tool_names = [t.name for t in tools] # all tools in the main tool_node | |
# if tool_name in valid_tool_names: | |
# return "tools" | |
# # If there's no recognized tool name, assume we're done => go to summary | |
# return "__end__" | |
# graph_builder.add_conditional_edges( | |
# "agent", | |
# custom_tools_condition, | |
# { | |
# "tools": "tools", | |
# "summary": "summary", | |
# "__end__": "summary", | |
# } | |
# ) | |
# If LLM requests a tool, go to "tools", otherwise go to "summary" | |
graph_builder.add_conditional_edges("agent", tools_condition) | |
#graph_builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", "__end__": "summary"}) | |
#graph_builder.add_conditional_edges("agent", lambda llm_output: "tools" if llm_output.get("name", None) in [t.name for t in tools] else "summary", {"tools": "tools", "__end__": "summary"} | |
# If we used a tool, return to the agent for final answer or more tools | |
graph_builder.add_edge("tools", "agent") | |
#graph_builder.add_edge("agent", "summary") | |
#graph_builder.set_finish_point("summary") | |
# Compile the graph with checkpointing and persistent store | |
graph = graph_builder.compile(checkpointer=checkpointer, store=in_memory_store) | |
#from langgraph.prebuilt import create_react_agent | |
#graph = create_react_agent(llm_with_tools, tools=tool_node, checkpointer=checkpointer, store=in_memory_store) | |
#from IPython.display import Image, display | |
#display(Image(graph.get_graph().draw_mermaid_png())) | |
######################################## | |
# Gradio Chatbot Application | |
######################################## | |
import gradio as gr | |
from gradio import ChatMessage | |
system_prompt = "You are a helpful Assistant. You will always use the tools available to you from {tools} to address user queries." | |
######################################## | |
# Upload_documents | |
######################################## | |
def upload_documents(file_list: List): | |
""" | |
Load documents into in-memory vector store. | |
""" | |
_documents = [] | |
for doc_path in file_list: | |
_documents.extend(load_file(doc_path)) | |
# Split the documents into smaller chunks for better embedding coverage | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=300, # Split into chunks of 512 characters | |
chunk_overlap=50, # Overlap by 50 characters | |
add_start_index=True | |
) | |
chunked_texts = splitter.split_documents(_documents) | |
in_memory_vs.add_documents(chunked_texts) | |
return f"Uploaded {len(file_list)} documents into in-memory vector store." | |
######################################## | |
# Submit_queries (ChatInterface Function) | |
######################################## | |
def submit_queries(message, _messages): | |
""" | |
- message: dict with {"text": ..., "files": [...]} | |
- history: list of ChatMessage | |
""" | |
_messages=[] | |
user_text = message.get("text", "") | |
user_files = message.get("files", []) | |
# Process user-uploaded files | |
if user_files: | |
for file_obj in user_files: | |
_messages.append(ChatMessage(role="user", content=f"Uploaded file: {file_obj}")) | |
upload_response = upload_documents(user_files) | |
_messages.append(ChatMessage(role="assistant", content=upload_response)) | |
yield _messages | |
return # Exit early since we don't need to process text or call the LLM | |
# Append user text if present | |
if user_text: | |
events = graph.stream( | |
{ | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_text}, | |
] | |
}, | |
config, | |
stream_mode="values" | |
) | |
for event in events: | |
response = event["messages"][-1] | |
if isinstance(response, AIMessage): | |
if "tool_calls" in response.additional_kwargs: | |
_messages.append( | |
ChatMessage(role="assistant", | |
content=str(response.tool_calls[0]["args"]), | |
metadata={"title": str(response.tool_calls[0]["name"]), | |
"id": config["configurable"]["thread_id"] | |
} | |
)) | |
yield _messages | |
else: | |
_messages.append(ChatMessage(role="assistant", | |
content=response.content, | |
metadata={"id": config["configurable"]["thread_id"] | |
} | |
)) | |
yield _messages | |
return _messages | |
######################################## | |
# 3) Save Chat History | |
######################################## | |
CHAT_HISTORY_FILE = "chat_history.json" | |
def save_chat_history(history): | |
""" | |
Saves the chat history into a JSON file. | |
""" | |
session_history = [ | |
{ | |
"role": "user" if msg.is_user else "assistant", | |
"content": msg.content | |
} | |
for msg in history | |
] | |
with open(CHAT_HISTORY_FILE, "w", encoding="utf-8") as f: | |
json.dump(session_history, f, ensure_ascii=False, indent=4) | |
######################################## | |
# 6) Main Gradio Interface | |
######################################## | |
with gr.Blocks() as AI_Tutor: | |
gr.Markdown("# AI Tutor Chatbot (Gradio App)") | |
# Primary Chat Interface | |
chat_interface = gr.ChatInterface( | |
fn=submit_queries, | |
type="messages", | |
chatbot=gr.Chatbot( | |
label="Chat Window", | |
height=500, | |
type="messages" | |
), | |
textbox=gr.MultimodalTextbox( | |
interactive=True, | |
file_count="multiple", | |
file_types=[".pdf",".ppt",".pptx",".doc",".docx",".md","image"], | |
sources=["upload"], | |
label="Type your query here:", | |
placeholder="Enter your question...", | |
), | |
title="AI Tutor Chatbot", | |
description="Ask me anything about Artificial Intelligence!", | |
multimodal=True, | |
save_history=True, | |
) | |
if __name__ == "__main__": | |
AI_Tutor.launch(server_name="0.0.0.0", server_port=7860) | |