|
import logging |
|
from config import AppConfig, ConfigConstants |
|
from data.load_dataset import load_data |
|
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics |
|
from retriever.chunk_documents import chunk_documents |
|
from retriever.embed_documents import embed_documents |
|
from generator.initialize_llm import initialize_generation_llm |
|
from generator.initialize_llm import initialize_validation_llm |
|
from app import launch_gradio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
def main(): |
|
logging.info("Starting the RAG pipeline") |
|
|
|
|
|
all_chunked_documents = [] |
|
datasets = {} |
|
|
|
|
|
for data_set_name in ConfigConstants.DATA_SET_NAMES: |
|
logging.info(f"Loading dataset: {data_set_name}") |
|
datasets[data_set_name] = load_data(data_set_name) |
|
|
|
|
|
chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE |
|
if data_set_name == 'cuad': |
|
chunk_size = 4000 |
|
|
|
|
|
chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP) |
|
all_chunked_documents.extend(chunked_documents) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info(f"Total chunked documents: {len(all_chunked_documents)}") |
|
|
|
|
|
vector_store = embed_documents(all_chunked_documents) |
|
logging.info("Documents embedded") |
|
|
|
|
|
gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME) |
|
|
|
|
|
val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm) |
|
launch_gradio(config) |
|
|
|
logging.info("Finished!!!") |
|
|
|
if __name__ == "__main__": |
|
main() |