richardprobe's picture
Update app.py
6a6269f verified
# app.py
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
# ---- CONFIG ----
ADAPTER_REPO = "richardprobe/opt-350-chris-adapter" # your LoRA repo
ADAPTER_NAME = "finetune_adapter" # how you saved it
SYSTEM_PROMPT = "You are Richard. Be concise and casual."
# If the adapter is private on the Hub, set HF_TOKEN in the Space secrets
HF_TOKEN = os.getenv("HF_TOKEN", None)
# ------------- Loading -------------
def load_model_and_tokenizer():
# Inspect adapter to get its base
print("Reading adapter config...")
peft_cfg = PeftConfig.from_pretrained(ADAPTER_REPO, token=HF_TOKEN)
base_id = peft_cfg.base_model_name_or_path
print(f"Base model detected: {base_id}")
# Tokenizer from base (adapter may also carry added tokens)
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=HF_TOKEN)
# Safety: many decoder-only models don't define a pad token
if tok.pad_token is None and tok.eos_token is not None:
tok.pad_token = tok.eos_token
tok.padding_side = "right"
# Non-quantized load so we can merge
print("Loading base model...")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
base_id, torch_dtype=dtype, device_map="auto", token=HF_TOKEN
)
print("Loading adapter and merging...")
peft = PeftModel.from_pretrained(
base, ADAPTER_REPO, adapter_name=ADAPTER_NAME, token=HF_TOKEN
)
# This bakes LoRA weights into the base weights and returns a plain model
merged = peft.merge_and_unload() # equivalent to merge_adapter + unload
merged.eval()
# We’ll use <|end|> as EOS if it exists
try:
end_id = tok.convert_tokens_to_ids("<|end|>")
if end_id is not None and end_id != tok.unk_token_id:
merged.config.eos_token_id = end_id
except Exception:
pass
return tok, merged
tokenizer, model = load_model_and_tokenizer()
# ------------- Prompt building -------------
def build_prompt(history, user_msg):
"""
Render your chat format using the added tokens that were used during training.
History is a list of (user, assistant) tuples from ChatInterface.
"""
segments = []
if SYSTEM_PROMPT:
# If you trained with a system token, add it here. Otherwise keep as plain text.
segments.append(f"<|system|>{SYSTEM_PROMPT}<|end|>")
for u, a in history or []:
if u:
segments.append(f"<|user|>{u}<|end|>")
if a:
segments.append(f"<|assistant|>{a}<|end|>")
segments.append(f"<|user|>{user_msg}<|end|>")
segments.append("<|assistant|>")
return "\n".join(segments)
# ------------- Inference -------------
def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1):
prompt = build_prompt(history, message)
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=float(temperature) > 0,
repetition_penalty=float(repetition_penalty),
eos_token_id=getattr(model.config, "eos_token_id", tokenizer.eos_token_id),
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
# Return only the assistant part
gen_tokens = out[0][inputs["input_ids"].shape[-1]:]
text = tokenizer.decode(gen_tokens, skip_special_tokens=True, errors="ignore")
# If your <|end|> isn’t marked as special, strip it manually
text = text.replace("<|end|>", "").strip()
return text
# ------------- UI -------------
demo = gr.ChatInterface(
fn=chat_generate,
title="OPT-350M + LoRA (Chris style)",
description="Loads the base model from the adapter's config, merges LoRA, and chats using your training tokens.",
additional_inputs=[
gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"),
gr.Slider(16, 512, value=256, step=16, label="Max new tokens"),
gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"),
],
examples=[
["What are you up to?", 0.7, 0.95, 256, 1.1],
["You coming?", 0.7, 0.95, 256, 1.1],
["I'm on the can", 0.7, 0.95, 256, 1.1],
],
cache_examples=False,
)
if __name__ == "__main__":
# queue helps avoid device contention; hide API to avoid schema issues
demo.queue(max_size=8)
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)