project / app /main.py
kabylake's picture
commit
7bd11ed
import os
from typing import Tuple
import click
import pandas as pd
from datasets import Dataset
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_google_vertexai import ChatVertexAI
from loguru import logger
from ragas import evaluate
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import (
answer_relevancy,
context_precision, answer_correctness,
)
from tqdm import tqdm
from app.chroma import ChromaDenseVectorDB
from app.config.load import load_config
from app.config.models.configs import Config
from app.config.models.vertexai import VertexAIModel
from app.parsers.splitter import DocumentSplitter
from app.pipeline import LLMBundle
from app.ranking import BCEReranker
from app.splade import SpladeSparseVectorDB
def get_hash_mapping_filenames(
config: Config,
file_to_hash_fn: str = "file_hash_mappings.snappy.parquet",
docid_to_hash_fn="docid_hash_mappings.snappy.parquet",
) -> Tuple[str, str]:
file_hashes_fn = os.path.join(config.embeddings.embeddings_path, file_to_hash_fn)
docid_hashes_fn = os.path.join(config.embeddings.embeddings_path, docid_to_hash_fn)
return file_hashes_fn, docid_hashes_fn
@click.group()
def main():
pass
@main.command(name="index")
@click.option(
"-c",
"app_config_path",
required=True,
help="Specifies App JavaScript configuration file (should be module exported)"
)
def create_index(app_config_path):
config = load_config(app_config_path)
dense_db = ChromaDenseVectorDB(
persist_folder=str(config.embeddings.embeddings_path), config=config
)
splitter = DocumentSplitter(config)
all_docs, all_hash_filename_mappings, all_hash_docid_mappings = splitter.split()
# dense embeddings
dense_db.generate_embeddings(docs=all_docs)
# sparse embeddings
sparse_db = SpladeSparseVectorDB(config)
sparse_db.generate_embeddings(docs=all_docs)
file_hashes_fn, docid_hashes_fn = get_hash_mapping_filenames(config)
all_hash_filename_mappings.to_parquet(
file_hashes_fn, compression="snappy", index=False
)
all_hash_docid_mappings.to_parquet(
docid_hashes_fn, compression="snappy", index=False
)
logger.info("Document Embeddings Generated")
@main.command("predict")
@click.option(
"-c",
"app_config_path",
required=True,
type=click.Path(exists=True, dir_okay=False, file_okay=True),
help="Specifies App JavaScript configuration file (should be module exported)",
)
@click.option(
"-m",
"model_config_path",
required=True,
type=click.Path(exists=True, dir_okay=False, file_okay=True),
help="Specifies Model JavaScript configuration file (should be module exported)",
)
def predict_pipeline(app_config_path: str, model_config_path: str):
config = load_config(app_config_path, model_config_path)
# llm = OpenAIModel(config=config.llm.params)
llm = VertexAIModel(config=config.llm.params)
chain = load_qa_chain(llm=llm.model, prompt=llm.prompt)
store = ChromaDenseVectorDB(
persist_folder=str(config.embeddings.embeddings_path), config=config
)
store._load_retriever()
reranker = BCEReranker()
chunk_sizes = config.embeddings.chunk_sizes
splade = SpladeSparseVectorDB(config=config)
splade.load()
hyde_chain = LLMChain(
llm=llm.model,
prompt=PromptTemplate(
template="Write a short passage to answer the question: {question}",
input_variables=["question"],
),
)
llm_bundle = LLMBundle(
chain=chain,
reranker=reranker,
chunk_sizes=chunk_sizes,
sparse_db=splade,
dense_db=store,
hyde_chain=hyde_chain,
)
test_dataset = pd.read_json("evaluation_dataset.json", lines=True)
evaluate_data = {
"question": [],
"answer": [],
"contexts": [], # should be a list[list[str]]
'ground_truth': [],
'context_ground_truth': []
}
test_dataset = test_dataset.head(10)
for idx, row in tqdm(test_dataset.iterrows()):
output = llm_bundle.get_and_parse_response(
query=row["question"],
config=config,
)
response = output.response
evaluate_data["question"].append(row["question"])
evaluate_data["answer"].append(response)
evaluate_data["contexts"].append(output.semantic_search)
evaluate_data["ground_truth"].append(row["answer"])
evaluate_data["context_ground_truth"].append(row["context"])
evaluate_dataset = Dataset.from_dict(evaluate_data)
# store the evaluation dataset
evaluate_dataset.to_pandas().to_json("evaluation_output.json", orient="records", lines=True)
@main.command("evaluate")
def evaluate_pipeline():
ragas_vertexai_llm = ChatVertexAI(model_name="gemini-pro")
ragas_vertexai_llm = LangchainLLMWrapper(ragas_vertexai_llm)
vertexai_embeddings = SentenceTransformerEmbeddings(model_name="maidalun1020/bce-embedding-base_v1")
vertexai_embeddings = LangchainEmbeddingsWrapper(vertexai_embeddings)
metrics = [
# the accuracy of the generated answer when compared to the ground truth
answer_correctness,
# evaluates whether all the ground-truth relevant items present in the contexts are ranked higher or not
context_precision,
# how pertinent the generated answer is to the given prompt
answer_relevancy,
]
evaluate_dataset = pd.read_json("evaluation_output.json", lines=True)
evaluate_dataset = Dataset.from_pandas(evaluate_dataset)
evaluate_result = evaluate(
dataset=evaluate_dataset,
metrics=metrics,
llm=ragas_vertexai_llm,
embeddings=vertexai_embeddings,
is_async=True
)
evaluate_result_df = evaluate_result.to_pandas()
# drop the contexts, context_ground_truth
evaluate_result_df = evaluate_result_df.drop(columns=["contexts", "context_ground_truth"])
# print the mean for answer_correctness context_precision answer_relevancy columns
print(evaluate_result_df.mean(numeric_only=True))
evaluate_result_df.to_csv("evaluation_results.csv", index=False)
if __name__ == "__main__":
main()