import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel,PeftConfig # Configuration BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" # Global variables for model and tokenizer model = None tokenizer = None def load_components(): global model, tokenizer if model is None or tokenizer is None: print("Loading model and tokenizer...") try: # Load tokenizer from base model tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # Configure 4-bit loading # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_compute_dtype=torch.float16, # bnb_4bit_use_double_quant=False, # ) # Load base model with correct device mapping base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, # quantization_config=bnb_config, device_map="auto", torch_dtype="auto", trust_remote_code=True ) # Load LoRA adapters with proper config config = PeftConfig.from_pretrained(LORA_ADAPTERS) model = PeftModel.from_pretrained( base_model, LORA_ADAPTERS, device_map="auto", is_trainable=False # Important for inference ) # Merge adapters carefully model = model.merge_and_unload() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {str(e)}") raise return model, tokenizer def respond(message, history, system_message, max_tokens, temperature, top_p): """Handle chat responses using the loaded model""" global model, tokenizer try: # Create conversation history messages = [{"role": "system", "content": system_message}] for user_input, bot_response in history: if user_input: messages.append({"role": "user", "content": user_input}) if bot_response: messages.append({"role": "assistant", "content": bot_response}) messages.append({"role": "user", "content": message}) # Format input using chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate response outputs = model.generate( input_ids=inputs.input_ids, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=temperature > 0.1, use_cache=True, ) # Decode and return response response = tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) return response except Exception as e: return f"Error generating response: {str(e)}" def create_interface(): """Create Gradio interface""" with gr.Blocks() as demo: gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") with gr.Row(): reload_btn = gr.Button("Reload Model") status = gr.Textbox(label="Load Status", interactive=False) chat_interface = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", label="System message", lines=2), gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), ] ) def reload_model(): global model, tokenizer try: model, tokenizer = None, None load_components() return "Model reloaded successfully!" except Exception as e: return f"Reload failed: {str(e)}" reload_btn.click(reload_model, outputs=status) return demo if __name__ == "__main__": # Initial model load load_components() # Create and launch interface demo = create_interface() demo.launch() # import torch # import gradio as gr # from transformers import AutoTokenizer, AutoModelForCausalLM # from peft import PeftModel # # Load the base model and LoRA adapters # BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" # LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" # def load_model(): # print("Loading model and tokenizer...") # try: # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # base_model = AutoModelForCausalLM.from_pretrained( # BASE_MODEL, # device_map="auto", # torch_dtype="auto", # Explicitly set dtype # trust_remote_code=True # ) # model = PeftModel.from_pretrained(base_model, LORA_ADAPTERS, device_map="auto") # model = model.merge_and_unload() # print("Model loaded successfully!") # return tokenizer, model # except Exception as e: # print(f"Error loading model: {str(e)}") # return None, None # # Global variables for model and tokenizer # tokenizer, model = None, None # def respond(message, history, system_message, max_tokens, temperature, top_p): # global tokenizer, model # # Check if model is loaded # if tokenizer is None or model is None: # # Try loading model again # tokenizer, model = load_model() # if tokenizer is None or model is None: # return "Failed to load the model. Please check your environment and dependencies." # try: # messages = [{"role": "system", "content": system_message}] # for user_input, bot_response in history: # if user_input: # messages.append({"role": "user", "content": user_input}) # if bot_response: # messages.append({"role": "assistant", "content": bot_response}) # messages.append({"role": "user", "content": message}) # # Format the input for Llama 3.1 # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # outputs = model.generate( # input_ids=inputs.input_ids, # max_new_tokens=int(max_tokens), # temperature=float(temperature), # top_p=float(top_p), # do_sample=temperature > 0.1, # use_cache=True, # ) # response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # return response # except Exception as e: # import traceback # error_details = traceback.format_exc() # return f"Error generating answer: {str(e)}\n\nDetails: {error_details}" # # Create the Gradio interface # def create_interface(): # with gr.Blocks() as demo: # with gr.Row(): # gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") # with gr.Row(): # with gr.Column(): # load_button = gr.Button("Reload Model") # def reload_model(): # global tokenizer, model # tokenizer, model = load_model() # if tokenizer is not None and model is not None: # return "Model reloaded successfully." # else: # return "Failed to reload model." # load_button.click(reload_model, outputs=gr.Textbox(label="Status")) # with gr.Row(): # with gr.Column(scale=4): # chatbot = gr.ChatInterface( # respond, # additional_inputs=[ # gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", # label="System message", lines=2), # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), # gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), # ], # ) # return demo # if __name__ == "__main__": # # Load model at startup # tokenizer, model = load_model() # # Create and launch interface # demo = create_interface() # demo.launch()