Spaces:
Sleeping
Sleeping
from torch import cuda, bfloat16 | |
import torch | |
import transformers | |
from transformers import AutoTokenizer | |
from time import time | |
import chromadb | |
from chromadb.config import Settings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.document_loaders.csv_loader import CSVLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
import gradio as gr | |
############################################################################# | |
model_id = "marcolorenzi98/tinyllama-enron-v1" | |
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' | |
# set quantization configuration to load large model with less GPU memory | |
# this requires the `bitsandbytes` library | |
bnb_config = transformers.BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type='nf4', | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=bfloat16 | |
) | |
############################################################################## | |
model_config = transformers.AutoConfig.from_pretrained(model_id) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_id,trust_remote_code=True, | |
config=model_config, | |
#quantization_config=bnb_config, | |
device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
############################################################################## | |
embedding = SpacyEmbeddings(model_name="en_core_web_sm") | |
# Embed and store the texts | |
# Supplying a persist_directory will store the embeddings on disk | |
persist_directory = 'Enron_case_RAG/Langchain_ChromaDB' | |
# load from disk | |
db3 = Chroma(persist_directory=persist_directory, | |
embedding_function=embedding, | |
collection_name="Enron_vectorstore" | |
) | |
############################################################################## | |
query_pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.float16, | |
device_map="auto") | |
llm = HuggingFacePipeline(pipeline=query_pipeline) | |
retriever = db3.as_retriever() | |
############################################################################## | |
def gradio_rag(query): | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
verbose=True) | |
print(f"Query: {query}\n") | |
time_1 = time() | |
result = qa.run(query) | |
time_2 = time() | |
print(f"Inference time: {round(time_2-time_1, 3)} sec.") | |
print("\nResult: ", result) | |
############################################################################### | |
demo = gr.Interface( | |
fn=gradio_rag, | |
inputs=gr.Textbox(label="Please, write your request here:", placeholder="example: who is Sheila Chang", lines=5), | |
outputs=gr.Textbox(label="Answer:"), | |
title='Tiny Llama RAG on Enron Scandal', | |
description="This is a RAG system based on the SLM Tiny Llama, fine tuned on the Enron Scandal Emails' dataset", | |
allow_flagging="never" | |
) | |
demo.launch(debug=False) | |