import gradio as gr import os from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import TextIteratorStreamer from accelerate import init_empty_weights, load_checkpoint_and_dispatch from threading import Thread from langchain_community.vectorstores.faiss import FAISS from langchain_huggingface import HuggingFaceEmbeddings from huggingface_hub import snapshot_download # Set an environment variable HF_TOKEN = os.environ.get("HF_TOKEN", None) MODEL_NAME_OR_PATH = 'StevenChen16/llama3-8b-Lawyer' DESCRIPTION = '''
Wealth Wizards Logo

AI Lawyer

wealth wizards

''' LICENSE = """

--- Built with model "StevenChen16/Llama3-8B-Lawyer", based on "meta-llama/Meta-Llama-3-8B" """ PLACEHOLDER = """

AI Lawyer

Ask me anything about US and Canada law...

""" css = """ h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh; } """ # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH) # Load the model with disk offloading print("Loading the model with disk offloading...") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME_OR_PATH, trust_remote_code=True, low_cpu_mem_usage=True # Optimize memory usage during loading ) # Specify an offload folder and map the model to disk and available GPUs device_map = infer_auto_device_map(model, max_memory={"cpu": "50GB", "cuda:0": "16GB"}) dispatch_model( model, device_map=device_map, offload_folder="./offload" # Folder for offloaded weights ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] # Embedding model and FAISS vector store def create_embedding_model(model_name): return HuggingFaceEmbeddings( model_name=model_name, model_kwargs={'trust_remote_code': True} ) embedding_model = create_embedding_model('intfloat/multilingual-e5-large-instruct') try: print("Downloading vector store from HuggingFace Hub...") repo_path = snapshot_download( repo_id="StevenChen16/laws.faiss", repo_type="model" ) print("Loading vector store...") vector_store = FAISS.load_local( folder_path=repo_path, embeddings=embedding_model, allow_dangerous_deserialization=True ) print("Vector store loaded successfully") except Exception as e: raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}") background_prompt = ''' As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information... [Shortened for brevity] ''' def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8): """ Query similar documents from vector store. """ retriever = vector_store.as_retriever( search_type="similarity_score_threshold", search_kwargs={"score_threshold": relevance_threshold, "k": k} ) similar_docs = retriever.invoke(query) context = [doc.page_content for doc in similar_docs] return " ".join(context) if context else "" def chat_llama3_8b(message: str, history: list, temperature=0.6, max_new_tokens=4096) -> str: """ Generate a streaming response using the LLaMA model. """ citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7) conversation = [] for user, assistant in history: conversation.extend([ {"role": "user", "content": str(user)}, {"role": "assistant", "content": str(assistant)} ]) final_message = f"{background_prompt}\n{message}" if not citation else f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}" conversation.append({"role": "user", "content": final_message}) input_ids = tokenizer.apply_chat_template( conversation, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generation_config = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "temperature": temperature, "eos_token_id": terminators } thread = Thread(target=model.generate, kwargs=generation_config) thread.start() accumulated_text = [] for text_chunk in streamer: accumulated_text.append(text_chunk) yield "".join(accumulated_text) # Gradio interface chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, label='Gradio ChatInterface') with gr.Blocks(fill_height=True, css=css) as demo: gr.Markdown(DESCRIPTION) gr.ChatInterface( fn=chat_llama3_8b, chatbot=chatbot, fill_height=True, examples=[ ['What are the key differences between a sole proprietorship and a partnership?'], ['What legal steps should I take if I want to start a business in the US?'], ['Can you explain the concept of "duty of care" in negligence law?'], ['What are the legal requirements for obtaining a patent in Canada?'], ['How can I protect my intellectual property when sharing my idea with potential investors?'] ], cache_examples=False, ) gr.Markdown(LICENSE) if __name__ == "__main__": demo.launch()