File size: 2,056 Bytes
6a06532
c2637f8
6a06532
043aa62
 
6a06532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
043aa62
 
6a06532
 
043aa62
 
 
 
 
6a06532
043aa62
 
6a06532
cbecf0e
6a06532
043aa62
 
 
 
 
 
6a06532
 
 
 
043aa62
cbecf0e
043aa62
 
cbecf0e
043aa62
cbecf0e
 
 
6a06532
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from datasets import load_dataset, concatenate_datasets
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# create train dataset
seed = 20
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
    sampled_datasets.append(
        dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))

# concatenate the sampled datasets
train_dataset = concatenate_datasets(sampled_datasets)

# create test dataset
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
    sampled_datasets.append(
        dataset["test"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
test_dataset = concatenate_datasets(sampled_datasets)

# Load a SetFit model from Hub
model_id = "sentence-transformers/all-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=64,
    num_iterations= 20, # The number of text pairs to generate for contrastive learning
    num_epochs=1,  # The number of epochs to use for constrastive learning
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

print(f"model used: {model_id}")
print(f"train dataset: {len(train_dataset)} samples")
print(f"accuracy: {metrics['accuracy']}")

# Push model to the Hub
trainer.model.save_pretrained("my_first_test")

# Download from Hub and run inference
model = SetFitModel.from_pretrained("my_first_test")
# Run inference
preds = model(["i loved France!", "pineapple on pizza is the worst when watching football"])
label = {'0': 'World','1': 'Sports', '2': 'Business', '3': 'Sci/Tech'}
output = [label[str(tt.item())] for tt in preds]
q = 1