Enron_case_RAG / app.py
marcolorenzi98's picture
Update app.py
1b8e201 verified
raw
history blame
3.53 kB
from torch import cuda, bfloat16
import torch
import transformers
from transformers import AutoTokenizer
from time import time
import chromadb
from chromadb.config import Settings
from langchain_community.llms import HuggingFacePipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
import gradio as gr
#############################################################################
model_id = "marcolorenzi98/tinyllama-enron-v1"
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
##############################################################################
model_config = transformers.AutoConfig.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,trust_remote_code=True,
config=model_config,
#quantization_config=bnb_config,
device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
##############################################################################
embedding = SpacyEmbeddings(model_name="en_core_web_sm")
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'Enron_case_RAG/Langchain_ChromaDB'
# load from disk
db3 = Chroma(persist_directory=persist_directory,
embedding_function=embedding,
collection_name="Enron_vectorstore"
)
##############################################################################
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto")
llm = HuggingFacePipeline(pipeline=query_pipeline)
retriever = db3.as_retriever()
##############################################################################
def gradio_rag(query):
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True)
print(f"Query: {query}\n")
time_1 = time()
result = qa.run(query)
time_2 = time()
print(f"Inference time: {round(time_2-time_1, 3)} sec.")
print("\nResult: ", result)
###############################################################################
demo = gr.Interface(
fn=gradio_rag,
inputs=gr.Textbox(label="Please, write your request here:", placeholder="example: who is Sheila Chang", lines=5),
outputs=gr.Textbox(label="Answer:"),
title='Tiny Llama RAG on Enron Scandal',
description="This is a RAG system based on the SLM Tiny Llama, fine tuned on the Enron Scandal Emails' dataset",
examples=[["Who is Sheila Chang"],
["What were the key factors that led to the collapse of Enron?"],
["What were the repercussions of the Enron scandal on the energy industry and financial markets?"],
["How did Enron's accounting firm, Arthur Andersen, contribute to the scandal?"]],
allow_flagging="never"
)
demo.launch(debug=False)