Spaces:
Sleeping
Sleeping
Add cross validation
Browse files- app/cli.py +24 -9
- app/model.py +34 -5
- notebook.ipynb +11 -3
app/cli.py
CHANGED
@@ -90,6 +90,13 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
90 |
show_default=True,
|
91 |
type=click.IntRange(1, None),
|
92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
@click.option(
|
94 |
"--seed",
|
95 |
default=42,
|
@@ -97,19 +104,26 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
97 |
show_default=True,
|
98 |
type=click.IntRange(-1, None),
|
99 |
)
|
|
|
|
|
|
|
|
|
|
|
100 |
def train(
|
101 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
102 |
max_features: int,
|
|
|
103 |
seed: int,
|
|
|
104 |
) -> None:
|
105 |
"""Train the model on the provided dataset"""
|
106 |
import joblib
|
107 |
|
108 |
from app.constants import MODELS_DIR
|
109 |
-
from app.model import create_model, load_data, train_model
|
110 |
|
111 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
112 |
-
if model_path.exists():
|
113 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
114 |
|
115 |
click.echo("Preprocessing dataset... ", nl=False)
|
@@ -122,16 +136,17 @@ def train(
|
|
122 |
|
123 |
# click.echo("Training model... ", nl=False)
|
124 |
click.echo("Training model... ")
|
125 |
-
accuracy = train_model(model, text_data, label_data)
|
126 |
-
joblib.dump(model, model_path)
|
127 |
-
click.echo("Model saved to: ", nl=False)
|
128 |
-
click.secho(str(model_path), fg="blue")
|
129 |
-
|
130 |
click.echo("Model accuracy: ", nl=False)
|
131 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
def cli_wrapper() -> None:
|
|
|
90 |
show_default=True,
|
91 |
type=click.IntRange(1, None),
|
92 |
)
|
93 |
+
@click.option(
|
94 |
+
"--cv",
|
95 |
+
default=5,
|
96 |
+
help="Number of cross-validation folds",
|
97 |
+
show_default=True,
|
98 |
+
type=click.IntRange(1, 50),
|
99 |
+
)
|
100 |
@click.option(
|
101 |
"--seed",
|
102 |
default=42,
|
|
|
104 |
show_default=True,
|
105 |
type=click.IntRange(-1, None),
|
106 |
)
|
107 |
+
@click.option(
|
108 |
+
"--force",
|
109 |
+
is_flag=True,
|
110 |
+
help="Overwrite the model file if it already exists",
|
111 |
+
)
|
112 |
def train(
|
113 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
114 |
max_features: int,
|
115 |
+
cv: int,
|
116 |
seed: int,
|
117 |
+
force: bool,
|
118 |
) -> None:
|
119 |
"""Train the model on the provided dataset"""
|
120 |
import joblib
|
121 |
|
122 |
from app.constants import MODELS_DIR
|
123 |
+
from app.model import create_model, evaluate_model, load_data, train_model
|
124 |
|
125 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
126 |
+
if model_path.exists() and not force:
|
127 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
128 |
|
129 |
click.echo("Preprocessing dataset... ", nl=False)
|
|
|
136 |
|
137 |
# click.echo("Training model... ", nl=False)
|
138 |
click.echo("Training model... ")
|
139 |
+
accuracy, text_test, text_label = train_model(model, text_data, label_data)
|
|
|
|
|
|
|
|
|
140 |
click.echo("Model accuracy: ", nl=False)
|
141 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
142 |
|
143 |
+
click.echo("Model saved to: ", nl=False)
|
144 |
+
joblib.dump(model, model_path)
|
145 |
+
click.secho(str(model_path), fg="blue")
|
146 |
+
|
147 |
+
click.echo("Evaluating model... ", nl=False)
|
148 |
+
acc_mean, acc_std = evaluate_model(model, text_test, text_label, cv=cv)
|
149 |
+
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
150 |
|
151 |
|
152 |
def cli_wrapper() -> None:
|
app/model.py
CHANGED
@@ -13,7 +13,7 @@ from nltk.stem import WordNetLemmatizer
|
|
13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
15 |
from sklearn.linear_model import LogisticRegression
|
16 |
-
from sklearn.model_selection import train_test_split
|
17 |
from sklearn.pipeline import Pipeline
|
18 |
|
19 |
from app.constants import (
|
@@ -28,7 +28,7 @@ from app.constants import (
|
|
28 |
URL_REGEX,
|
29 |
)
|
30 |
|
31 |
-
__all__ = ["load_data", "create_model", "train_model"]
|
32 |
|
33 |
|
34 |
class TextCleaner(BaseEstimator, TransformerMixin):
|
@@ -293,7 +293,7 @@ def train_model(
|
|
293 |
text_data: list[str],
|
294 |
label_data: list[int],
|
295 |
seed: int = 42,
|
296 |
-
) -> float:
|
297 |
"""Train the sentiment analysis model.
|
298 |
|
299 |
Args:
|
@@ -303,7 +303,7 @@ def train_model(
|
|
303 |
seed: Random seed (None for random seed)
|
304 |
|
305 |
Returns:
|
306 |
-
|
307 |
"""
|
308 |
text_train, text_test, label_train, label_test = train_test_split(
|
309 |
text_data,
|
@@ -316,4 +316,33 @@ def train_model(
|
|
316 |
warnings.simplefilter("ignore")
|
317 |
model.fit(text_train, label_train)
|
318 |
|
319 |
-
return model.score(text_test, label_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
15 |
from sklearn.linear_model import LogisticRegression
|
16 |
+
from sklearn.model_selection import cross_val_score, train_test_split
|
17 |
from sklearn.pipeline import Pipeline
|
18 |
|
19 |
from app.constants import (
|
|
|
28 |
URL_REGEX,
|
29 |
)
|
30 |
|
31 |
+
__all__ = ["load_data", "create_model", "train_model", "evaluate_model"]
|
32 |
|
33 |
|
34 |
class TextCleaner(BaseEstimator, TransformerMixin):
|
|
|
293 |
text_data: list[str],
|
294 |
label_data: list[int],
|
295 |
seed: int = 42,
|
296 |
+
) -> tuple[float, list[str], list[int]]:
|
297 |
"""Train the sentiment analysis model.
|
298 |
|
299 |
Args:
|
|
|
303 |
seed: Random seed (None for random seed)
|
304 |
|
305 |
Returns:
|
306 |
+
Model accuracy and test data
|
307 |
"""
|
308 |
text_train, text_test, label_train, label_test = train_test_split(
|
309 |
text_data,
|
|
|
316 |
warnings.simplefilter("ignore")
|
317 |
model.fit(text_train, label_train)
|
318 |
|
319 |
+
return model.score(text_test, label_test), text_test, label_test
|
320 |
+
|
321 |
+
|
322 |
+
def evaluate_model(
|
323 |
+
model: Pipeline,
|
324 |
+
text_test: list[str],
|
325 |
+
label_test: list[int],
|
326 |
+
cv: int = 5,
|
327 |
+
) -> tuple[float, float]:
|
328 |
+
"""Evaluate the model using cross-validation.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
model: Trained model
|
332 |
+
text_test: Text data
|
333 |
+
label_test: Label data
|
334 |
+
seed: Random seed (None for random seed)
|
335 |
+
cv: Number of cross-validation folds
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
Mean accuracy and standard deviation
|
339 |
+
"""
|
340 |
+
scores = cross_val_score(
|
341 |
+
model,
|
342 |
+
text_test,
|
343 |
+
label_test,
|
344 |
+
cv=cv,
|
345 |
+
scoring="accuracy",
|
346 |
+
n_jobs=-1,
|
347 |
+
)
|
348 |
+
return scores.mean(), scores.std()
|
notebook.ipynb
CHANGED
@@ -668,9 +668,17 @@
|
|
668 |
},
|
669 |
{
|
670 |
"cell_type": "code",
|
671 |
-
"execution_count":
|
672 |
"metadata": {},
|
673 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
"source": [
|
675 |
"# SVM\n",
|
676 |
"svm_clf = SVC(random_state=SEED)\n",
|
@@ -680,7 +688,7 @@
|
|
680 |
" svm_clf,\n",
|
681 |
" {\n",
|
682 |
" \"C\": np.logspace(-4, 4, 20),\n",
|
683 |
-
" \"kernel\": [\"linear\", \"poly\", \"rbf\"
|
684 |
" \"degree\": [2, 3, 4],\n",
|
685 |
" },\n",
|
686 |
")\n",
|
|
|
668 |
},
|
669 |
{
|
670 |
"cell_type": "code",
|
671 |
+
"execution_count": 24,
|
672 |
"metadata": {},
|
673 |
+
"outputs": [
|
674 |
+
{
|
675 |
+
"name": "stdout",
|
676 |
+
"output_type": "stream",
|
677 |
+
"text": [
|
678 |
+
"Fitting 3 folds for each of 10 candidates, totalling 30 fits\n"
|
679 |
+
]
|
680 |
+
}
|
681 |
+
],
|
682 |
"source": [
|
683 |
"# SVM\n",
|
684 |
"svm_clf = SVC(random_state=SEED)\n",
|
|
|
688 |
" svm_clf,\n",
|
689 |
" {\n",
|
690 |
" \"C\": np.logspace(-4, 4, 20),\n",
|
691 |
+
" \"kernel\": [\"linear\", \"poly\", \"rbf\"],\n",
|
692 |
" \"degree\": [2, 3, 4],\n",
|
693 |
" },\n",
|
694 |
")\n",
|