Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""Deploy_CapstoneRagBench.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1OG-77VqKwz3509_osgNgSeOMJ9G6RvB4 | |
""" | |
# For Legal | |
from datasets import load_from_disk | |
from transformers import AutoTokenizer, AutoModel | |
import faiss | |
import numpy as np | |
import torch | |
from datasets import load_dataset, Dataset, get_dataset_config_names | |
import os | |
from groq import Groq | |
from sentence_transformers import CrossEncoder | |
import requests | |
import uuid | |
import re | |
import gradio as gr | |
import json | |
import torch | |
import numpy as np | |
from sklearn.metrics import mean_squared_error, roc_auc_score | |
import gradio as gr | |
import io | |
import sys | |
import traceback | |
def retrieve_top_k(query,domain='legal', model_name='nlpaueb/legal-bert-base-uncased', k=8): | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name).to(device) | |
model.eval() | |
#print(f"In retrive_top_k Query:{query}") | |
# Tokenize and embed query using mean pooling | |
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
# Load FAISS index and dataset | |
index_path = f"{domain}_index/faiss.index" | |
dataset_path = f"{domain}_dataset" | |
faiss_index = faiss.read_index(index_path) | |
dataset = load_from_disk(dataset_path) | |
# Perform FAISS search | |
D, I = faiss_index.search(query_embedding.astype('float32'), k) | |
# Retrieve top-k matching chunks | |
top_chunks = [dataset[int(idx)]['text'] for idx in I[0]] | |
return top_chunks | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
#print(device) | |
dataset = load_dataset("rungalileo/ragbench", "cuad", split="test") | |
client = Groq( | |
api_key= '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',#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', | |
) | |
# Load BGE reranker | |
reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512) | |
def rerank_documents_bge(query, documents, top_n=5, return_scores=False): | |
""" | |
Rerank documents using BAAI/bge-reranker-base CrossEncoder. | |
Args: | |
query (str): The query string. | |
documents (List[str]): List of candidate documents. | |
top_n (int): Number of top results to return. | |
return_scores (bool): Whether to return scores along with documents. | |
Returns: | |
List[str] or List[Tuple[str, float]] | |
""" | |
if not documents: | |
return [] | |
# Prepare (query, doc) pairs | |
pairs = [(query, doc) for doc in documents] | |
# Predict relevance scores | |
scores = reranker.predict(pairs, batch_size=16) | |
# Sort by score descending | |
reranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) | |
if return_scores: | |
return reranked[:top_n] | |
else: | |
return [doc for doc, _ in reranked[:top_n]] | |
def generate_response_rag(query,model,index_dir="legal_index"): | |
# Step 1: Retrieve top-k context chunks using your FAISS setup | |
top_chunks = retrieve_top_k(query,'legal', "nlpaueb/legal-bert-base-uncased") | |
# Step 2: Rerank retrieved documents using cross-encoder | |
#reranked_chunks = rerank_documents(query, top_chunks, top_n=15) | |
#rerank_and_filter_chunks = filter_by_faithfulness(query, reranked_chunks) | |
#reranked_chunks = rerank_and_filter_chunks | |
reranked_chunks_bge = rerank_documents_bge(query, top_chunks, top_n=5) | |
#sum_context = summarize_context("\n\n".join(reranked_chunks_bge)) | |
final_context = reranked_chunks_bge | |
# Step 2: Prepare context and RAG-style prompt | |
context = "\n\n".join(final_context) | |
#print(f"Context:{context}") | |
prompt = f"""You are a helpful legal assistant. | |
Use the following context to answer the question. | |
Using only the information from the retrieved context, answer the following question. If the answer cannot be derived, say "I don't know." Always have answer with prefix **Answer:** | |
Context:{context} | |
Question: {query} | |
Answer:""" | |
# Step 3: Call the LLM (LLaMA3 or any chat model) | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
model=model,#"gemma2-9b-it"#"qwen/qwen3-32b"#deepseek-r1-distill-llama-70b",#"llama3-70b-8192", # mistral-saba-24b | |
temperature=0.0 | |
) | |
return chat_completion.choices[0].message.content.strip() | |
'''response = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
temperature=0.0, | |
max_tokens=1024 | |
) | |
return response.choices[0].message.content''' | |
#JUDGE LLM | |
def split_into_keyed_sentences(text, prefix): | |
"""Splits text into sentences with keys like '0a.', '0b.', or 'a.', 'b.', etc.""" | |
# Basic sentence tokenizer with keys | |
sentences = re.split(r'(?<=[.?!])\s+', text.strip()) | |
keyed = {} | |
for i, s in enumerate(sentences): | |
key = f"{prefix}{chr(97 + i)}" # 'a', 'b', ... | |
if s: | |
keyed[key] = s.strip() | |
return keyed | |
def jugde_response_rag(query, embedder="nlpaueb/legal-bert-base-uncased", domain="legal", k=5): | |
top_chunks = retrieve_top_k(query) | |
top_chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in top_chunks] | |
# Step 2: Prepare context and RAG-style prompt | |
context = "\n\n".join(top_chunks) | |
# Split context and dummy answer into keyed sentences | |
document_keys = split_into_keyed_sentences(context, "0") | |
#print(f"Query:{query}\n====================================================================") | |
response = generate_response_rag(query,model="llama3-70b-8192") #deepseek-r1-distill-llama-70b llama3-70b-8192 | |
#print(f"\n====================================\Generator Response:{response}") | |
#For deepseek | |
#print("Before Curated:",response) | |
response=response[response.find("**Answer"):].replace("**Answer",""); | |
print(f"Response for Generator LLM:{response}") | |
response_keys = split_into_keyed_sentences(response, "") | |
# Rebuild sections for prompt | |
documents_formatted = "\n".join([f"{k}. {v}" for k, v in document_keys.items()]) | |
response_formatted = "\n".join([f"{k}. {v}" for k, v in response_keys.items()]) | |
'''print(f"\n====================================================================") | |
print(f"documents_formatted:{documents_formatted}") | |
print(f"\n====================================================================") | |
print(f"response_formatted:{response_formatted}") | |
print(f"\n====================================================================")''' | |
prompt = f"""I asked someone to answer a question based on one or more documents. | |
Your task is to review their response and assess whether or not each sentence | |
in that response is supported by text in the documents. And if so, which | |
sentences in the documents provide that support. You will also tell me which | |
of the documents contain useful information for answering the question, and | |
which of the documents the answer was sourced from. | |
Here are the documents, each of which is split into sentences. Alongside each | |
sentence is associated key, such as ’0a.’ or ’0b.’ that you can use to refer | |
to it: | |
''' | |
{documents_formatted} | |
''' | |
The question was: | |
''' | |
{query} | |
''' | |
Here is their response, split into sentences. Alongside each sentence is | |
associated key, such as ’a.’ or ’b.’ that you can use to refer to it. Note | |
that these keys are unique to the response, and are not related to the keys | |
in the documents: | |
''' | |
{response_formatted} | |
''' | |
You must respond with a JSON object matching this schema: | |
''' | |
{{ | |
"relevance_explanation": string, | |
"all_relevant_sentence_keys": [string], | |
"overall_supported_explanation": string, | |
"overall_supported": boolean, | |
"sentence_support_information": [ | |
{{ | |
"response_sentence_key": string, | |
"explanation": string, | |
"supporting_sentence_keys": [string], | |
"fully_supported": boolean | |
}}, | |
], | |
"all_utilized_sentence_keys": [string] | |
}} | |
''' | |
The relevance_explanation field is a string explaining which documents | |
contain useful information for answering the question. Provide a step-by-step | |
breakdown of information provided in the documents and how it is useful for | |
answering the question. | |
The all_relevant_sentence_keys field is a list of all document sentences keys | |
(e.g. ’0a’) that are revant to the question. Include every sentence that is | |
useful and relevant to the question, even if it was not used in the response, | |
or if only parts of the sentence are useful. Ignore the provided response when | |
making this judgement and base your judgement solely on the provided documents | |
and question. Omit sentences that, if removed from the document, would not | |
impact someone’s ability to answer the question. | |
The overall_supported_explanation field is a string explaining why the response | |
*as a whole* is or is not supported by the documents. In this field, provide a | |
step-by-step breakdown of the claims made in the response and the support (or | |
lack thereof) for those claims in the documents. Begin by assessing each claim | |
separately, one by one; don’t make any remarks about the response as a whole | |
until you have assessed all the claims in isolation. | |
The overall_supported field is a boolean indicating whether the response as a | |
whole is supported by the documents. This value should reflect the conclusion | |
you drew at the end of your step-by-step breakdown in overall_supported_explanation. | |
In the sentence_support_information field, provide information about the support | |
*for each sentence* in the response. | |
The sentence_support_information field is a list of objects, one for each sentence | |
in the response. Each object MUST have the following fields: | |
- response_sentence_key: a string identifying the sentence in the response. | |
This key is the same as the one used in the response above. | |
- explanation: a string explaining why the sentence is or is not supported by the | |
documents. | |
- supporting_sentence_keys: keys (e.g. ’0a’) of sentences from the documents that | |
support the response sentence. If the sentence is not supported, this list MUST | |
be empty. If the sentence is supported, this list MUST contain one or more keys. | |
In special cases where the sentence is supported, but not by any specific sentence, | |
you can use the string "supported_without_sentence" to indicate that the sentence | |
is generally supported by the documents. Consider cases where the sentence is | |
expressing inability to answer the question due to lack of relevant information in | |
the provided contex as "supported_without_sentence". In cases where the sentence | |
is making a general statement (e.g. outlining the steps to produce an answer, or | |
summarizing previously stated sentences, or a transition sentence), use the | |
sting "general".In cases where the sentence is correctly stating a well-known fact, | |
like a mathematical formula, use the string "well_known_fact". In cases where the | |
sentence is performing numerical reasoning (e.g. addition, multiplication), use | |
the string "numerical_reasoning". | |
- fully_supported: a boolean indicating whether the sentence is fully supported by | |
the documents. | |
- This value should reflect the conclusion you drew at the end of your step-by-step | |
breakdown in explanation. | |
- If supporting_sentence_keys is an empty list, then fully_supported must be false. | |
17 | |
- Otherwise, use fully_supported to clarify whether everything in the response | |
sentence is fully supported by the document text indicated in supporting_sentence_keys | |
(fully_supported = true), or whether the sentence is only partially or incompletely | |
supported by that document text (fully_supported = false). | |
The all_utilized_sentence_keys field is a list of all sentences keys (e.g. ’0a’) that | |
were used to construct the answer. Include every sentence that either directly supported | |
the answer, or was implicitly used to construct the answer, even if it was not used | |
in its entirety. Omit sentences that were not used, and could have been removed from | |
the documents without affecting the answer. | |
You must respond with a valid JSON string. Use escapes for quotes, e.g. ‘\\"‘, and | |
newlines, e.g. ‘\\n‘. Do not write anything before or after the JSON string. Do not | |
wrap the JSON string in backticks like ‘‘‘ or ‘‘‘json. | |
As a reminder: your task is to review the response and assess which documents contain | |
useful information pertaining to the question, and how each sentence in the response | |
is supported by the text in the documents.\ | |
""" | |
# Step 3: Call the LLM | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
model="meta-llama/llama-4-maverick-17b-128e-instruct", #deepseek-r1-distill-llama-70b llama3-70b-8192 meta-llama/llama-4-maverick-17b-128e-instruct | |
) | |
return documents_formatted,chat_completion.choices[0].message.content.strip() | |
'''chat_completion = openai.chat.completions.create( | |
messages=[ | |
{"role":"user", | |
"content":prompt} | |
], | |
model="gpt-4o", | |
max_tokens=1024, | |
) | |
return documents_formatted,chat_completion.choices[0].message.content''' | |
def extract_retrieved_sentence_keys(document_text: str) -> list[str]: | |
""" | |
Extracts sentence keys like '0a.', '0b.', etc. from a formatted document string. | |
Parameters: | |
- document_text (str): full text of document with sentence keys | |
Returns: | |
- List of unique sentence keys in the order they appear | |
""" | |
# Match pattern like 0a., 0b., 0z., 0{., 0|., etc. | |
pattern = r'\b0[\w\{\|\}~]\.' | |
matches = re.findall(pattern, document_text) | |
return list(dict.fromkeys(matches)) # Removes duplicates while preserving order | |
def compute_ragbench_metrics(judge_response: dict, retrieved_sentence_keys: list[str]) -> dict: | |
""" | |
Computes RAGBench-style metrics from Judge LLM response. | |
Parameters: | |
- judge_response (dict): JSON response from Judge LLM | |
- retrieved_sentence_keys (list of str): all sentence keys from the retrieved documents | |
Returns: | |
- Dictionary with Context Relevance, Context Utilization, Completeness, and Adherence | |
""" | |
R = set(judge_response.get("all_relevant_sentence_keys", [])) # Relevant sentences | |
U = set(judge_response.get("all_utilized_sentence_keys", [])) # Utilized sentences | |
intersection_RU = R & U | |
total_retrieved = len(retrieved_sentence_keys) | |
len_R = len(R) | |
len_U = len(U) | |
len_intersection = len(intersection_RU) | |
# Context Relevance: fraction of retrieved context that is relevant | |
context_relevance = len_R / total_retrieved if total_retrieved else 0.0 | |
# Context Utilization: fraction of retrieved context that was used | |
context_utilization = len_U / total_retrieved if total_retrieved else 0.0 | |
# Completeness: fraction of relevant content that was used | |
completeness = len_intersection / len_R if len_R else 0.0 | |
# Adherence: 1 if all response sentences are fully supported, else 0 | |
is_fully_supported = all(s.get("fully_supported", False) | |
for s in judge_response.get("sentence_support_information", [])) | |
adherence = 1.0 if is_fully_supported and judge_response.get("overall_supported", False) else 0.0 | |
return { | |
"Context Relevance": round(context_relevance, 4), | |
"Context Utilization": round(context_utilization, 4), | |
"Completeness": round(completeness, 4), | |
"Adherence": adherence | |
} | |
def compute_rmse(gt, pred): | |
return round(np.sqrt(np.mean((np.array(gt) - np.array(pred)) ** 2)), 4) | |
def evaluate_rag_pipeline(q_indices): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def safe_append(gt_list, pred_list, gt_val, pred_val): | |
if gt_val is not None and pred_val is not None: | |
gt_list.append(gt_val) | |
pred_list.append(pred_val) | |
def clean_and_parse_json_block(text): | |
# Strip markdown-style code block if present | |
#text = text.strip().strip("`").strip() | |
code_block_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text) | |
if code_block_match: | |
text = code_block_match.group(1).strip() | |
# Remove invalid/control characters that break decoding | |
text = re.sub(r"[^\x20-\x7E\n\t]", "", text) | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError as e: | |
print("❌ JSON Decode Error:", e) | |
print("⚠️ Cleaned text:\n", text) | |
raise | |
gt_relevance, pred_relevance = [], [] | |
gt_utilization, pred_utilization = [], [] | |
gt_completeness, pred_completeness = [], [] | |
gt_adherence, pred_adherence = [], [] | |
for i in q_indices: | |
query = dataset[i]['question'] | |
print(f"\n\n\nQuery:{i}.{query}\n====================================================================") | |
documents_formatted, response = jugde_response_rag( | |
query, embedder="nlpaueb/legal-bert-base-uncased", domain="legal") | |
judge_response = clean_and_parse_json_block(response) | |
print(f"\n======================================================================\nResponse:{judge_response}") | |
retrieved_sentences = extract_retrieved_sentence_keys(documents_formatted) | |
predicted = compute_ragbench_metrics(judge_response, retrieved_sentences) | |
# GT values | |
gt_r = dataset[i].get('relevance_score') | |
gt_u = dataset[i].get('utilization_score') | |
gt_c = dataset[i].get('completeness_score') | |
gt_a = dataset[i].get('gpt3_adherence') | |
safe_append(gt_relevance, pred_relevance, gt_r, predicted['Context Relevance']) | |
safe_append(gt_utilization, pred_utilization, gt_u, predicted['Context Utilization']) | |
safe_append(gt_completeness, pred_completeness, gt_c, predicted['Completeness']) | |
if gt_a is not None and predicted['Adherence'] is not None: | |
safe_append(gt_adherence, pred_adherence, int(gt_a), int(predicted['Adherence'])) | |
def compute_rmse(gt, pred): | |
return round(np.sqrt(np.mean((np.array(gt) - np.array(pred)) ** 2)), 4) | |
result = { | |
"Context Relevance": compute_rmse(gt_relevance, pred_relevance), | |
"Context Utilization": compute_rmse(gt_utilization, pred_utilization), | |
"Completeness": compute_rmse(gt_completeness, pred_completeness), | |
} | |
if len(set(gt_adherence)) == 2: | |
result["Adherence"] = compute_rmse(gt_adherence, pred_adherence) | |
result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4) | |
else: | |
result["Adherence"] = compute_rmse(gt_adherence, pred_adherence) | |
result["AUC-ROC (Adherence)"] = "N/A - one class only" | |
return result | |
# Wrapper to parse textbox input into list of ints | |
def evaluate_rag_gradio(q_indices_str): | |
# Capture printed logs | |
log_stream = io.StringIO() | |
sys.stdout = log_stream | |
try: | |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()] | |
results = evaluate_rag_pipeline(q_indices) | |
# Return metrics and logs | |
logs = log_stream.getvalue() | |
return results, logs | |
except Exception as e: | |
traceback.print_exc() | |
return {"error": str(e)}, log_stream.getvalue() | |
finally: | |
sys.stdout = sys.__stdout__ | |
iface = gr.Interface( | |
fn=evaluate_rag_gradio, | |
inputs=gr.Textbox(label="Comma-separated Query Indices (e.g. 89,121,245)", lines=1), | |
outputs=[ | |
gr.JSON(label="Evaluation Metrics (RMSE & AUC-ROC)"), | |
gr.Textbox(label="Execution Log", lines=5, interactive=True) | |
], | |
title="RAG Evaluation Dashboard", | |
description="Evaluate your RAG pipeline across selected queries using GPT-based generation and judgment." | |
) | |
iface.launch(debug=True) | |