|
from datasets import load_dataset |
|
from sentence_transformers.losses import CosineSimilarityLoss |
|
|
|
from setfit import SetFitModel, SetFitTrainer, sample_dataset |
|
|
|
|
|
|
|
dataset = load_dataset("sst2") |
|
|
|
|
|
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) |
|
eval_dataset = dataset["validation"] |
|
|
|
|
|
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") |
|
|
|
|
|
trainer = SetFitTrainer( |
|
model=model, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
loss_class=CosineSimilarityLoss, |
|
metric="accuracy", |
|
batch_size=16, |
|
num_iterations=20, |
|
num_epochs=1, |
|
column_mapping={"sentence": "text", "label": "label"} |
|
) |
|
|
|
|
|
trainer.train() |
|
metrics = trainer.evaluate() |
|
|
|
|
|
trainer.push_to_hub("my-awesome-setfit-model") |
|
|
|
|
|
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model") |
|
|
|
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]) |