Pidgin_0.1 / app.py
Ephraimmm's picture
Update app.py
1290f63 verified
import gradio as gr
import torch
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
import threading
from peft import PeftModel
import json
import time
import os
max_token = 9000
# -----------------------------
# 1️⃣ Set device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# -----------------------------
# 2️⃣ Load base model (skip compilation)
# -----------------------------
base_model_name = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"#"unsloth/gemma-3-4b-it-unsloth-bnb-4bit" #"unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit"
base_model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_name,
max_seq_length=2048,
dtype=torch.float16,
load_in_4bit=True,
)
# -----------------------------
# 3️⃣ Load LoRA
# -----------------------------
lora_repo = "Ephraimmm/PIDGIN_gemma-3" #"Ephraimmm/pigin-gemma-3-0.2" #"Ephraimmm/Pidgin_llamma_model"
lora_model = PeftModel.from_pretrained(base_model, lora_repo, adapter_name="adapter_model")
FastLanguageModel.for_inference(lora_model)
# -----------------------------
# 4️⃣ Streaming generation function
# -----------------------------
def generate_response(user_message):
messages = [
{
"role": "system",
"content": [{"type": "text", "text": """You are a Nigerian assistant that speaks PIDGIN ENGLISH.when asked how far reply I de o, how you de"""}]
},
{
"role": "user",
"content": [{"type": "text", "text": user_message}]
}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
return_dict=True
).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_token,
temperature=0.1,
top_p=1.0,
top_k=None,
use_cache=False
)
def generate():
lora_model.generate(**generation_kwargs)
thread = threading.Thread(target=generate)
thread.start()
full_response = ""
for new_token in streamer:
if new_token:
full_response += new_token
thread.join()
return full_response
# -----------------------------
# 5️⃣ Chat + Save
# -----------------------------
chat_history = []
def chat(user_message):
bot_response = generate_response(user_message)
chat_history.append((user_message, bot_response))
return chat_history, "" # also clears input box
def save_conversation():
if not chat_history:
# Return a small empty txt file instead of None (to avoid Gradio error)
file_path = "conversation_empty.txt"
with open(file_path, "w", encoding="utf-8") as f:
f.write("[]")
return file_path
conversation = []
for user_msg, bot_msg in chat_history:
conversation.append({"role": "user", "content": str(user_msg)})
conversation.append({"role": "assistant", "content": str(bot_msg)})
timestamp = time.strftime("%Y%m%d-%H%M%S")
file_path = f"conversation_{timestamp}.txt" # save as TXT not JSON
with open(file_path, "w", encoding="utf-8") as f:
json.dump(conversation, f, indent=4, ensure_ascii=False)
return file_path
# -----------------------------
# 6️⃣ Gradio interface
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Nigerian PIDGIN Assistant")
gr.Markdown("Chat with a Nigerian assistant that only speaks Pidgin English.")
chatbot = gr.Chatbot(label="Conversation")
user_input = gr.Textbox(label="Your message", placeholder="Type your message here...")
with gr.Row():
send_button = gr.Button("Send")
save_button = gr.Button("Save Conversation")
download_file = gr.File(label="Download Conversation")
send_button.click(chat, inputs=user_input, outputs=[chatbot, user_input])
save_button.click(save_conversation, outputs=download_file)
demo.launch()