|
from langchain.vectorstores import Chroma |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from transformers import pipeline |
|
import torch |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.embeddings import SentenceTransformerEmbeddings |
|
from langchain.chains import RetrievalQA |
|
import gradio as gr |
|
|
|
def chat(chat_history, user_input): |
|
|
|
bot_response = qa_chain({"query": user_input}) |
|
bot_response = bot_response['result'] |
|
response = "" |
|
for letter in ''.join(bot_response): |
|
response += letter + "" |
|
yield chat_history + [(user_input, response)] |
|
|
|
checkpoint = "MBZUAI/LaMini-Flan-T5-783M" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
checkpoint, |
|
device_map="auto", |
|
torch_dtype = torch.float32) |
|
|
|
embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") |
|
|
|
db = Chroma(persist_directory="ipc_vector_data", embedding_function=embeddings) |
|
|
|
pipe = pipeline( |
|
'text2text-generation', |
|
model = base_model, |
|
tokenizer = tokenizer, |
|
max_length = 512, |
|
do_sample = True, |
|
temperature = 0.3, |
|
top_p= 0.95 |
|
) |
|
local_llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
qa_chain = RetrievalQA.from_chain_type(llm=local_llm, |
|
chain_type='stuff', |
|
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k":2}), |
|
return_source_documents=True, |
|
) |
|
|
|
with gr.Blocks() as gradioUI: |
|
|
|
gr.Image('lawbot511.png') |
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
input_query = gr.TextArea(label='Input',show_copy_button=True) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
with gr.Column(): |
|
clear_input_btn = gr.Button("Clear Input") |
|
with gr.Column(): |
|
clear_chat_btn = gr.Button("Clear Chat") |
|
|
|
submit_btn.click(chat, [chatbot, input_query], chatbot) |
|
submit_btn.click(lambda: gr.update(value=""), None, input_query, queue=False) |
|
clear_input_btn.click(lambda: None, None, input_query, queue=False) |
|
clear_chat_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
gradioUI.queue().launch() |
|
|