sv-task / src /models /dataset.py
lamossta's picture
models and inference classes
51620d3
import json
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from src.schemas.labels import SENTIMENT_LABELS
def load_data(path: str) -> list[dict]:
with open(path) as f:
return [json.loads(line) for line in f]
def deduplicate_positions(samples: list[dict]) -> list[dict]:
"""Select one position per entity.
Prefers the position whose position_text matches entity_text exactly
(case-insensitive). If none matches, selects the longest position.
"""
out = []
for s in samples:
new_entities = []
for e in s["entities"]:
positions = e["positions"]
if not positions:
new_entities.append(e)
continue
exact = [
p for p in positions
if p["position_text"].lower() == e["entity_text"].lower()
]
if exact:
best = max(exact, key=lambda p: p["length"])
else:
best = max(positions, key=lambda p: p["length"])
new_entities.append({**e, "positions": [best]})
out.append({**s, "entities": new_entities})
return out
def flatten_to_examples(
samples: list[dict],
mode: str,
) -> list[dict]:
"""Flatten augmented data to one example per (entity, position) pair.
Reads pre-computed fields from the augmented JSONL:
marker -> seg_a = marker_text, seg_b = None
qa_m -> seg_a = entity_centered_window, seg_b = qa_m_question
qa_b -> 3 binary examples per position using qa_b_hypotheses
"""
sentiments = list(SENTIMENT_LABELS.classes)
label2id = SENTIMENT_LABELS.label2id
examples = []
for s in samples:
for e in s["entities"]:
label_str = e.get("label")
base = {
"sample_id": s["id"],
"entity_id": e["entity_id"],
"entity_text": e["entity_text"],
"entity_type": e["entity_type"],
}
for p in e["positions"]:
if mode == "marker":
ex = {**base, "seg_a": p["marker_text"], "seg_b": None}
if label_str in label2id:
ex["label"] = label2id[label_str]
examples.append(ex)
elif mode == "qa_m":
ex = {
**base,
"seg_a": p["entity_centered_window"],
"seg_b": p["qa_m_question"],
}
if label_str in label2id:
ex["label"] = label2id[label_str]
examples.append(ex)
elif mode == "qa_b":
for sentiment in sentiments:
ex = {
**base,
"seg_a": p["entity_centered_window"],
"seg_b": p["qa_b_hypotheses"][sentiment],
"sentiment": sentiment,
}
if label_str in label2id:
ex["label"] = 1 if sentiment == label_str else 0
examples.append(ex)
else:
raise ValueError(f"Unknown mode: {mode!r}")
return examples
def split_data(
examples: list[dict], val_frac: float, test_frac: float, seed: int = 42
) -> tuple[list[dict], list[dict], list[dict]]:
"""Split at the *sample* level"""
sample_ids = np.array(list({e["sample_id"] for e in examples}))
remaining_ids, test_ids = train_test_split(
sample_ids, test_size=test_frac, random_state=seed
)
val_frac_adj = val_frac / (1.0 - test_frac)
train_ids, val_ids = train_test_split(
remaining_ids, test_size=val_frac_adj, random_state=seed
)
train_set = set(train_ids)
val_set = set(val_ids)
test_set = set(test_ids)
return (
[e for e in examples if e["sample_id"] in train_set],
[e for e in examples if e["sample_id"] in val_set],
[e for e in examples if e["sample_id"] in test_set],
)
class EntitySentimentDataset(Dataset):
def __init__(self, examples: list[dict], tokenizer, max_len: int):
self.examples = examples
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> dict:
ex = self.examples[idx]
seg_a = ex["seg_a"]
seg_b = ex["seg_b"]
if seg_b is None:
enc = self.tokenizer(
seg_a,
max_length=self.max_len,
truncation=True,
padding="max_length",
return_tensors="pt",
)
else:
enc = self.tokenizer(
seg_a, seg_b,
max_length=self.max_len,
truncation="only_first",
padding="max_length",
return_tensors="pt",
)
item = {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
}
if "label" in ex:
item["labels"] = torch.tensor(ex["label"], dtype=torch.long)
return item
class DeduplicatedEntitySentimentDataset(EntitySentimentDataset):
"""Like EntitySentimentDataset but with one position per entity.
Applies deduplicate_positions before flattening, so each entity
contributes exactly one training example.
"""
def __init__(self, samples: list[dict], mode: str, tokenizer, max_len: int):
deduped = deduplicate_positions(samples)
examples = flatten_to_examples(deduped, mode=mode)
super().__init__(examples, tokenizer, max_len)