MyFirstModel / newtest.py
vincenttruum
test
cbecf0e
raw history blame
No virus
2.06 kB
from datasets import load_dataset, concatenate_datasets
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# create train dataset
seed = 20
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
sampled_datasets.append(
dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
# concatenate the sampled datasets
train_dataset = concatenate_datasets(sampled_datasets)
# create test dataset
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
sampled_datasets.append(
dataset["test"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
test_dataset = concatenate_datasets(sampled_datasets)
# Load a SetFit model from Hub
model_id = "sentence-transformers/all-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id)
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=64,
num_iterations= 20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for constrastive learning
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(f"model used: {model_id}")
print(f"train dataset: {len(train_dataset)} samples")
print(f"accuracy: {metrics['accuracy']}")
# Push model to the Hub
trainer.model.save_pretrained("my_first_test")
# Download from Hub and run inference
model = SetFitModel.from_pretrained("my_first_test")
# Run inference
preds = model(["i loved France!", "pineapple on pizza is the worst when watching football"])
label = {'0': 'World','1': 'Sports', '2': 'Business', '3': 'Sci/Tech'}
output = [label[str(tt.item())] for tt in preds]
q = 1