import gradio as gr import torch from unsloth import FastLanguageModel from transformers import TextIteratorStreamer import threading from peft import PeftModel import json import time import os max_token = 9000 # ----------------------------- # 1️⃣ Set device # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) # ----------------------------- # 2️⃣ Load base model (skip compilation) # ----------------------------- base_model_name = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"#"unsloth/gemma-3-4b-it-unsloth-bnb-4bit" #"unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit" base_model, tokenizer = FastLanguageModel.from_pretrained( model_name=base_model_name, max_seq_length=2048, dtype=torch.float16, load_in_4bit=True, ) # ----------------------------- # 3️⃣ Load LoRA # ----------------------------- lora_repo = "Ephraimmm/PIDGIN_gemma-3" #"Ephraimmm/pigin-gemma-3-0.2" #"Ephraimmm/Pidgin_llamma_model" lora_model = PeftModel.from_pretrained(base_model, lora_repo, adapter_name="adapter_model") FastLanguageModel.for_inference(lora_model) # ----------------------------- # 4️⃣ Streaming generation function # ----------------------------- def generate_response(user_message): messages = [ { "role": "system", "content": [{"type": "text", "text": """You are a Nigerian assistant that speaks PIDGIN ENGLISH.when asked how far reply I de o, how you de"""}] }, { "role": "user", "content": [{"type": "text", "text": user_message}] } ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", tokenize=True, return_dict=True ).to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_token, temperature=0.1, top_p=1.0, top_k=None, use_cache=False ) def generate(): lora_model.generate(**generation_kwargs) thread = threading.Thread(target=generate) thread.start() full_response = "" for new_token in streamer: if new_token: full_response += new_token thread.join() return full_response # ----------------------------- # 5️⃣ Chat + Save # ----------------------------- chat_history = [] def chat(user_message): bot_response = generate_response(user_message) chat_history.append((user_message, bot_response)) return chat_history, "" # also clears input box def save_conversation(): if not chat_history: # Return a small empty txt file instead of None (to avoid Gradio error) file_path = "conversation_empty.txt" with open(file_path, "w", encoding="utf-8") as f: f.write("[]") return file_path conversation = [] for user_msg, bot_msg in chat_history: conversation.append({"role": "user", "content": str(user_msg)}) conversation.append({"role": "assistant", "content": str(bot_msg)}) timestamp = time.strftime("%Y%m%d-%H%M%S") file_path = f"conversation_{timestamp}.txt" # save as TXT not JSON with open(file_path, "w", encoding="utf-8") as f: json.dump(conversation, f, indent=4, ensure_ascii=False) return file_path # ----------------------------- # 6️⃣ Gradio interface # ----------------------------- with gr.Blocks() as demo: gr.Markdown("# Nigerian PIDGIN Assistant") gr.Markdown("Chat with a Nigerian assistant that only speaks Pidgin English.") chatbot = gr.Chatbot(label="Conversation") user_input = gr.Textbox(label="Your message", placeholder="Type your message here...") with gr.Row(): send_button = gr.Button("Send") save_button = gr.Button("Save Conversation") download_file = gr.File(label="Download Conversation") send_button.click(chat, inputs=user_input, outputs=[chatbot, user_input]) save_button.click(save_conversation, outputs=download_file) demo.launch()