vincenttruum
commited on
Commit
•
6a06532
1
Parent(s):
043aa62
test
Browse files- newtest.py +40 -18
newtest.py
CHANGED
@@ -1,40 +1,62 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
from sentence_transformers.losses import CosineSimilarityLoss
|
3 |
-
|
4 |
-
from setfit import SetFitModel, SetFitTrainer, sample_dataset
|
5 |
-
|
6 |
|
7 |
# Load a dataset from the Hugging Face Hub
|
8 |
-
dataset = load_dataset("
|
9 |
-
|
10 |
-
#
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Load a SetFit model from Hub
|
15 |
-
|
|
|
16 |
|
17 |
# Create trainer
|
18 |
trainer = SetFitTrainer(
|
19 |
model=model,
|
20 |
train_dataset=train_dataset,
|
21 |
-
eval_dataset=
|
22 |
loss_class=CosineSimilarityLoss,
|
23 |
metric="accuracy",
|
24 |
-
batch_size=
|
25 |
-
num_iterations=20, # The number of text pairs to generate for contrastive learning
|
26 |
-
num_epochs=1,
|
27 |
-
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
|
28 |
)
|
29 |
|
30 |
# Train and evaluate
|
31 |
trainer.train()
|
32 |
metrics = trainer.evaluate()
|
33 |
|
|
|
|
|
|
|
|
|
34 |
# Push model to the Hub
|
35 |
-
trainer.push_to_hub("
|
36 |
|
37 |
# Download from Hub and run inference
|
38 |
-
model = SetFitModel.from_pretrained("
|
39 |
# Run inference
|
40 |
-
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
|
|
|
|
1 |
+
from datasets import load_dataset, concatenate_datasets
|
2 |
from sentence_transformers.losses import CosineSimilarityLoss
|
3 |
+
from setfit import SetFitModel, SetFitTrainer
|
|
|
|
|
4 |
|
5 |
# Load a dataset from the Hugging Face Hub
|
6 |
+
dataset = load_dataset("ag_news")
|
7 |
+
|
8 |
+
# create train dataset
|
9 |
+
seed = 20
|
10 |
+
labels = 4
|
11 |
+
samples_per_label = 8
|
12 |
+
sampled_datasets = []
|
13 |
+
# find the number of samples per label
|
14 |
+
for i in range(labels):
|
15 |
+
sampled_datasets.append(
|
16 |
+
dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
|
17 |
+
|
18 |
+
# concatenate the sampled datasets
|
19 |
+
train_dataset = concatenate_datasets(sampled_datasets)
|
20 |
+
|
21 |
+
# create test dataset
|
22 |
+
labels = 4
|
23 |
+
samples_per_label = 8
|
24 |
+
sampled_datasets = []
|
25 |
+
# find the number of samples per label
|
26 |
+
for i in range(labels):
|
27 |
+
sampled_datasets.append(
|
28 |
+
dataset["test"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
|
29 |
+
test_dataset = concatenate_datasets(sampled_datasets)
|
30 |
|
31 |
# Load a SetFit model from Hub
|
32 |
+
model_id = "sentence-transformers/all-mpnet-base-v2"
|
33 |
+
model = SetFitModel.from_pretrained(model_id)
|
34 |
|
35 |
# Create trainer
|
36 |
trainer = SetFitTrainer(
|
37 |
model=model,
|
38 |
train_dataset=train_dataset,
|
39 |
+
eval_dataset=test_dataset,
|
40 |
loss_class=CosineSimilarityLoss,
|
41 |
metric="accuracy",
|
42 |
+
batch_size=64,
|
43 |
+
num_iterations=1, # 20, # The number of text pairs to generate for contrastive learning
|
44 |
+
num_epochs=1, # The number of epochs to use for constrastive learning
|
|
|
45 |
)
|
46 |
|
47 |
# Train and evaluate
|
48 |
trainer.train()
|
49 |
metrics = trainer.evaluate()
|
50 |
|
51 |
+
print(f"model used: {model_id}")
|
52 |
+
print(f"train dataset: {len(train_dataset)} samples")
|
53 |
+
print(f"accuracy: {metrics['accuracy']}")
|
54 |
+
|
55 |
# Push model to the Hub
|
56 |
+
trainer.push_to_hub("MyFirstModel")
|
57 |
|
58 |
# Download from Hub and run inference
|
59 |
+
model = SetFitModel.from_pretrained("VinceItsMe/MyFirstModel")
|
60 |
# Run inference
|
61 |
+
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
|
62 |
+
q = 1
|