Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel, PeftConfig | |
| # ---- CONFIG ---- | |
| ADAPTER_REPO = "richardprobe/opt-350-chris-adapter" # your LoRA repo | |
| ADAPTER_NAME = "finetune_adapter" # how you saved it | |
| SYSTEM_PROMPT = "You are Richard. Be concise and casual." | |
| # If the adapter is private on the Hub, set HF_TOKEN in the Space secrets | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| # ------------- Loading ------------- | |
| def load_model_and_tokenizer(): | |
| # Inspect adapter to get its base | |
| print("Reading adapter config...") | |
| peft_cfg = PeftConfig.from_pretrained(ADAPTER_REPO, token=HF_TOKEN) | |
| base_id = peft_cfg.base_model_name_or_path | |
| print(f"Base model detected: {base_id}") | |
| # Tokenizer from base (adapter may also carry added tokens) | |
| print("Loading tokenizer...") | |
| tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=HF_TOKEN) | |
| # Safety: many decoder-only models don't define a pad token | |
| if tok.pad_token is None and tok.eos_token is not None: | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "right" | |
| # Non-quantized load so we can merge | |
| print("Loading base model...") | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_id, torch_dtype=dtype, device_map="auto", token=HF_TOKEN | |
| ) | |
| print("Loading adapter and merging...") | |
| peft = PeftModel.from_pretrained( | |
| base, ADAPTER_REPO, adapter_name=ADAPTER_NAME, token=HF_TOKEN | |
| ) | |
| # This bakes LoRA weights into the base weights and returns a plain model | |
| merged = peft.merge_and_unload() # equivalent to merge_adapter + unload | |
| merged.eval() | |
| # We’ll use <|end|> as EOS if it exists | |
| try: | |
| end_id = tok.convert_tokens_to_ids("<|end|>") | |
| if end_id is not None and end_id != tok.unk_token_id: | |
| merged.config.eos_token_id = end_id | |
| except Exception: | |
| pass | |
| return tok, merged | |
| tokenizer, model = load_model_and_tokenizer() | |
| # ------------- Prompt building ------------- | |
| def build_prompt(history, user_msg): | |
| """ | |
| Render your chat format using the added tokens that were used during training. | |
| History is a list of (user, assistant) tuples from ChatInterface. | |
| """ | |
| segments = [] | |
| if SYSTEM_PROMPT: | |
| # If you trained with a system token, add it here. Otherwise keep as plain text. | |
| segments.append(f"<|system|>{SYSTEM_PROMPT}<|end|>") | |
| for u, a in history or []: | |
| if u: | |
| segments.append(f"<|user|>{u}<|end|>") | |
| if a: | |
| segments.append(f"<|assistant|>{a}<|end|>") | |
| segments.append(f"<|user|>{user_msg}<|end|>") | |
| segments.append("<|assistant|>") | |
| return "\n".join(segments) | |
| # ------------- Inference ------------- | |
| def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1): | |
| prompt = build_prompt(history, message) | |
| inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| gen_kwargs = dict( | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=float(temperature) > 0, | |
| repetition_penalty=float(repetition_penalty), | |
| eos_token_id=getattr(model.config, "eos_token_id", tokenizer.eos_token_id), | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| with torch.inference_mode(): | |
| out = model.generate(**inputs, **gen_kwargs) | |
| # Return only the assistant part | |
| gen_tokens = out[0][inputs["input_ids"].shape[-1]:] | |
| text = tokenizer.decode(gen_tokens, skip_special_tokens=True, errors="ignore") | |
| # If your <|end|> isn’t marked as special, strip it manually | |
| text = text.replace("<|end|>", "").strip() | |
| return text | |
| # ------------- UI ------------- | |
| demo = gr.ChatInterface( | |
| fn=chat_generate, | |
| title="OPT-350M + LoRA (Chris style)", | |
| description="Loads the base model from the adapter's config, merges LoRA, and chats using your training tokens.", | |
| additional_inputs=[ | |
| gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"), | |
| gr.Slider(16, 512, value=256, step=16, label="Max new tokens"), | |
| gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"), | |
| ], | |
| examples=[ | |
| ["What are you up to?", 0.7, 0.95, 256, 1.1], | |
| ["You coming?", 0.7, 0.95, 256, 1.1], | |
| ["I'm on the can", 0.7, 0.95, 256, 1.1], | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| # queue helps avoid device contention; hide API to avoid schema issues | |
| demo.queue(max_size=8) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True) | |