Spaces:
Running
Running
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel,PeftConfig | |
# Configuration | |
BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_components(): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
print("Loading model and tokenizer...") | |
try: | |
# Load tokenizer from base model | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
# Configure 4-bit loading | |
# bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_quant_type="nf4", | |
# bnb_4bit_compute_dtype=torch.float16, | |
# bnb_4bit_use_double_quant=False, | |
# ) | |
# Load base model with correct device mapping | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
# quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype="auto", | |
trust_remote_code=True | |
) | |
# Load LoRA adapters with proper config | |
config = PeftConfig.from_pretrained(LORA_ADAPTERS) | |
model = PeftModel.from_pretrained( | |
base_model, | |
LORA_ADAPTERS, | |
device_map="auto", | |
is_trainable=False # Important for inference | |
) | |
# Merge adapters carefully | |
model = model.merge_and_unload() | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
return model, tokenizer | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
"""Handle chat responses using the loaded model""" | |
global model, tokenizer | |
try: | |
# Create conversation history | |
messages = [{"role": "system", "content": system_message}] | |
for user_input, bot_response in history: | |
if user_input: | |
messages.append({"role": "user", "content": user_input}) | |
if bot_response: | |
messages.append({"role": "assistant", "content": bot_response}) | |
messages.append({"role": "user", "content": message}) | |
# Format input using chat template | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate response | |
outputs = model.generate( | |
input_ids=inputs.input_ids, | |
max_new_tokens=int(max_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
do_sample=temperature > 0.1, | |
use_cache=True, | |
) | |
# Decode and return response | |
response = tokenizer.decode( | |
outputs[0][inputs.input_ids.shape[1]:], | |
skip_special_tokens=True | |
) | |
return response | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def create_interface(): | |
"""Create Gradio interface""" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") | |
with gr.Row(): | |
reload_btn = gr.Button("Reload Model") | |
status = gr.Textbox(label="Load Status", interactive=False) | |
chat_interface = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", | |
label="System message", lines=2), | |
gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), | |
] | |
) | |
def reload_model(): | |
global model, tokenizer | |
try: | |
model, tokenizer = None, None | |
load_components() | |
return "Model reloaded successfully!" | |
except Exception as e: | |
return f"Reload failed: {str(e)}" | |
reload_btn.click(reload_model, outputs=status) | |
return demo | |
if __name__ == "__main__": | |
# Initial model load | |
load_components() | |
# Create and launch interface | |
demo = create_interface() | |
demo.launch() | |
# import torch | |
# import gradio as gr | |
# from transformers import AutoTokenizer, AutoModelForCausalLM | |
# from peft import PeftModel | |
# # Load the base model and LoRA adapters | |
# BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
# LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" | |
# def load_model(): | |
# print("Loading model and tokenizer...") | |
# try: | |
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
# base_model = AutoModelForCausalLM.from_pretrained( | |
# BASE_MODEL, | |
# device_map="auto", | |
# torch_dtype="auto", # Explicitly set dtype | |
# trust_remote_code=True | |
# ) | |
# model = PeftModel.from_pretrained(base_model, LORA_ADAPTERS, device_map="auto") | |
# model = model.merge_and_unload() | |
# print("Model loaded successfully!") | |
# return tokenizer, model | |
# except Exception as e: | |
# print(f"Error loading model: {str(e)}") | |
# return None, None | |
# # Global variables for model and tokenizer | |
# tokenizer, model = None, None | |
# def respond(message, history, system_message, max_tokens, temperature, top_p): | |
# global tokenizer, model | |
# # Check if model is loaded | |
# if tokenizer is None or model is None: | |
# # Try loading model again | |
# tokenizer, model = load_model() | |
# if tokenizer is None or model is None: | |
# return "Failed to load the model. Please check your environment and dependencies." | |
# try: | |
# messages = [{"role": "system", "content": system_message}] | |
# for user_input, bot_response in history: | |
# if user_input: | |
# messages.append({"role": "user", "content": user_input}) | |
# if bot_response: | |
# messages.append({"role": "assistant", "content": bot_response}) | |
# messages.append({"role": "user", "content": message}) | |
# # Format the input for Llama 3.1 | |
# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# outputs = model.generate( | |
# input_ids=inputs.input_ids, | |
# max_new_tokens=int(max_tokens), | |
# temperature=float(temperature), | |
# top_p=float(top_p), | |
# do_sample=temperature > 0.1, | |
# use_cache=True, | |
# ) | |
# response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
# return response | |
# except Exception as e: | |
# import traceback | |
# error_details = traceback.format_exc() | |
# return f"Error generating answer: {str(e)}\n\nDetails: {error_details}" | |
# # Create the Gradio interface | |
# def create_interface(): | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") | |
# with gr.Row(): | |
# with gr.Column(): | |
# load_button = gr.Button("Reload Model") | |
# def reload_model(): | |
# global tokenizer, model | |
# tokenizer, model = load_model() | |
# if tokenizer is not None and model is not None: | |
# return "Model reloaded successfully." | |
# else: | |
# return "Failed to reload model." | |
# load_button.click(reload_model, outputs=gr.Textbox(label="Status")) | |
# with gr.Row(): | |
# with gr.Column(scale=4): | |
# chatbot = gr.ChatInterface( | |
# respond, | |
# additional_inputs=[ | |
# gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", | |
# label="System message", lines=2), | |
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
# gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
# ], | |
# ) | |
# return demo | |
# if __name__ == "__main__": | |
# # Load model at startup | |
# tokenizer, model = load_model() | |
# # Create and launch interface | |
# demo = create_interface() | |
# demo.launch() | |