import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextIteratorStreamer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from threading import Thread
from langchain_community.vectorstores.faiss import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import snapshot_download
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_NAME_OR_PATH = 'StevenChen16/llama3-8b-Lawyer'
DESCRIPTION = '''
'''
LICENSE = """
---
Built with model "StevenChen16/Llama3-8B-Lawyer", based on "meta-llama/Meta-Llama-3-8B"
"""
PLACEHOLDER = """
AI Lawyer
Ask me anything about US and Canada law...
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
# Load the model with disk offloading
print("Loading the model with disk offloading...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME_OR_PATH,
trust_remote_code=True,
low_cpu_mem_usage=True # Optimize memory usage during loading
)
# Specify an offload folder and map the model to disk and available GPUs
device_map = infer_auto_device_map(model, max_memory={"cpu": "50GB", "cuda:0": "16GB"})
dispatch_model(
model,
device_map=device_map,
offload_folder="./offload" # Folder for offloaded weights
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# Embedding model and FAISS vector store
def create_embedding_model(model_name):
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'trust_remote_code': True}
)
embedding_model = create_embedding_model('intfloat/multilingual-e5-large-instruct')
try:
print("Downloading vector store from HuggingFace Hub...")
repo_path = snapshot_download(
repo_id="StevenChen16/laws.faiss",
repo_type="model"
)
print("Loading vector store...")
vector_store = FAISS.load_local(
folder_path=repo_path,
embeddings=embedding_model,
allow_dangerous_deserialization=True
)
print("Vector store loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
background_prompt = '''
As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information...
[Shortened for brevity]
'''
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
"""
Query similar documents from vector store.
"""
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": relevance_threshold, "k": k}
)
similar_docs = retriever.invoke(query)
context = [doc.page_content for doc in similar_docs]
return " ".join(context) if context else ""
def chat_llama3_8b(message: str, history: list, temperature=0.6, max_new_tokens=4096) -> str:
"""
Generate a streaming response using the LLaMA model.
"""
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
conversation = []
for user, assistant in history:
conversation.extend([
{"role": "user", "content": str(user)},
{"role": "assistant", "content": str(assistant)}
])
final_message = f"{background_prompt}\n{message}" if not citation else f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generation_config = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": terminators
}
thread = Thread(target=model.generate, kwargs=generation_config)
thread.start()
accumulated_text = []
for text_chunk in streamer:
accumulated_text.append(text_chunk)
yield "".join(accumulated_text)
# Gradio interface
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
examples=[
['What are the key differences between a sole proprietorship and a partnership?'],
['What legal steps should I take if I want to start a business in the US?'],
['Can you explain the concept of "duty of care" in negligence law?'],
['What are the legal requirements for obtaining a patent in Canada?'],
['How can I protect my intellectual property when sharing my idea with potential investors?']
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()