Spaces:
Sleeping
Sleeping
Update training_bert.py
Browse files- training_bert.py +8 -4
training_bert.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import re
|
|
|
|
|
4 |
from sklearn.model_selection import GroupShuffleSplit
|
|
|
|
|
|
|
5 |
def remove_links(review):
|
6 |
pattern = r'\bhttps?://\S+'
|
7 |
return re.sub(pattern, '', review)
|
@@ -34,7 +39,7 @@ y_val = val.Score
|
|
34 |
X_test = test.drop(columns = 'Score')
|
35 |
y_test = test.Score
|
36 |
|
37 |
-
|
38 |
base_model = 'bert-base-cased'
|
39 |
learning_rate = 2e-5
|
40 |
max_length = 64
|
@@ -78,7 +83,7 @@ def compute_metrics_for_regression(eval_pred):
|
|
78 |
|
79 |
return {"mse": mse, "mae": mae, "r2": r2, "accuracy": accuracy}
|
80 |
|
81 |
-
|
82 |
|
83 |
output_dir = ".."
|
84 |
|
@@ -94,8 +99,7 @@ training_args = TrainingArguments(
|
|
94 |
load_best_model_at_end=True,
|
95 |
weight_decay=0.01,
|
96 |
)
|
97 |
-
|
98 |
-
import torch
|
99 |
class RegressionTrainer(Trainer):
|
100 |
def compute_loss(self, model, inputs, return_outputs=False):
|
101 |
labels = inputs.pop("labels")
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import re
|
4 |
+
from transformers import Trainer
|
5 |
+
import torch
|
6 |
from sklearn.model_selection import GroupShuffleSplit
|
7 |
+
from transformers import AutoTokenizer,AutoModelForSequenceClassification
|
8 |
+
from transformers import TrainingArguments
|
9 |
+
|
10 |
def remove_links(review):
|
11 |
pattern = r'\bhttps?://\S+'
|
12 |
return re.sub(pattern, '', review)
|
|
|
39 |
X_test = test.drop(columns = 'Score')
|
40 |
y_test = test.Score
|
41 |
|
42 |
+
|
43 |
base_model = 'bert-base-cased'
|
44 |
learning_rate = 2e-5
|
45 |
max_length = 64
|
|
|
83 |
|
84 |
return {"mse": mse, "mae": mae, "r2": r2, "accuracy": accuracy}
|
85 |
|
86 |
+
|
87 |
|
88 |
output_dir = ".."
|
89 |
|
|
|
99 |
load_best_model_at_end=True,
|
100 |
weight_decay=0.01,
|
101 |
)
|
102 |
+
|
|
|
103 |
class RegressionTrainer(Trainer):
|
104 |
def compute_loss(self, model, inputs, return_outputs=False):
|
105 |
labels = inputs.pop("labels")
|