| |
| """ |
| Quick smoke-test script: load a fine-tuned adapter and generate a single prediction. |
| Use the same prompt format as finetune_hx_unsloth.py to keep behavior aligned. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| import pandas as pd |
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
|
| FEW_SHOT_EXAMPLES = [ |
| { |
| "text": ( |
| "in britain the prime minister a childless woman london mayor a moslem invader in command " |
| "of london police a lesbian in command of london fire brigade a fat woman scotland first " |
| "minister a childless woman" |
| ), |
| "label": "Hate", |
| }, |
| { |
| "text": ( |
| "the non partisan congressional budget office cbo has released their report on sens " |
| "lindsey graham and dick durbin the dream act the report shows that their legislation " |
| "would add billions to the deficit over a ten year period and lots more immigrant scum" |
| ), |
| "label": "Hate", |
| }, |
| { |
| "text": ( |
| "i live and work with many legal mexican immigrants who are great citizens and trump supporters " |
| "they have no problem with deporting illegals maga" |
| ), |
| "label": "Not Hate", |
| }, |
| ] |
|
|
|
|
| def build_prompt(post: str) -> str: |
| post = (post or "").strip() |
| sections = [ |
| "<system>: You are a content moderation assistant. Label each post as hatespeech, offensive, or normal according to the HateXplain policy. Classify the following social media post. Respond with exactly one label: hatespeech, offensive, or normal." |
| ] |
| for idx, example in enumerate(FEW_SHOT_EXAMPLES, start=1): |
| sections.append(f"Example {idx}:") |
| sections.append(f"<human>: {example['text']}") |
| sections.append(f"<bot>: {example['label']}") |
| sections.append("") |
| sections.append(f"<human>: {post}") |
| sections.append("<bot>:") |
| return "\n".join(sections).strip() |
|
|
|
|
| def fetch_post(args: argparse.Namespace) -> str: |
| if args.text: |
| return args.text |
| df = pd.read_csv(args.dataset_path) |
| if args.sample_index < 0 or args.sample_index >= len(df): |
| raise IndexError(f"sample_index {args.sample_index} outside dataset size {len(df)}") |
| return df.iloc[args.sample_index][args.text_col] |
|
|
|
|
| DEFAULT_ADAPTER_DIR = Path(__file__).resolve().parent.parent / "adapter" |
| DEFAULT_BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
|
|
| def load_adapter( |
| adapter_dir: str, |
| base_model: str | None = None, |
| load_in_4bit: bool = True, |
| ): |
| adapter_path = Path(adapter_dir) |
| if base_model is None: |
| base_model = DEFAULT_BASE_MODEL |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| adapter_path, use_fast=True, trust_remote_code=True |
| ) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| quantization_config = None |
| if load_in_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| device_map="auto", |
| torch_dtype=torch_dtype, |
| trust_remote_code=True, |
| quantization_config=quantization_config, |
| ) |
| model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Generate a single HX prediction using a fine-tuned adapter.") |
| parser.add_argument( |
| "--adapter-dir", |
| default=str(DEFAULT_ADAPTER_DIR), |
| help="Path to the trained adapter directory.", |
| ) |
| parser.add_argument( |
| "--base-model", |
| default=None, |
| help="Optional base model path/ID; if omitted we rely on adapter_config metadata.", |
| ) |
| parser.add_argument( |
| "--dataset-path", |
| default="hx.csv", |
| help="CSV to sample from when --text is not provided.", |
| ) |
| parser.add_argument("--text-col", default="text", help="Column containing the post text.") |
| parser.add_argument("--sample-index", type=int, default=0, help="Row index to read from the dataset.") |
| parser.add_argument("--text", default=None, help="Override text instead of pulling from the dataset.") |
| parser.add_argument("--max-new-tokens", type=int, default=8, help="Generation length.") |
| parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature (0 for greedy).") |
| parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling cutoff.") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| post = fetch_post(args) |
| prompt = build_prompt(post) |
|
|
| model, tokenizer = load_adapter( |
| adapter_dir=args.adapter_dir, |
| base_model=args.base_model, |
| load_in_4bit=True, |
| ) |
|
|
| tokenizer.padding_side = "left" |
| tokenizer.truncation_side = "left" |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| do_sample=args.temperature > 0, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| completion = tokenizer.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip() |
| print("Prompt:\n", prompt) |
| print("\nModel completion:\n", completion) |
|
|
|
|
| if __name__ == "__main__": |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| main() |
|
|