Enron_case_RAG / app.py
marcolorenzi98's picture
Update app.py
9663c5d verified
from torch import cuda, bfloat16
import torch
import transformers
from transformers import AutoTokenizer
from time import time
import chromadb
from chromadb.config import Settings
from langchain_community.llms import HuggingFacePipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
import gradio as gr
#############################################################################
model_id = "marcolorenzi98/tinyllama-enron-v1"
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
##############################################################################
model_config = transformers.AutoConfig.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,trust_remote_code=True,
config=model_config,
#quantization_config=bnb_config,
device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
##############################################################################
embedding = SpacyEmbeddings(model_name="en_core_web_sm")
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'Enron_case_RAG/Langchain_ChromaDB'
# load from disk
db3 = Chroma(persist_directory=persist_directory,
embedding_function=embedding,
collection_name="Enron_vectorstore"
)
##############################################################################
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto")
llm = HuggingFacePipeline(pipeline=query_pipeline)
retriever = db3.as_retriever()
##############################################################################
def gradio_rag(query):
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True)
print(f"Query: {query}\n")
time_1 = time()
result = qa.run(query)
time_2 = time()
print(f"Inference time: {round(time_2-time_1, 3)} sec.")
print("\nResult: ", result)
###############################################################################
demo = gr.Interface(
fn=gradio_rag,
inputs=gr.Textbox(label="Please, write your request here:", placeholder="example: who is Sheila Chang", lines=5),
outputs=gr.Textbox(label="Answer:"),
title='Tiny Llama RAG on Enron Scandal',
description="This is a RAG system based on the SLM Tiny Llama, fine tuned on the Enron Scandal Emails' dataset",
allow_flagging="never"
)
demo.launch(debug=False)