Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_groq import ChatGroq | |
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph | |
from langchain.chains import GraphQAChain | |
from langchain_community.document_loaders import TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.vectorstores import Pinecone | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.prompts import PromptTemplate | |
from langchain_core.documents import Document | |
from neo4j import GraphDatabase | |
import networkx as nx | |
import pinecone | |
import os | |
from datetime import datetime | |
import gspread | |
from oauth2client.service_account import ServiceAccountCredentials | |
from typing import Optional, List, Dict, Any | |
from pydantic import BaseModel, Field | |
import time | |
import tracemalloc | |
# Custom Configuration for Pydantic | |
class Config: | |
arbitrary_types_allowed = True | |
# Define base models for type validation | |
class CustomDocument(BaseModel): | |
page_content: str | |
metadata: Dict[str, Any] = Field(default_factory=dict) | |
class Config(Config): | |
pass | |
os.system("pip install sentence-transformers") | |
os.system("pip install gspread oauth2client") | |
os.system("pip install -U langchain-huggingface langchain-community") | |
# Google Sheets Setup | |
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] | |
creds = ServiceAccountCredentials.from_json_keyfile_name(r"./gen-lang-client-0300122402-f2f67b3e8c27.json", scope) | |
client = gspread.authorize(creds) | |
spreadsheet_id = "10953i4ZOhvpNAyyrFdo-4TvVwfVWppt0H2xXM7yYQF4" | |
sheet = client.open_by_key(spreadsheet_id).sheet1 | |
def store_feedback_in_sheet(feedback, question, rag_response, graphrag_response): | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
row = [timestamp, question, rag_response, graphrag_response, feedback] | |
sheet.append_row(row) | |
def load_data(): | |
data = sheet.get_all_records() | |
return data[-10:], len(data) | |
def add_review(question, rag_response, graphrag_response, feedback): | |
store_feedback_in_sheet(feedback, question, rag_response, graphrag_response) | |
return load_data() | |
# RAG Setup | |
text_path = r"./text_chunks.txt" | |
loader = TextLoader(text_path, encoding='utf-8') | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=3000, chunk_overlap=4) | |
docs = text_splitter.split_documents(documents) | |
embeddings = HuggingFaceEmbeddings() | |
# Hugging Face Setup | |
repo_id = "meta-llama/Meta-Llama-3-8B" | |
llm = HuggingFaceHub( | |
repo_id=repo_id, | |
model_kwargs={"temperature": 0.8, "top_k": 50}, | |
huggingfacehub_api_token=os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
) | |
# Pinecone Setup | |
pinecone.init( | |
api_key=os.getenv('PINECONE_API_KEY', '6396a319-9bc0-49b2-97ba-400e96eff377'), | |
environment='gcp-starter' | |
) | |
index_name = "langchain-demo" | |
if index_name not in pinecone.list_indexes(): | |
pinecone.create_index(name=index_name, metric="cosine", dimension=768) | |
docsearch = Pinecone.from_documents(docs, embeddings, index_name=index_name) | |
else: | |
docsearch = Pinecone.from_existing_index(index_name, embeddings) | |
# Setup LLMs | |
rag_llm = ChatGroq( | |
model="Llama3-8b-8192", | |
temperature=0, | |
groq_api_key='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' | |
) | |
template = """ | |
You are a Thai rice assistant. These humans will ask you questions about Thai rice. | |
Answer the question only in Thai language. | |
Use the following piece of context to answer the question. | |
If you don't know the answer, just say you don't know. | |
Keep the answer within 2 sentences and concise. | |
Context: {context} | |
Question: {question} | |
Answer: | |
""" | |
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
# Modified RAG chain setup | |
def get_context(question: str) -> str: | |
docs = docsearch.similarity_search(question) | |
return " ".join([doc.page_content for doc in docs]) | |
def get_rag_response(question: str) -> str: | |
context = get_context(question) | |
response = rag_llm.predict(prompt.format(context=context, question=question)) | |
return response | |
# Graph RAG Setup | |
graphrag_llm = ChatGroq( | |
model="Llama3-8b-8192", | |
temperature=0, | |
groq_api_key='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' | |
) | |
# Neo4j Setup | |
uri = "neo4j+s://86a1bbb8.databases.neo4j.io" | |
user = "neo4j" | |
password = "zURGb3tO9VFtKxPEw98WrQnphvJQbBDofDM9uG3O8O8" | |
driver = GraphDatabase.driver(uri, auth=(user, password)) | |
def fetch_nodes(tx): | |
query = "MATCH (n) RETURN id(n) AS id, labels(n) AS labels" | |
result = tx.run(query) | |
return result.data() | |
def fetch_relationships(tx): | |
query = "MATCH (n)-[r]->(m) RETURN id(n) AS source, id(m) AS target, type(r) AS relation" | |
result = tx.run(query) | |
return result.data() | |
def populate_networkx_graph(): | |
G = nx.Graph() | |
with driver.session() as session: | |
nodes = session.execute_read(fetch_nodes) | |
relationships = session.execute_read(fetch_relationships) | |
for node in nodes: | |
G.add_node(node['id'], labels=node['labels']) | |
for relationship in relationships: | |
G.add_edge( | |
relationship['source'], | |
relationship['target'], | |
relation=relationship['relation'] | |
) | |
return G | |
networkx_graph = populate_networkx_graph() | |
graph = NetworkxEntityGraph() | |
graph._graph = networkx_graph | |
graphrag_chain = GraphQAChain.from_llm( | |
llm=graphrag_llm, | |
graph=graph, | |
verbose=True | |
) | |
# Define functions to measure memory and time | |
def measure_memory_and_time(func): | |
def wrapper(*args, **kwargs): | |
tracemalloc.start() # Start memory tracking | |
start_time = time.time() # Start time tracking | |
result = func(*args, **kwargs) | |
current, peak = tracemalloc.get_traced_memory() # Get memory usage | |
tracemalloc.stop() | |
end_time = time.time() # End time tracking | |
elapsed_time = end_time - start_time # Calculate elapsed time | |
return result, elapsed_time, peak / 1024 # Convert memory to KB | |
return wrapper | |
# Modified functions to use memory and time measurement | |
def get_rag_response(question: str) -> str: | |
context = get_context(question) | |
response = rag_llm.predict(prompt.format(context=context, question=question)) | |
return response | |
def get_graphrag_response(question: str) -> str: | |
system_prompt = "You are a Thai rice assistant that gives concise and direct answers. Do not explain the process, just provide the answer, provide the answer only in Thai." | |
formatted_question = f"System Prompt: {system_prompt}\n\nQuestion: {question}" | |
response = graphrag_chain.run(formatted_question) | |
return response | |
# Modify compare_models to collect and display metrics | |
def compare_models(question: str) -> dict: | |
rag_response, rag_time, rag_memory = get_rag_response(question) | |
graphrag_response, graphrag_time, graphrag_memory = get_graphrag_response(question) | |
# Combine responses with metrics | |
results = { | |
"RAG Response": rag_response, | |
"RAG Time (s)": round(rag_time, 2), | |
"RAG Memory (KB)": round(rag_memory, 2), | |
"GraphRAG Response": graphrag_response, | |
"GraphRAG Time (s)": round(graphrag_time, 2), | |
"GraphRAG Memory (KB)": round(graphrag_memory, 2), | |
} | |
return results | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox(label="ถามคำถามเกี่ยวกับข้าว:", placeholder="Enter your question about Thai rice") | |
submit_btn = gr.Button(value="ถาม") | |
rag_output = gr.Textbox(label="RAG Response", interactive=False) | |
rag_time_output = gr.Textbox(label="RAG Time (s)", interactive=False) | |
rag_memory_output = gr.Textbox(label="RAG Memory (KB)", interactive=False) | |
graphrag_output = gr.Textbox(label="GraphRAG Response", interactive=False) | |
graphrag_time_output = gr.Textbox(label="GraphRAG Time (s)", interactive=False) | |
graphrag_memory_output = gr.Textbox(label="GraphRAG Memory (KB)", interactive=False) | |
feedback = gr.Radio(label="Which response is better?", choices=["A ดีกว่า", "B ดีกว่า", "เท่ากัน", "แย่ทั้งคู่"]) | |
submit_feedback = gr.Button(value="Submit Feedback") | |
# Update Gradio app with time and memory results | |
def display_results(question): | |
results = compare_models(question) | |
return ( | |
results["RAG Response"], results["RAG Time (s)"], results["RAG Memory (KB)"], | |
results["GraphRAG Response"], results["GraphRAG Time (s)"], results["GraphRAG Memory (KB)"] | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=display_results, | |
inputs=[question_input], | |
outputs=[rag_output, rag_time_output, rag_memory_output, graphrag_output, graphrag_time_output, graphrag_memory_output] | |
) | |
submit_feedback.click(fn=add_review, inputs=[question_input, rag_output, graphrag_output, feedback]) | |
demo.load(fn=load_data, inputs=None) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |