|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import BitsAndBytesConfig |
|
import torch |
|
import json |
|
|
|
from torch import cuda |
|
import torch |
|
import transformers |
|
from time import time |
|
import chromadb |
|
from chromadb.config import Settings |
|
from langchain.llms import huggingface_pipeline |
|
from langchain.document_loaders import TextLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.chains import RetrievalQA |
|
from langchain.vectorstores.chroma import Chroma |
|
import os |
|
|
|
nf4_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
model_path = 'meta-llama/Llama-2-7b-chat-hf' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
device_map="auto", |
|
offload_folder="offload/", |
|
cache_dir="cache/", |
|
quantization_config=nf4_config |
|
) |
|
|
|
model_config = transformers.AutoConfig.from_pretrained(model_path) |
|
|
|
|
|
|
|
query_pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
llm = huggingface_pipeline.HuggingFacePipeline(pipeline=query_pipeline) |
|
|
|
|
|
|
|
loader = TextLoader("MTM_Memoir_txt.txt", encoding="utf-8") |
|
documents = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20) |
|
all_splits = text_splitter.split_documents(documents) |
|
|
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2" |
|
model_kwargs = {"device": "cuda"} |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) |
|
|
|
|
|
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="chroma_db") |
|
|
|
|
|
retriever = vectordb.as_retriever() |
|
qa = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
verbose=True |
|
) |
|
|
|
|
|
def test_rag(qa, query): |
|
print(query) |
|
result = qa.run(query) |
|
print(f"Result \t {result}") |
|
|
|
test_rag(qa, "Hello when were you born?") |
|
|
|
|
|
def preprocess_query(query): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTRO = f"<s>[INST] <<SYS>>You are former Malaysian Prime Minister Tun Dr Mahathir Mohamad.. A visionary leader \n\n<</SYS>>" |
|
INSTRUCTION = f"### Instruction\n Respond to the following query by your subject {query} Just like yourself. \n\n" |
|
|
|
RESPONSE = f"### Response:\n\n" |
|
|
|
final_payload = INTRO + INSTRUCTION + RESPONSE |
|
payload_length = len(final_payload) |
|
|
|
return final_payload |
|
|
|
|
|
def get_result(qa=qa, query = ""): |
|
return qa.run(query) |
|
|
|
|
|
def predict(query): |
|
processed_query = preprocess_query(query=query) |
|
result = get_result(query=processed_query) |
|
return(result) |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as sql_generator: |
|
query = gr.Textbox(label="Query", placeholder='Ask the president?') |
|
|
|
output = gr.Textbox(label="Output") |
|
submit_button = gr.Button("Submit") |
|
submit_button.click(fn=predict, |
|
inputs=query, |
|
outputs=output, api_name="predict" |
|
) |
|
|
|
|
|
|
|
sql_generator.launch() |