FreddyM commited on
Commit
f011140
1 Parent(s): aee4465

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +131 -2
README.md CHANGED
@@ -85,8 +85,137 @@ Finetuned from model [optional]: The model used might have been fine-tuned from
85
  Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
86
 
87
  ## How to Get Started with the Model
88
-
89
- Use the code below to get started with the model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  [More Information Needed]
92
 
 
85
  Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
86
 
87
  ## How to Get Started with the Model
88
+ from transformers import BertTokenizer, BertForSequenceClassification
89
+
90
+ # Número de etiquetas/clases en tu problema de clasificación
91
+ num_etiquetas = 2 # Actualiza con el número correcto de clases
92
+ #1 Descargar y cargar el modelo BERT para clasificación:
93
+ # Descargar el tokenizador y el modelo preentrenado BERT para clasificación
94
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
95
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_etiquetas)
96
+
97
+ from transformers import BertTokenizer, BertForSequenceClassification
98
+ #2. Configuración del optimizador y del dispositivo:
99
+
100
+ from torch.optim import AdamW
101
+
102
+ # Parámetros de optimización
103
+ optimizador = AdamW(model.parameters(), lr=5e-5)
104
+
105
+ # Dispositivo (GPU si está disponible, de lo contrario, CPU)
106
+ dispositivo = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
107
+ model.to(dispositivo)
108
+
109
+ # 3 División del conjunto de datos y creación de DataLoader:
110
+ from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
111
+ from sklearn.model_selection import train_test_split
112
+
113
+ # División del conjunto de datos
114
+ train_idx, val_idx = train_test_split(np.arange(len(labels)), test_size=val_ratio, shuffle=True, stratify=labels)
115
+
116
+ # Creación de DataLoader para entrenamiento
117
+ train_dataloader = DataLoader(
118
+ TensorDataset(token_id[train_idx], attention_masks[train_idx], labels[train_idx]),
119
+ sampler=RandomSampler(train_idx),
120
+ batch_size=batch_size
121
+ )
122
+
123
+ # Creación de DataLoader para validación
124
+ val_dataloader = DataLoader(
125
+ TensorDataset(token_id[val_idx], attention_masks[val_idx], labels[val_idx]),
126
+ sampler=SequentialSampler(val_idx),
127
+ batch_size=batch_size
128
+ )
129
+
130
+
131
+ from sklearn.metrics import precision_score
132
+
133
+ # ...
134
+
135
+ #4Entrenamiento del modelo BERT para clasificación:
136
+
137
+ num_epochs = 3 # ajusta el número de épocas según sea necesario
138
+
139
+ # Ciclo de entrenamiento
140
+ for epoch in trange(num_epochs, desc='Epoch'):
141
+ model.train()
142
+
143
+ for step, batch in enumerate(train_dataloader):
144
+ batch = tuple(t.to(dispositivo) for t in batch)
145
+ input_ids, attention_mask, labels = batch
146
+
147
+ optimizador.zero_grad()
148
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
149
+ loss = outputs.loss
150
+ loss.backward()
151
+ optimizador.step()
152
+
153
+ # Evaluación en el conjunto de validación después de cada época
154
+ model.eval()
155
+
156
+ # Tracking variables
157
+ val_accuracy = []
158
+ val_precision = []
159
+
160
+ for batch in val_dataloader: # Cambiado a val_dataloader en lugar de validation_dataloader
161
+ batch = tuple(t.to(dispositivo) for t in batch)
162
+ b_input_ids, b_input_mask, b_labels = batch
163
+ with torch.no_grad():
164
+ # Forward pass
165
+ eval_output = model(
166
+ b_input_ids,
167
+ token_type_ids=None,
168
+ attention_mask=b_input_mask
169
+ )
170
+ logits = eval_output.logits.detach().cpu().numpy()
171
+ label_ids = b_labels.to('cpu').numpy()
172
+
173
+ # Calculate validation metrics
174
+ b_accuracy, _, _, b_precision = b_metrics(logits, label_ids)
175
+ val_accuracy.append(b_accuracy)
176
+ val_precision.append(b_precision)
177
+
178
+ # Calcular métricas promedio para la época
179
+ avg_val_accuracy = sum(val_accuracy) / len(val_accuracy)
180
+ avg_val_precision = sum(val_precision) / len(val_precision) if len(val_precision) > 0 else float('nan')
181
+
182
+ # Imprimir resultados de la época
183
+ print(f'\nEpoch {epoch + 1}/{num_epochs}')
184
+ print(f' - Training Loss: {loss.item()}')
185
+ print(f' - Validation Accuracy: {avg_val_accuracy}')
186
+ print(f' - Validation Precision: {avg_val_precision}')
187
+
188
+ # Predicción en un nuevo ejemplo
189
+ nueva_oracion = "Nah I don't think he goes to usf, he lives around here though"
190
+
191
+ # Aplicar el tokenizer para obtener los IDs de tokens y la máscara de atención
192
+ encoding = tokenizer.encode_plus(
193
+ nueva_oracion,
194
+ add_special_tokens=True,
195
+ max_length=32, # Ajusta la longitud máxima según sea necesario
196
+ pad_to_max_length=True,
197
+ return_attention_mask=True,
198
+ return_tensors='pt' # Devuelve tensores de PyTorch
199
+ )
200
+
201
+ # Obtener los IDs de tokens y la máscara de atención
202
+ input_ids = encoding['input_ids'].to(dispositivo)
203
+ attention_mask = encoding['attention_mask'].to(dispositivo)
204
+
205
+ # Asegurarse de que las dimensiones sean adecuadas para el modelo BERT
206
+ input_ids = input_ids.view(1, -1) # Cambiar la forma a (1, longitud)
207
+ attention_mask = attention_mask.view(1, -1) # Cambiar la forma a (1, longitud)
208
+
209
+ # Realizar la predicción
210
+ with torch.no_grad():
211
+ output = model(input_ids, attention_mask=attention_mask)
212
+
213
+ # Obtener la clase predicha
214
+ prediccion = 'Clase A' if torch.argmax(output.logits[0]).item() == 0 else 'Clase B'
215
+
216
+ # Imprimir resultados
217
+ print(f'Nueva Oración: {nueva_oracion}')
218
+ print(f'Predicción: {prediccion}')
219
 
220
  [More Information Needed]
221