driwnet commited on
Commit
e080f49
1 Parent(s): 07ee737

Upload enytrenador_model.py

Browse files
Files changed (1) hide show
  1. enytrenador_model.py +108 -0
enytrenador_model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MODIFIED: (efv) Use STSb-multi-mt Spanish
3
+ source: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py
4
+ ---
5
+ This examples trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) for the STSbenchmark from scratch. It generates sentence embeddings
6
+ that can be compared using cosine-similarity to measure the similarity.
7
+ Usage:
8
+ python training_nli.py
9
+ OR
10
+ python training_nli.py pretrained_transformer_model_name
11
+ """
12
+ from torch.utils.data import DataLoader
13
+ import math
14
+ from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util
15
+ from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
16
+ from sentence_transformers.readers import InputExample
17
+ import logging
18
+ from datetime import datetime
19
+ import sys
20
+ import os
21
+ import gzip
22
+ import csv
23
+
24
+ from datasets import load_dataset
25
+
26
+ #### Just some code to print debug information to stdout
27
+ logging.basicConfig(format='%(asctime)s - %(message)s',
28
+ datefmt='%Y-%m-%d %H:%M:%S',
29
+ level=logging.INFO,
30
+ handlers=[LoggingHandler()])
31
+ #### /print debug information to stdout
32
+
33
+
34
+
35
+ #You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
36
+ model_name = sys.argv[1] if len(sys.argv) > 1 else 'distilbert-base-uncased'
37
+
38
+ # Read the dataset
39
+ train_batch_size = 16
40
+ num_epochs = 4
41
+ model_save_path = 'output/training_stsbenchmark_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
42
+
43
+ # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
44
+ word_embedding_model = models.Transformer(model_name)
45
+
46
+ # Apply mean pooling to get one fixed sized sentence vector
47
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
48
+ pooling_mode_mean_tokens=True,
49
+ pooling_mode_cls_token=False,
50
+ pooling_mode_max_tokens=False)
51
+
52
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
53
+
54
+ # Convert the dataset to a DataLoader ready for training
55
+ logging.info("Read stsb-multi-mt train dataset de mis documentos")
56
+
57
+ train_samples = []
58
+ dev_samples = []
59
+ test_samples = []
60
+
61
+ def samples_from_dataset(dataset):
62
+ samples = [InputExample(texts=[e['sentence1'], e['sentence2']], label=e['similarity_score'] / 5) \
63
+ for e in dataset]
64
+ return samples
65
+
66
+ print("vamos a cargar")
67
+ train_samples = load_dataset("csv", name="Bases_dades\Catala\stsb-ca-train.csv",split="train", column_names = ['sentence1', 'sentence2', 'similarity_score'] )
68
+ print("cargada dataset")
69
+ train_samples = samples_from_dataset(train_samples)
70
+ print("Samples del train creades")
71
+ print("Cargar dev samples")
72
+ dev_samples = samples_from_dataset(load_dataset("csv", name="Bases_dades\Catala\stsb-ca-dev.csv", split="validation", column_names = ['sentence1', 'sentence2', 'similarity_score']))
73
+ print("dev samples creades")
74
+ print("Cargar test samples")
75
+ test_samples = samples_from_dataset(load_dataset("csv", name="Bases_dades\Catala\stsb-ca-test.csv", split="test", column_names = ['sentence1', 'sentence2', 'similarity_score']))
76
+ print("Test samples creades")
77
+
78
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
79
+ train_loss = losses.CosineSimilarityLoss(model=model)
80
+
81
+
82
+ logging.info("Read stsb-multi-mt dev dataset")
83
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')
84
+
85
+
86
+ # Configure the training. We skip evaluation in this example
87
+ warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
88
+ logging.info("Warmup-steps: {}".format(warmup_steps))
89
+
90
+
91
+ ## Train the model
92
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
93
+ evaluator=evaluator,
94
+ epochs=num_epochs,
95
+ evaluation_steps=1000,
96
+ warmup_steps=warmup_steps,
97
+ output_path=model_save_path)
98
+
99
+
100
+ ##############################################################################
101
+ #
102
+ # Load the stored model and evaluate its performance on STS benchmark dataset
103
+ #
104
+ ##############################################################################
105
+
106
+ #model = SentenceTransformer(model_save_path)
107
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='stsb-multi-mt-test')
108
+ test_evaluator(model, output_path=model_save_path)