Lawgarithm / app.py
Khalid02's picture
updated
1b52d49 verified
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel,PeftConfig
# Configuration
BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"
LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters"
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_components():
global model, tokenizer
if model is None or tokenizer is None:
print("Loading model and tokenizer...")
try:
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Configure 4-bit loading
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# bnb_4bit_use_double_quant=False,
# )
# Load base model with correct device mapping
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
# quantization_config=bnb_config,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
# Load LoRA adapters with proper config
config = PeftConfig.from_pretrained(LORA_ADAPTERS)
model = PeftModel.from_pretrained(
base_model,
LORA_ADAPTERS,
device_map="auto",
is_trainable=False # Important for inference
)
# Merge adapters carefully
model = model.merge_and_unload()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
return model, tokenizer
def respond(message, history, system_message, max_tokens, temperature, top_p):
"""Handle chat responses using the loaded model"""
global model, tokenizer
try:
# Create conversation history
messages = [{"role": "system", "content": system_message}]
for user_input, bot_response in history:
if user_input:
messages.append({"role": "user", "content": user_input})
if bot_response:
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
# Format input using chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response
outputs = model.generate(
input_ids=inputs.input_ids,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=temperature > 0.1,
use_cache=True,
)
# Decode and return response
response = tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def create_interface():
"""Create Gradio interface"""
with gr.Blocks() as demo:
gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant")
with gr.Row():
reload_btn = gr.Button("Reload Model")
status = gr.Textbox(label="Load Status", interactive=False)
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.",
label="System message", lines=2),
gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
]
)
def reload_model():
global model, tokenizer
try:
model, tokenizer = None, None
load_components()
return "Model reloaded successfully!"
except Exception as e:
return f"Reload failed: {str(e)}"
reload_btn.click(reload_model, outputs=status)
return demo
if __name__ == "__main__":
# Initial model load
load_components()
# Create and launch interface
demo = create_interface()
demo.launch()
# import torch
# import gradio as gr
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from peft import PeftModel
# # Load the base model and LoRA adapters
# BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"
# LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters"
# def load_model():
# print("Loading model and tokenizer...")
# try:
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# base_model = AutoModelForCausalLM.from_pretrained(
# BASE_MODEL,
# device_map="auto",
# torch_dtype="auto", # Explicitly set dtype
# trust_remote_code=True
# )
# model = PeftModel.from_pretrained(base_model, LORA_ADAPTERS, device_map="auto")
# model = model.merge_and_unload()
# print("Model loaded successfully!")
# return tokenizer, model
# except Exception as e:
# print(f"Error loading model: {str(e)}")
# return None, None
# # Global variables for model and tokenizer
# tokenizer, model = None, None
# def respond(message, history, system_message, max_tokens, temperature, top_p):
# global tokenizer, model
# # Check if model is loaded
# if tokenizer is None or model is None:
# # Try loading model again
# tokenizer, model = load_model()
# if tokenizer is None or model is None:
# return "Failed to load the model. Please check your environment and dependencies."
# try:
# messages = [{"role": "system", "content": system_message}]
# for user_input, bot_response in history:
# if user_input:
# messages.append({"role": "user", "content": user_input})
# if bot_response:
# messages.append({"role": "assistant", "content": bot_response})
# messages.append({"role": "user", "content": message})
# # Format the input for Llama 3.1
# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# outputs = model.generate(
# input_ids=inputs.input_ids,
# max_new_tokens=int(max_tokens),
# temperature=float(temperature),
# top_p=float(top_p),
# do_sample=temperature > 0.1,
# use_cache=True,
# )
# response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# return response
# except Exception as e:
# import traceback
# error_details = traceback.format_exc()
# return f"Error generating answer: {str(e)}\n\nDetails: {error_details}"
# # Create the Gradio interface
# def create_interface():
# with gr.Blocks() as demo:
# with gr.Row():
# gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant")
# with gr.Row():
# with gr.Column():
# load_button = gr.Button("Reload Model")
# def reload_model():
# global tokenizer, model
# tokenizer, model = load_model()
# if tokenizer is not None and model is not None:
# return "Model reloaded successfully."
# else:
# return "Failed to reload model."
# load_button.click(reload_model, outputs=gr.Textbox(label="Status"))
# with gr.Row():
# with gr.Column(scale=4):
# chatbot = gr.ChatInterface(
# respond,
# additional_inputs=[
# gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.",
# label="System message", lines=2),
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
# ],
# )
# return demo
# if __name__ == "__main__":
# # Load model at startup
# tokenizer, model = load_model()
# # Create and launch interface
# demo = create_interface()
# demo.launch()