vincenttruum
commited on
Commit
•
043aa62
1
Parent(s):
7820bb0
test
Browse files- newtest.py +39 -2
newtest.py
CHANGED
@@ -1,3 +1,40 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
from setfit import SetFitModel, SetFitTrainer
|
3 |
from sentence_transformers.losses import CosineSimilarityLoss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
|
|
2 |
from sentence_transformers.losses import CosineSimilarityLoss
|
3 |
+
|
4 |
+
from setfit import SetFitModel, SetFitTrainer, sample_dataset
|
5 |
+
|
6 |
+
|
7 |
+
# Load a dataset from the Hugging Face Hub
|
8 |
+
dataset = load_dataset("sst2")
|
9 |
+
|
10 |
+
# Simulate the few-shot regime by sampling 8 examples per class
|
11 |
+
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
|
12 |
+
eval_dataset = dataset["validation"]
|
13 |
+
|
14 |
+
# Load a SetFit model from Hub
|
15 |
+
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
|
16 |
+
|
17 |
+
# Create trainer
|
18 |
+
trainer = SetFitTrainer(
|
19 |
+
model=model,
|
20 |
+
train_dataset=train_dataset,
|
21 |
+
eval_dataset=eval_dataset,
|
22 |
+
loss_class=CosineSimilarityLoss,
|
23 |
+
metric="accuracy",
|
24 |
+
batch_size=16,
|
25 |
+
num_iterations=20, # The number of text pairs to generate for contrastive learning
|
26 |
+
num_epochs=1, # The number of epochs to use for contrastive learning
|
27 |
+
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
|
28 |
+
)
|
29 |
+
|
30 |
+
# Train and evaluate
|
31 |
+
trainer.train()
|
32 |
+
metrics = trainer.evaluate()
|
33 |
+
|
34 |
+
# Push model to the Hub
|
35 |
+
trainer.push_to_hub("my-awesome-setfit-model")
|
36 |
+
|
37 |
+
# Download from Hub and run inference
|
38 |
+
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
|
39 |
+
# Run inference
|
40 |
+
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
|