Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import ast | |
| import json | |
| from collections import Counter | |
| from functools import partial | |
| from pathlib import Path | |
| import numpy as np | |
| from datasets import Dataset, DatasetDict | |
| from sklearn.metrics import accuracy_score, f1_score | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| from paper_classifier import BASE_MODEL_NAME, DEFAULT_MODEL_DIR, MAX_LENGTH, format_input_text | |
| DATA_PATH = Path("arxivData.json") | |
| OUTPUT_DIR = Path(DEFAULT_MODEL_DIR) | |
| HF_CACHE_DIR = Path("/tmp/huggingface") | |
| TITLE_FIELD = "title" | |
| ABSTRACT_FIELD = "summary" | |
| TAG_FIELD = "tag" | |
| VALIDATION_SIZE = 0.1 | |
| NUM_TRAIN_EPOCHS = 4 | |
| LEARNING_RATE = 2e-5 | |
| WEIGHT_DECAY = 0.01 | |
| PER_DEVICE_TRAIN_BATCH_SIZE = 16 | |
| PER_DEVICE_EVAL_BATCH_SIZE = 32 | |
| LOGGING_STEPS = 50 | |
| SEED = 42 | |
| PREFIX_TO_LABEL = { | |
| "adap-org": "Quantitative Biology", | |
| "astro-ph": "Physics", | |
| "cmp-lg": "Computer Science", | |
| "cond-mat": "Physics", | |
| "cs": "Computer Science", | |
| "econ": "Economics", | |
| "eess": "Electrical Engineering and Systems Science", | |
| "gr-qc": "Physics", | |
| "hep-ex": "Physics", | |
| "hep-lat": "Physics", | |
| "hep-ph": "Physics", | |
| "hep-th": "Physics", | |
| "math": "Mathematics", | |
| "nlin": "Physics", | |
| "nucl-th": "Physics", | |
| "physics": "Physics", | |
| "q-bio": "Quantitative Biology", | |
| "q-fin": "Quantitative Finance", | |
| "quant-ph": "Physics", | |
| "stat": "Statistics", | |
| } | |
| def normalize_text(value): | |
| return " ".join(str(value or "").split()) | |
| def parse_top_level_label(raw_tag): | |
| if not raw_tag: | |
| return None | |
| try: | |
| parsed_tags = ast.literal_eval(str(raw_tag)) | |
| except (SyntaxError, ValueError): | |
| return None | |
| if not isinstance(parsed_tags, list): | |
| return None | |
| for tag in parsed_tags: | |
| if not isinstance(tag, dict): | |
| continue | |
| term = tag.get("term") | |
| if not term: | |
| continue | |
| prefix = str(term).split(".")[0] | |
| label = PREFIX_TO_LABEL.get(prefix) | |
| if label: | |
| return label | |
| return None | |
| def build_records(): | |
| with DATA_PATH.open("r", encoding="utf-8") as file: | |
| raw_records = json.load(file) | |
| prepared_records: list[dict[str, str]] = [] | |
| skipped = Counter() | |
| for item in raw_records: | |
| title = normalize_text(item.get(TITLE_FIELD)) | |
| abstract = normalize_text(item.get(ABSTRACT_FIELD)) | |
| label = parse_top_level_label(item.get(TAG_FIELD)) | |
| text = format_input_text(title, abstract) | |
| prepared_records.append( | |
| { | |
| "text": text, | |
| "label": label, | |
| } | |
| ) | |
| print(f"Loaded {len(prepared_records)}") | |
| label_distribution = Counter(record["label"] for record in prepared_records) | |
| print("Label distribution:", dict(label_distribution)) | |
| return prepared_records | |
| def build_splits(records): | |
| dataset = Dataset.from_list(records) | |
| split = dataset.train_test_split(test_size=VALIDATION_SIZE, seed=SEED) | |
| return DatasetDict(train=split["train"], validation=split["test"]) | |
| def preprocess(batch, *, tokenizer, label2id): | |
| tokenized = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH) | |
| tokenized["labels"] = [label2id[label] for label in batch["label"]] | |
| return tokenized | |
| def compute_metrics(eval_prediction): | |
| logits, labels = eval_prediction | |
| predictions = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": accuracy_score(labels, predictions), | |
| "macro_f1": f1_score(labels, predictions, average="macro"), | |
| } | |
| def main() -> None: | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Dataset file not found: {DATA_PATH}") | |
| set_seed(SEED) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| records = build_records() | |
| raw_splits = build_splits(records) | |
| label_names = sorted({record["label"] for record in records}) | |
| label2id = {label: index for index, label in enumerate(label_names)} | |
| id2label = {index: label for label, index in label2id.items()} | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL_NAME, | |
| cache_dir=HF_CACHE_DIR.as_posix(), | |
| ) | |
| tokenized_splits = raw_splits.map( | |
| partial(preprocess, tokenizer=tokenizer, label2id=label2id), | |
| batched=True, | |
| remove_columns=raw_splits["train"].column_names, | |
| ) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL_NAME, | |
| cache_dir=HF_CACHE_DIR.as_posix(), | |
| num_labels=len(label_names), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR.as_posix(), | |
| do_train=True, | |
| do_eval=True, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| logging_strategy="steps", | |
| logging_steps=LOGGING_STEPS, | |
| learning_rate=LEARNING_RATE, | |
| weight_decay=WEIGHT_DECAY, | |
| num_train_epochs=NUM_TRAIN_EPOCHS, | |
| per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, | |
| per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="macro_f1", | |
| greater_is_better=True, | |
| save_total_limit=2, | |
| report_to=[], | |
| seed=SEED, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_splits["train"], | |
| eval_dataset=tokenized_splits["validation"], | |
| processing_class=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| metrics = trainer.evaluate() | |
| trainer.save_model(OUTPUT_DIR.as_posix()) | |
| tokenizer.save_pretrained(OUTPUT_DIR.as_posix()) | |
| summary_path = OUTPUT_DIR / "training_summary.json" | |
| summary = { | |
| "base_model": BASE_MODEL_NAME, | |
| "data_path": DATA_PATH.as_posix(), | |
| "output_dir": OUTPUT_DIR.as_posix(), | |
| "title_field": TITLE_FIELD, | |
| "abstract_field": ABSTRACT_FIELD, | |
| "tag_field": TAG_FIELD, | |
| "labels": label_names, | |
| "metrics": metrics, | |
| } | |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |