|
import spaces |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig |
|
import torch |
|
from threading import Thread |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import fitz |
|
|
|
|
|
|
|
token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
ST = SentenceTransformer("jhgan/ko-sroberta-multitask") |
|
|
|
|
|
def extract_text_from_pdf(pdf_path): |
|
doc = fitz.open(pdf_path) |
|
text = "" |
|
for page in doc: |
|
text += page.get_text() |
|
return text |
|
|
|
|
|
pdf_path = "laws.pdf" |
|
law_text = extract_text_from_pdf(pdf_path) |
|
|
|
|
|
law_sentences = law_text.split('\n') |
|
law_embeddings = ST.encode(law_sentences) |
|
|
|
|
|
index = faiss.IndexFlatL2(law_embeddings.shape[1]) |
|
index.add(law_embeddings) |
|
|
|
|
|
dataset = load_dataset("jihye-moon/LawQA-Ko") |
|
data = dataset["train"] |
|
|
|
|
|
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True) |
|
data.add_faiss_index(column="question_embedding") |
|
|
|
|
|
model_id = "google/gemma-2-27b-it" |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
quantization_config=bnb_config, |
|
token=token |
|
) |
|
|
|
SYS_PROMPT = """You are an assistant for answering legal questions. |
|
You are given the extracted parts of legal documents and a question. Provide a conversational answer. |
|
If you don't know the answer, just say "I do not know." Don't make up an answer. |
|
you must answer korean.""" |
|
|
|
|
|
@spaces.Gpu |
|
def search_law(query, k=5): |
|
query_embedding = ST.encode([query]) |
|
D, I = index.search(query_embedding, k) |
|
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] |
|
|
|
|
|
@spaces.Gpu |
|
def search_qa(query, k=3): |
|
scores, retrieved_examples = data.get_nearest_examples( |
|
"question_embedding", ST.encode(query), k=k |
|
) |
|
return [retrieved_examples["answer"][i] for i in range(k)] |
|
|
|
|
|
def format_prompt(prompt, law_docs, qa_docs): |
|
PROMPT = f"Question: {prompt}\n\nLegal Context:\n" |
|
for doc in law_docs: |
|
PROMPT += f"{doc[0]}\n" |
|
PROMPT += "\nLegal QA:\n" |
|
for doc in qa_docs: |
|
PROMPT += f"{doc}\n" |
|
return PROMPT |
|
|
|
|
|
@spaces.Gpu |
|
def talk(prompt, history): |
|
law_results = search_law(prompt, k=3) |
|
qa_results = search_qa(prompt, k=3) |
|
|
|
retrieved_law_docs = [result[0] for result in law_results] |
|
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) |
|
formatted_prompt = formatted_prompt[:2000] |
|
|
|
|
|
messages = [{"role": "user", "content": SYS_PROMPT + "\n" + formatted_prompt}] |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
|
) |
|
|
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.2, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
|
|
TITLE = "Legal RAG Chatbot" |
|
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. |
|
This chatbot can search legal documents and previous legal QA pairs to provide answers.""" |
|
|
|
demo = gr.ChatInterface( |
|
fn=talk, |
|
chatbot=gr.Chatbot( |
|
show_label=True, |
|
show_share_button=True, |
|
show_copy_button=True, |
|
likeable=True, |
|
layout="bubble", |
|
bubble_full_width=False, |
|
), |
|
theme="Soft", |
|
examples=[["What are the regulations on data privacy?"]], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
) |
|
|
|
|
|
demo.launch(debug=True) |