Tymec commited on
Commit
204391c
·
1 Parent(s): 0993d5e

Use stopwords from NLTK and download NLTK data

Browse files
Files changed (2) hide show
  1. app/cli.py +6 -4
  2. app/model.py +16 -1
app/cli.py CHANGED
@@ -117,15 +117,17 @@ def train(
117
  click.echo(DONE_STR)
118
 
119
  click.echo("Creating model... ", nl=False)
120
- model = create_model(max_features, seed=None if seed == -1 else seed)
121
  click.echo(DONE_STR)
122
 
123
- click.echo("Training model... ", nl=False)
 
124
  accuracy = train_model(model, text_data, label_data)
125
  joblib.dump(model, model_path)
126
- click.echo(DONE_STR)
 
127
 
128
- click.echo("Model accuracy: ")
129
  click.secho(f"{accuracy:.2%}", fg="blue")
130
 
131
  # TODO: Add hyperparameter options
 
117
  click.echo(DONE_STR)
118
 
119
  click.echo("Creating model... ", nl=False)
120
+ model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
121
  click.echo(DONE_STR)
122
 
123
+ # click.echo("Training model... ", nl=False)
124
+ click.echo("Training model... ")
125
  accuracy = train_model(model, text_data, label_data)
126
  joblib.dump(model, model_path)
127
+ click.echo("Model saved to: ", nl=False)
128
+ click.secho(str(model_path), fg="blue")
129
 
130
+ click.echo("Model accuracy: ", nl=False)
131
  click.secho(f"{accuracy:.2%}", fg="blue")
132
 
133
  # TODO: Add hyperparameter options
app/model.py CHANGED
@@ -5,8 +5,10 @@ import re
5
  import warnings
6
  from typing import Literal
7
 
 
8
  import pandas as pd
9
  from joblib import Memory
 
10
  from nltk.stem import WordNetLemmatizer
11
  from sklearn.base import BaseEstimator, TransformerMixin
12
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
@@ -248,28 +250,41 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t
248
  def create_model(
249
  max_features: int,
250
  seed: int | None = None,
 
251
  ) -> Pipeline:
252
  """Create a sentiment analysis model.
253
 
254
  Args:
255
  max_features: Maximum number of features
256
  seed: Random seed (None for random seed)
 
257
 
258
  Returns:
259
  Untrained model
260
  """
 
 
 
 
 
 
 
261
  return Pipeline(
262
  [
263
  # Text preprocessing
264
  ("clean", TextCleaner()),
265
  ("lemma", TextLemmatizer()),
266
  # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
267
- ("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=max_features)),
 
 
 
268
  ("tfidf", TfidfTransformer()),
269
  # Classifier
270
  ("clf", LogisticRegression(max_iter=1000, random_state=seed)),
271
  ],
272
  memory=Memory(CACHE_DIR, verbose=0),
 
273
  )
274
 
275
 
 
5
  import warnings
6
  from typing import Literal
7
 
8
+ import nltk
9
  import pandas as pd
10
  from joblib import Memory
11
+ from nltk.corpus import stopwords
12
  from nltk.stem import WordNetLemmatizer
13
  from sklearn.base import BaseEstimator, TransformerMixin
14
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
 
250
  def create_model(
251
  max_features: int,
252
  seed: int | None = None,
253
+ verbose: bool = False,
254
  ) -> Pipeline:
255
  """Create a sentiment analysis model.
256
 
257
  Args:
258
  max_features: Maximum number of features
259
  seed: Random seed (None for random seed)
260
+ verbose: Whether to log progress during training
261
 
262
  Returns:
263
  Untrained model
264
  """
265
+ # Download NLTK data if not already downloaded
266
+ nltk.download("wordnet", quiet=True)
267
+ nltk.download("stopwords", quiet=True)
268
+
269
+ # Load English stopwords
270
+ stopwords_en = set(stopwords.words("english"))
271
+
272
  return Pipeline(
273
  [
274
  # Text preprocessing
275
  ("clean", TextCleaner()),
276
  ("lemma", TextLemmatizer()),
277
  # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
278
+ (
279
+ "vectorize",
280
+ CountVectorizer(stop_words=stopwords_en, ngram_range=(1, 2), max_features=max_features),
281
+ ),
282
  ("tfidf", TfidfTransformer()),
283
  # Classifier
284
  ("clf", LogisticRegression(max_iter=1000, random_state=seed)),
285
  ],
286
  memory=Memory(CACHE_DIR, verbose=0),
287
+ verbose=verbose,
288
  )
289
 
290