Fahmid
Initial model upload
51439a1
import argparse
import os
from typing import List, Tuple
import torch
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report
def read_conll_2col(path: str) -> Tuple[List[List[str]], List[List[str]]]:
"""Reads 2-column CoNLL (TOKEN TAG) with blank lines between sentences."""
toks, labs = [], []
all_toks, all_labs = [], []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.rstrip("\n")
if not line:
if toks:
all_toks.append(toks)
all_labs.append(labs)
toks, labs = [], []
continue
parts = line.split()
if len(parts) < 2:
# tolerate malformed lines
continue
tok, tag = parts[0], parts[-1]
toks.append(tok)
labs.append(tag)
if toks:
all_toks.append(toks)
all_labs.append(labs)
return all_toks, all_labs
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="outputs/bert-base-cased-timeNER",
help="Path to the fine-tuned model directory (with config.json, tokenizer files, weights).")
parser.add_argument("--test_path", type=str, default="data/test.conll",
help="Path to 2-column CoNLL test file.")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_length", type=int, default=256)
args = parser.parse_args()
assert os.path.exists(args.model_dir), f"Model dir not found: {args.model_dir}"
assert os.path.exists(args.test_path), f"Test file not found: {args.test_path}"
# Load tokenizer & model
print(f"Loading model from: {args.model_dir}")
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(args.model_dir)
model.eval()
# Labels
id2label = model.config.id2label
label2id = model.config.label2id
labels_sorted = [id2label[i] for i in range(len(id2label))]
print(f"Model labels: {labels_sorted}")
# Read test set
print(f"Reading test set: {args.test_path}")
tokens_list, tags_list = read_conll_2col(args.test_path)
num_sents = len(tokens_list)
num_tokens = sum(len(s) for s in tokens_list)
print(f"Loaded {num_sents} sentences / {num_tokens} tokens")
# Sanity check: test labels should be subset of model labels
uniq_test_labels = sorted({t for seq in tags_list for t in seq})
missing = [t for t in uniq_test_labels if t not in label2id]
if missing:
print(f"⚠️ Warning: test labels not in model: {missing}")
# Build a simple dataset for batching convenience
ds = Dataset.from_dict({"tokens": tokens_list, "ner_tags": tags_list})
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Evaluation loop (word-piece alignment; label only first subword)
all_preds: List[List[str]] = []
all_refs: List[List[str]] = []
# iterate in batches
for start in range(0, len(ds), args.batch_size):
batch = ds[start : start + args.batch_size]
batch_tokens = batch["tokens"] # List[List[str]]
batch_refs = batch["ner_tags"] # List[List[str]]
# Tokenize as split words so we can map back using fast tokenizer encodings
encodings = tokenizer(
batch_tokens,
is_split_into_words=True,
truncation=True,
max_length=args.max_length,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
logits = model(
input_ids=encodings["input_ids"].to(device),
attention_mask=encodings["attention_mask"].to(device),
token_type_ids=encodings.get("token_type_ids", None).to(device) if "token_type_ids" in encodings else None,
).logits # (bsz, seq_len, num_labels)
pred_ids = logits.argmax(dim=-1).cpu().numpy() # (bsz, seq_len)
# Recover word-level predictions using encodings.word_ids()
for i, word_labels in enumerate(batch_refs):
encoding = encodings.encodings[i] # fast tokenizer encoding
word_ids = encoding.word_ids # list[Optional[int]] aligned to tokens
seq_pred_ids = pred_ids[i]
word_level_preds: List[str] = []
seen_word = None
for tok_idx, wid in enumerate(word_ids):
if wid is None:
continue
if wid != seen_word:
# first subword of this word → take prediction
label_id = int(seq_pred_ids[tok_idx])
word_level_preds.append(id2label[label_id])
seen_word = wid
else:
# subsequent subwords → skip (standard NER eval)
continue
# Trim to same length as references (in case of truncation)
L = min(len(word_labels), len(word_level_preds))
all_refs.append(word_labels[:L])
all_preds.append(word_level_preds[:L])
# Metrics
p = precision_score(all_refs, all_preds)
r = recall_score(all_refs, all_preds)
f1 = f1_score(all_refs, all_preds)
print("\n Results on test set")
print(f"Precision: {p:.4f}")
print(f"Recall : {r:.4f}")
print(f"F1 : {f1:.4f}")
print("\nSeqeval classification report")
print(classification_report(all_refs, all_preds, digits=4))
if __name__ == "__main__":
main()