supcon / code /inference.py
IGandarillas1's picture
Add model contrastive classifier
fc1c2b8
raw
history blame
2.05 kB
import numpy as np
np.random.seed(42)
import random
random.seed(42)
import pandas as pd
from sklearn.metrics import classification_report
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import json
from copy import deepcopy
import torch
import transformers as transformers
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed
)
from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from src.datasets import ContrastiveClassificationDataset
from src.data_collators import DataCollatorContrastiveClassification
from src.modeling import ContrastiveClassifierModel
from src.metrics import compute_metrics_bce
from transformers import EarlyStoppingCallback
from transformers.utils.hp_naming import TrialShortNamer
from pdb import set_trace
import json
def model_fn(model_dir):
tokenizer = AutoTokenizer.from_pretrained('roberta-base', additional_special_tokens=('[COL]', '[VAL]'))
model = ContrastiveClassifierModel(checkpoint_path=model_dir, len_tokenizer=len(tokenizer), model='roberta-base', frozen=False)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
test_dataset = ContrastiveClassificationDataset(data["inputs"], dataset_type='test', size=512, tokenizer='roberta-base', dataset='serialized')
data_collator = DataCollatorContrastiveClassification(tokenizer)
trainer = Trainer(
model=model,
data_collator=data_collator,
compute_metrics=compute_metrics_bce,
)
predict_results = trainer.predict(test_dataset,metric_key_prefix="predict")
df = test_dataset.data
df['prediction'] = predict_results.predictions
return {"values": df[df['prediction']==1].values.tolist()}