LawBot511 / LawGPT /LawGPT.py
naitik31's picture
Upload 6 files
2df226a
from langchain.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
import torch
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.chains import RetrievalQA
import gradio as gr
def chat(chat_history, user_input):
bot_response = qa_chain({"query": user_input})
bot_response = bot_response['result']
response = ""
for letter in ''.join(bot_response):
response += letter + ""
yield chat_history + [(user_input, response)]
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
device_map="auto",
torch_dtype = torch.float32)
embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
db = Chroma(persist_directory="ipc_vector_data", embedding_function=embeddings)
pipe = pipeline(
'text2text-generation',
model = base_model,
tokenizer = tokenizer,
max_length = 512,
do_sample = True,
temperature = 0.3,
top_p= 0.95
)
local_llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = RetrievalQA.from_chain_type(llm=local_llm,
chain_type='stuff',
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k":2}),
return_source_documents=True,
)
with gr.Blocks() as gradioUI:
gr.Image('lawbot511.png')
with gr.Row():
chatbot = gr.Chatbot()
with gr.Row():
input_query = gr.TextArea(label='Input',show_copy_button=True)
with gr.Row():
with gr.Column():
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
clear_input_btn = gr.Button("Clear Input")
with gr.Column():
clear_chat_btn = gr.Button("Clear Chat")
submit_btn.click(chat, [chatbot, input_query], chatbot)
submit_btn.click(lambda: gr.update(value=""), None, input_query, queue=False)
clear_input_btn.click(lambda: None, None, input_query, queue=False)
clear_chat_btn.click(lambda: None, None, chatbot, queue=False)
gradioUI.queue().launch()