Tymec commited on
Commit
2c1f9dd
1 Parent(s): e50b20c

Tokenization rework

Browse files
Files changed (4) hide show
  1. app/cli.py +97 -26
  2. app/data.py +115 -6
  3. app/gui.py +4 -1
  4. app/model.py +62 -87
app/cli.py CHANGED
@@ -55,6 +55,8 @@ def predict(model_path: Path, text: list[str]) -> None:
55
 
56
  import joblib
57
 
 
 
58
  text = " ".join(text).strip()
59
  if not sys.stdin.isatty():
60
  piped_text = sys.stdin.read().strip()
@@ -69,7 +71,8 @@ def predict(model_path: Path, text: list[str]) -> None:
69
  click.echo(DONE_STR)
70
 
71
  click.echo("Performing sentiment analysis... ", nl=False)
72
- prediction = model.predict([text])[0]
 
73
  if prediction == 0:
74
  sentiment = click.style("NEGATIVE", fg="red")
75
  elif prediction == 1:
@@ -82,9 +85,9 @@ def predict(model_path: Path, text: list[str]) -> None:
82
  @cli.command()
83
  @click.option(
84
  "--dataset",
85
- required=True,
86
- help="Dataset to train the model on",
87
- type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
88
  )
89
  @click.option(
90
  "--model",
@@ -100,27 +103,65 @@ def predict(model_path: Path, text: list[str]) -> None:
100
  show_default=True,
101
  type=click.IntRange(1, 50),
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def evaluate(
104
- dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
105
  model_path: Path,
106
  cv: int,
 
 
 
107
  ) -> None:
108
- """Evaluate the model on the test dataset"""
109
  import joblib
110
 
111
- from app.data import load_data
 
112
  from app.model import evaluate_model
113
 
114
- click.echo("Loading dataset... ", nl=False)
115
- text_data, label_data = load_data(dataset)
116
- click.echo(DONE_STR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  click.echo("Loading model... ", nl=False)
119
  model = joblib.load(model_path)
120
  click.echo(DONE_STR)
121
 
122
  click.echo("Evaluating model... ", nl=False)
123
- acc_mean, acc_std = evaluate_model(model, text_data, label_data, folds=cv)
124
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
125
 
126
 
@@ -145,6 +186,18 @@ def evaluate(
145
  show_default=True,
146
  type=click.IntRange(1, 50),
147
  )
 
 
 
 
 
 
 
 
 
 
 
 
148
  @click.option(
149
  "--seed",
150
  default=42,
@@ -157,45 +210,63 @@ def evaluate(
157
  is_flag=True,
158
  help="Overwrite the model file if it already exists",
159
  )
 
 
 
 
 
160
  def train(
161
  dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
162
  max_features: int,
163
  cv: int,
 
 
164
  seed: int,
165
  force: bool,
 
166
  ) -> None:
167
  """Train the model on the provided dataset"""
168
  import joblib
169
 
170
- from app.constants import MODELS_DIR
171
- from app.data import load_data
172
- from app.model import create_model, evaluate_model, train_model
173
 
174
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
175
  if model_path.exists() and not force:
176
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
177
 
178
- click.echo("Loading dataset... ", nl=False)
179
- text_data, label_data = load_data(dataset)
180
- click.echo(DONE_STR)
 
181
 
182
- click.echo("Creating model... ", nl=False)
183
- model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
184
- click.echo(DONE_STR)
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  click.echo("Training model... ")
187
- accuracy = train_model(model, text_data, label_data)
 
188
  click.echo("Model accuracy: ", nl=False)
189
  click.secho(f"{accuracy:.2%}", fg="blue")
190
 
191
  click.echo("Model saved to: ", nl=False)
192
- joblib.dump(model, model_path)
193
  click.secho(str(model_path), fg="blue")
194
 
195
- click.echo("Evaluating model... ", nl=False)
196
- acc_mean, acc_std = evaluate_model(model, text_data, label_data, folds=cv)
197
- click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
198
-
199
 
200
  def cli_wrapper() -> None:
201
  cli(max_content_width=120)
 
55
 
56
  import joblib
57
 
58
+ from app.model import infer_model
59
+
60
  text = " ".join(text).strip()
61
  if not sys.stdin.isatty():
62
  piped_text = sys.stdin.read().strip()
 
71
  click.echo(DONE_STR)
72
 
73
  click.echo("Performing sentiment analysis... ", nl=False)
74
+ prediction = infer_model(model, [text])[0]
75
+ # prediction = model.predict([text])[0]
76
  if prediction == 0:
77
  sentiment = click.style("NEGATIVE", fg="red")
78
  elif prediction == 1:
 
85
  @cli.command()
86
  @click.option(
87
  "--dataset",
88
+ default="test",
89
+ help="Dataset to evaluate the model on",
90
+ type=click.Choice(["test", "sentiment140", "amazonreviews", "imdb50k"]),
91
  )
92
  @click.option(
93
  "--model",
 
103
  show_default=True,
104
  type=click.IntRange(1, 50),
105
  )
106
+ @click.option(
107
+ "--batch-size",
108
+ default=512,
109
+ help="Size of the batches used in tokenization",
110
+ show_default=True,
111
+ )
112
+ @click.option(
113
+ "--processes",
114
+ default=8,
115
+ help="Number of parallel jobs during tokenization",
116
+ show_default=True,
117
+ )
118
+ @click.option(
119
+ "--verbose",
120
+ is_flag=True,
121
+ help="Show verbose output",
122
+ )
123
  def evaluate(
124
+ dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"],
125
  model_path: Path,
126
  cv: int,
127
+ batch_size: int,
128
+ processes: int,
129
+ verbose: bool,
130
  ) -> None:
131
+ """Evaluate the model on the the specified dataset"""
132
  import joblib
133
 
134
+ from app.constants import CACHE_DIR
135
+ from app.data import load_data, tokenize
136
  from app.model import evaluate_model
137
 
138
+ cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
139
+ use_cached_data = False
140
+ if cached_data_path.exists():
141
+ use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
142
+
143
+ if use_cached_data:
144
+ click.echo("Loading cached data... ", nl=False)
145
+ token_data, label_data = joblib.load(cached_data_path)
146
+ click.echo(DONE_STR)
147
+ else:
148
+ click.echo("Loading dataset... ", nl=False)
149
+ text_data, label_data = load_data(dataset)
150
+ click.echo(DONE_STR)
151
+
152
+ click.echo("Tokenizing data... ", nl=False)
153
+ token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
154
+ joblib.dump((token_data, label_data), cached_data_path, compress=3)
155
+ click.echo(DONE_STR)
156
+
157
+ del text_data
158
 
159
  click.echo("Loading model... ", nl=False)
160
  model = joblib.load(model_path)
161
  click.echo(DONE_STR)
162
 
163
  click.echo("Evaluating model... ", nl=False)
164
+ acc_mean, acc_std = evaluate_model(model, token_data, label_data, folds=cv, verbose=verbose)
165
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
166
 
167
 
 
186
  show_default=True,
187
  type=click.IntRange(1, 50),
188
  )
189
+ @click.option(
190
+ "--batch-size",
191
+ default=512,
192
+ help="Size of the batches used in tokenization",
193
+ show_default=True,
194
+ )
195
+ @click.option(
196
+ "--processes",
197
+ default=8,
198
+ help="Number of parallel jobs during tokenization",
199
+ show_default=True,
200
+ )
201
  @click.option(
202
  "--seed",
203
  default=42,
 
210
  is_flag=True,
211
  help="Overwrite the model file if it already exists",
212
  )
213
+ @click.option(
214
+ "--verbose",
215
+ is_flag=True,
216
+ help="Show verbose output",
217
+ )
218
  def train(
219
  dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
220
  max_features: int,
221
  cv: int,
222
+ batch_size: int,
223
+ processes: int,
224
  seed: int,
225
  force: bool,
226
+ verbose: bool,
227
  ) -> None:
228
  """Train the model on the provided dataset"""
229
  import joblib
230
 
231
+ from app.constants import CACHE_DIR, MODELS_DIR
232
+ from app.data import load_data, tokenize
233
+ from app.model import create_model, train_model
234
 
235
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
236
  if model_path.exists() and not force:
237
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
238
 
239
+ cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
240
+ use_cached_data = False
241
+ if cached_data_path.exists():
242
+ use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
243
 
244
+ if use_cached_data:
245
+ click.echo("Loading cached data... ", nl=False)
246
+ token_data, label_data = joblib.load(cached_data_path)
247
+ click.echo(DONE_STR)
248
+ else:
249
+ click.echo("Loading dataset... ", nl=False)
250
+ text_data, label_data = load_data(dataset)
251
+ click.echo(DONE_STR)
252
+
253
+ click.echo("Tokenizing data... ", nl=False)
254
+ token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
255
+ joblib.dump((token_data, label_data), cached_data_path, compress=3)
256
+ click.echo(DONE_STR)
257
+
258
+ del text_data
259
 
260
  click.echo("Training model... ")
261
+ model = create_model(max_features, seed=None if seed == -1 else seed, verbose=verbose)
262
+ trained_model, accuracy = train_model(model, token_data, label_data, folds=cv, seed=seed, verbose=verbose)
263
  click.echo("Model accuracy: ", nl=False)
264
  click.secho(f"{accuracy:.2%}", fg="blue")
265
 
266
  click.echo("Model saved to: ", nl=False)
267
+ joblib.dump(trained_model, model_path, compress=3)
268
  click.secho(str(model_path), fg="blue")
269
 
 
 
 
 
270
 
271
  def cli_wrapper() -> None:
272
  cli(max_content_width=120)
app/data.py CHANGED
@@ -1,9 +1,11 @@
1
  from __future__ import annotations
2
 
3
  import bz2
4
- from typing import Literal
5
 
6
  import pandas as pd
 
 
7
 
8
  from app.constants import (
9
  AMAZONREVIEWS_PATH,
@@ -12,9 +14,76 @@ from app.constants import (
12
  IMDB50K_URL,
13
  SENTIMENT140_PATH,
14
  SENTIMENT140_URL,
 
 
15
  )
16
 
17
- __all__ = ["load_data"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
@@ -104,9 +173,6 @@ def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
104
  # Split the data into labels and text
105
  labels, texts = zip(*(line.split(" ", 1) for line in dataset)) # NOTE: Occasionally OOM
106
 
107
- # Free up memory
108
- del dataset
109
-
110
  # Map sentiment values
111
  sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
112
 
@@ -147,7 +213,48 @@ def load_imdb50k() -> tuple[list[str], list[int]]:
147
  return data["review"].tolist(), data["sentiment"].tolist()
148
 
149
 
150
- def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  """Load and preprocess the specified dataset.
152
 
153
  Args:
@@ -166,6 +273,8 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t
166
  return load_amazonreviews(merge=True)
167
  case "imdb50k":
168
  return load_imdb50k()
 
 
169
  case _:
170
  msg = f"Unknown dataset: {dataset}"
171
  raise ValueError(msg)
 
1
  from __future__ import annotations
2
 
3
  import bz2
4
+ from typing import TYPE_CHECKING, Literal
5
 
6
  import pandas as pd
7
+ import spacy
8
+ from tqdm import tqdm
9
 
10
  from app.constants import (
11
  AMAZONREVIEWS_PATH,
 
14
  IMDB50K_URL,
15
  SENTIMENT140_PATH,
16
  SENTIMENT140_URL,
17
+ TEST_DATASET_PATH,
18
+ TEST_DATASET_URL,
19
  )
20
 
21
+ if TYPE_CHECKING:
22
+ from spacy.tokens import Doc
23
+
24
+ __all__ = ["load_data", "tokenize"]
25
+
26
+
27
+ try:
28
+ nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
29
+ except OSError:
30
+ print("Downloading spaCy model...")
31
+
32
+ from spacy.cli import download as spacy_download
33
+
34
+ spacy_download("en_core_web_sm")
35
+ nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
36
+
37
+
38
+ def _lemmatize(doc: Doc, threshold: int = 2) -> list[str]:
39
+ """Lemmatize the provided text using spaCy.
40
+
41
+ Args:
42
+ doc: spaCy document
43
+ threshold: Minimum character length of tokens
44
+
45
+ Returns:
46
+ Lemmatized text
47
+ """
48
+ return [
49
+ token.lemma_.lower().strip()
50
+ for token in doc
51
+ if not token.is_stop
52
+ and not token.is_punct
53
+ and not token.like_email
54
+ and not token.like_url
55
+ and not token.like_num
56
+ and not (len(token.lemma_) < threshold)
57
+ ]
58
+
59
+
60
+ def tokenize(
61
+ text_data: list[str],
62
+ batch_size: int = 512,
63
+ n_jobs: int = 4,
64
+ character_threshold: int = 2,
65
+ show_progress: bool = True,
66
+ ) -> list[list[str]]:
67
+ """Tokenize the provided text using spaCy.
68
+
69
+ Args:
70
+ text_data: Text data to tokenize
71
+ batch_size: Batch size for tokenization
72
+ n_jobs: Number of parallel jobs
73
+ character_threshold: Minimum character length of tokens
74
+ show_progress: Whether to show a progress bar
75
+
76
+ Returns:
77
+ Tokenized text data
78
+ """
79
+ return [
80
+ _lemmatize(doc, character_threshold)
81
+ for doc in tqdm(
82
+ nlp.pipe(text_data, batch_size=batch_size, n_process=n_jobs),
83
+ total=len(text_data),
84
+ disable=not show_progress,
85
+ )
86
+ ]
87
 
88
 
89
  def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
 
173
  # Split the data into labels and text
174
  labels, texts = zip(*(line.split(" ", 1) for line in dataset)) # NOTE: Occasionally OOM
175
 
 
 
 
176
  # Map sentiment values
177
  sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
178
 
 
213
  return data["review"].tolist(), data["sentiment"].tolist()
214
 
215
 
216
+ def load_test(include_neutral: bool = False) -> tuple[list[str], list[int]]:
217
+ """Load the test dataset and make it suitable for use.
218
+
219
+ Args:
220
+ include_neutral: Whether to include neutral sentiment
221
+
222
+ Returns:
223
+ Text and label data
224
+
225
+ Raises:
226
+ FileNotFoundError: If the dataset is not found
227
+ """
228
+ # Check if the dataset exists
229
+ if not TEST_DATASET_PATH.exists():
230
+ msg = (
231
+ f"Test dataset not found at: '{TEST_DATASET_PATH}'\n"
232
+ "Please download the dataset from:\n"
233
+ f"{TEST_DATASET_URL}"
234
+ )
235
+ raise FileNotFoundError(msg)
236
+
237
+ # Load the dataset
238
+ data = pd.read_csv(TEST_DATASET_PATH)
239
+
240
+ # Ignore rows with neutral sentiment
241
+ if not include_neutral:
242
+ data = data[data["label"] != 1]
243
+
244
+ # Map sentiment values
245
+ data["label"] = data["label"].map(
246
+ {
247
+ 0: 0, # Negative
248
+ 1: 1, # Neutral
249
+ 2: 2, # Positive
250
+ },
251
+ )
252
+
253
+ # Return as lists
254
+ return data["text"].tolist(), data["label"].tolist()
255
+
256
+
257
+ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k", "test"]) -> tuple[list[str], list[int]]:
258
  """Load and preprocess the specified dataset.
259
 
260
  Args:
 
273
  return load_amazonreviews(merge=True)
274
  case "imdb50k":
275
  return load_imdb50k()
276
+ case "test":
277
+ return load_test(include_neutral=False)
278
  case _:
279
  msg = f"Unknown dataset: {dataset}"
280
  raise ValueError(msg)
app/gui.py CHANGED
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING
7
  import gradio as gr
8
  import joblib
9
 
 
 
10
  if TYPE_CHECKING:
11
  from sklearn.base import BaseEstimator
12
 
@@ -31,7 +33,7 @@ def load_model() -> BaseEstimator:
31
  def sentiment_analysis(text: str) -> str:
32
  """Perform sentiment analysis on the provided text."""
33
  model = load_model()
34
- prediction = model.predict([text])[0]
35
 
36
  if prediction == 0:
37
  return NEGATIVE_LABEL
@@ -52,6 +54,7 @@ demo = gr.Interface(
52
  ["The movie we watched was boring."],
53
  ["This website is amazing!"],
54
  ],
 
55
  )
56
 
57
 
 
7
  import gradio as gr
8
  import joblib
9
 
10
+ from app.model import infer_model
11
+
12
  if TYPE_CHECKING:
13
  from sklearn.base import BaseEstimator
14
 
 
33
  def sentiment_analysis(text: str) -> str:
34
  """Perform sentiment analysis on the provided text."""
35
  model = load_model()
36
+ prediction = infer_model(model, [text])[0]
37
 
38
  if prediction == 0:
39
  return NEGATIVE_LABEL
 
54
  ["The movie we watched was boring."],
55
  ["This website is amazing!"],
56
  ],
57
+ allow_flagging=False,
58
  )
59
 
60
 
app/model.py CHANGED
@@ -1,85 +1,25 @@
1
  from __future__ import annotations
2
 
3
- import warnings
 
4
 
5
  import numpy as np
6
- import spacy
7
  from joblib import Memory
8
- from sklearn.base import BaseEstimator, TransformerMixin
9
  from sklearn.feature_extraction.text import TfidfVectorizer
10
  from sklearn.linear_model import LogisticRegression
11
  from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
12
  from sklearn.pipeline import Pipeline
13
- from tqdm import tqdm
14
 
15
  from app.constants import CACHE_DIR
 
16
 
17
- __all__ = ["create_model", "train_model", "evaluate_model"]
18
-
19
- try:
20
- nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
21
- except OSError:
22
- print("Downloading spaCy model...")
23
-
24
- from spacy.cli import download as spacy_download
25
-
26
- spacy_download("en_core_web_sm")
27
- nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
28
-
29
-
30
- class TextTokenizer(BaseEstimator, TransformerMixin):
31
- def __init__(
32
- self,
33
- *,
34
- character_threshold: int = 2,
35
- batch_size: int = 1024,
36
- n_jobs: int = 8,
37
- progress: bool = True,
38
- ) -> None:
39
- self.character_threshold = character_threshold
40
- self.batch_size = batch_size
41
- self.n_jobs = n_jobs
42
- self.progress = progress
43
-
44
- def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextTokenizer:
45
- return self
46
-
47
- def transform(self, data: list[str]) -> list[list[str]]:
48
- tokenized = []
49
- for doc in tqdm(
50
- nlp.pipe(data, batch_size=self.batch_size, n_process=self.n_jobs),
51
- total=len(data),
52
- disable=not self.progress,
53
- ):
54
- tokens = []
55
- for token in doc:
56
- # Ignore stop words and punctuation
57
- if token.is_stop or token.is_punct:
58
- continue
59
- # Ignore emails, URLs and numbers
60
- if token.like_email or token.like_email or token.like_num:
61
- continue
62
-
63
- # Lemmatize and lowercase
64
- tok = token.lemma_.lower().strip()
65
-
66
- # Format hashtags
67
- if tok.startswith("#"):
68
- tok = tok[1:]
69
-
70
- # Ignore short and non-alphanumeric tokens
71
- if len(tok) < self.character_threshold or not tok.isalnum():
72
- continue
73
-
74
- # TODO: Emoticons and emojis
75
- # TODO: Spelling correction
76
-
77
- tokens.append(tok)
78
- tokenized.append(tokens)
79
- return tokenized
80
-
81
-
82
- def identity(x: list[str]) -> list[str]:
83
  """Identity function for use in TfidfVectorizer.
84
 
85
  Args:
@@ -101,22 +41,21 @@ def create_model(
101
  Args:
102
  max_features: Maximum number of features
103
  seed: Random seed (None for random seed)
104
- verbose: Whether to log progress during training
105
 
106
  Returns:
107
  Untrained model
108
  """
109
  return Pipeline(
110
  [
111
- ("tokenizer", TextTokenizer(progress=True)),
112
  (
113
  "vectorizer",
114
  TfidfVectorizer(
115
  max_features=max_features,
116
  ngram_range=(1, 2),
117
  # disable text processing
118
- tokenizer=identity,
119
- preprocessor=identity,
120
  lowercase=False,
121
  token_pattern=None,
122
  ),
@@ -130,23 +69,27 @@ def create_model(
130
 
131
  def train_model(
132
  model: BaseEstimator,
133
- text_data: list[str],
134
  label_data: list[int],
 
135
  seed: int = 42,
 
136
  ) -> tuple[BaseEstimator, float]:
137
  """Train the sentiment analysis model.
138
 
139
  Args:
140
  model: Untrained model
141
- text_data: Text data
142
  label_data: Label data
 
143
  seed: Random seed (None for random seed)
 
144
 
145
  Returns:
146
  Trained model and accuracy
147
  """
148
  text_train, text_test, label_train, label_test = train_test_split(
149
- text_data,
150
  label_data,
151
  test_size=0.2,
152
  random_state=seed,
@@ -154,50 +97,82 @@ def train_model(
154
 
155
  param_distributions = {
156
  "classifier__C": np.logspace(-4, 4, 20),
157
- "classifier__penalty": ["l1", "l2"],
158
  }
159
 
160
  search = RandomizedSearchCV(
161
  model,
162
  param_distributions,
163
  n_iter=10,
164
- cv=5,
165
  scoring="accuracy",
166
  random_state=seed,
167
  n_jobs=-1,
 
168
  )
169
 
170
- with warnings.catch_warnings():
171
- warnings.simplefilter("ignore")
172
- # model.fit(text_train, label_train)
173
- search.fit(text_train, label_train)
174
 
175
  best_model = search.best_estimator_
176
  return best_model, best_model.score(text_test, label_test)
177
 
178
 
179
  def evaluate_model(
180
- model: Pipeline,
181
- text_data: list[str],
182
  label_data: list[int],
183
  folds: int = 5,
 
184
  ) -> tuple[float, float]:
185
  """Evaluate the model using cross-validation.
186
 
187
  Args:
188
  model: Trained model
189
- text_data: Text data
190
  label_data: Label data
191
  folds: Number of cross-validation folds
 
192
 
193
  Returns:
194
  Mean accuracy and standard deviation
195
  """
 
196
  scores = cross_val_score(
197
  model,
198
- text_data,
199
  label_data,
200
  cv=folds,
201
  scoring="accuracy",
 
 
202
  )
 
203
  return scores.mean(), scores.std()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import os
4
+ from typing import TYPE_CHECKING
5
 
6
  import numpy as np
 
7
  from joblib import Memory
 
8
  from sklearn.feature_extraction.text import TfidfVectorizer
9
  from sklearn.linear_model import LogisticRegression
10
  from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
11
  from sklearn.pipeline import Pipeline
 
12
 
13
  from app.constants import CACHE_DIR
14
+ from app.data import tokenize
15
 
16
+ if TYPE_CHECKING:
17
+ from sklearn.base import BaseEstimator
18
+
19
+ __all__ = ["create_model", "train_model", "evaluate_model", "infer_model"]
20
+
21
+
22
+ def _identity(x: list[str]) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  """Identity function for use in TfidfVectorizer.
24
 
25
  Args:
 
41
  Args:
42
  max_features: Maximum number of features
43
  seed: Random seed (None for random seed)
44
+ verbose: Whether to output additional information
45
 
46
  Returns:
47
  Untrained model
48
  """
49
  return Pipeline(
50
  [
 
51
  (
52
  "vectorizer",
53
  TfidfVectorizer(
54
  max_features=max_features,
55
  ngram_range=(1, 2),
56
  # disable text processing
57
+ tokenizer=_identity,
58
+ preprocessor=_identity,
59
  lowercase=False,
60
  token_pattern=None,
61
  ),
 
69
 
70
  def train_model(
71
  model: BaseEstimator,
72
+ token_data: list[str],
73
  label_data: list[int],
74
+ folds: int = 5,
75
  seed: int = 42,
76
+ verbose: bool = False,
77
  ) -> tuple[BaseEstimator, float]:
78
  """Train the sentiment analysis model.
79
 
80
  Args:
81
  model: Untrained model
82
+ token_data: Tokenized text data
83
  label_data: Label data
84
+ folds: Number of cross-validation folds
85
  seed: Random seed (None for random seed)
86
+ verbose: Whether to output additional information
87
 
88
  Returns:
89
  Trained model and accuracy
90
  """
91
  text_train, text_test, label_train, label_test = train_test_split(
92
+ token_data,
93
  label_data,
94
  test_size=0.2,
95
  random_state=seed,
 
97
 
98
  param_distributions = {
99
  "classifier__C": np.logspace(-4, 4, 20),
100
+ "classifier__solver": ["liblinear", "saga"],
101
  }
102
 
103
  search = RandomizedSearchCV(
104
  model,
105
  param_distributions,
106
  n_iter=10,
107
+ cv=folds,
108
  scoring="accuracy",
109
  random_state=seed,
110
  n_jobs=-1,
111
+ verbose=verbose,
112
  )
113
 
114
+ os.environ["PYTHONWARNINGS"] = "ignore"
115
+ search.fit(text_train, label_train)
116
+ del os.environ["PYTHONWARNINGS"]
 
117
 
118
  best_model = search.best_estimator_
119
  return best_model, best_model.score(text_test, label_test)
120
 
121
 
122
  def evaluate_model(
123
+ model: BaseEstimator,
124
+ token_data: list[str],
125
  label_data: list[int],
126
  folds: int = 5,
127
+ verbose: bool = False,
128
  ) -> tuple[float, float]:
129
  """Evaluate the model using cross-validation.
130
 
131
  Args:
132
  model: Trained model
133
+ token_data: Tokenized text data
134
  label_data: Label data
135
  folds: Number of cross-validation folds
136
+ verbose: Whether to output additional information
137
 
138
  Returns:
139
  Mean accuracy and standard deviation
140
  """
141
+ os.environ["PYTHONWARNINGS"] = "ignore"
142
  scores = cross_val_score(
143
  model,
144
+ token_data,
145
  label_data,
146
  cv=folds,
147
  scoring="accuracy",
148
+ n_jobs=-1,
149
+ verbose=verbose,
150
  )
151
+ del os.environ["PYTHONWARNINGS"]
152
  return scores.mean(), scores.std()
153
+
154
+
155
+ def infer_model(
156
+ model: BaseEstimator,
157
+ text_data: list[str],
158
+ batch_size: int = 32,
159
+ n_jobs: int = 4,
160
+ ) -> list[int]:
161
+ """Predict the sentiment of the provided text documents.
162
+
163
+ Args:
164
+ model: Trained model
165
+ text_data: Text data
166
+ batch_size: Batch size for tokenization
167
+ n_jobs: Number of parallel jobs
168
+
169
+ Returns:
170
+ Predicted sentiments
171
+ """
172
+ tokens = tokenize(
173
+ text_data,
174
+ batch_size=batch_size,
175
+ n_jobs=n_jobs,
176
+ show_progress=False,
177
+ )
178
+ return model.predict(tokens)