Tymec's picture
Update docstrings and comments
d4ef46b
raw
history blame
6.05 kB
"""Functions for model training, evaluation, and inference."""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Literal, Sequence
import numpy as np
from joblib import Memory
from sklearn.exceptions import ConvergenceWarning
from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer, TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from app.constants import CACHE_DIR
from app.data import tokenize
if TYPE_CHECKING:
from sklearn.base import BaseEstimator, TransformerMixin
__all__ = ["train_model", "evaluate_model", "infer_model"]
def _identity(x: list[str]) -> list[str]:
"""Identity function for use in vectorizers.
Args:
x: Input data
Returns:
Unchanged input data
"""
return x
def _get_vectorizer(
name: Literal["tfidf", "count", "hashing"],
n_features: int,
min_df: int = 5,
) -> TransformerMixin:
"""Get the appropriate vectorizer.
Args:
name: Type of vectorizer
n_features: Maximum number of features
min_df: Minimum document frequency (ignored for hashing)
Returns:
Vectorizer instance
Raises:
ValueError: If the vectorizer is not recognized
"""
shared_params = {
"ngram_range": (1, 2), # unigrams and bigrams
# disable text processing
"tokenizer": _identity,
"preprocessor": _identity,
"lowercase": False,
"token_pattern": None,
}
match name:
case "tfidf":
return TfidfVectorizer(
max_features=n_features,
min_df=min_df,
**shared_params,
)
case "count":
return CountVectorizer(
max_features=n_features,
min_df=min_df,
**shared_params,
)
case "hashing":
if n_features < 2**15:
warnings.warn(
"HashingVectorizer may perform poorly with small n_features, default is 2^20.",
stacklevel=2,
)
return HashingVectorizer(
n_features=n_features,
**shared_params,
)
case _:
msg = f"Unknown vectorizer: {name}"
raise ValueError(msg)
def train_model(
token_data: Sequence[Sequence[str]],
label_data: list[int],
vectorizer: Literal["tfidf", "count", "hashing"],
max_features: int,
min_df: int = 5,
cv: int = 5,
n_jobs: int = 4,
seed: int = 42,
) -> tuple[BaseEstimator, float]:
"""Train the sentiment analysis model.
Args:
token_data: Tokenized text data
label_data: Label data
vectorizer: Which vectorizer to use
max_features: Maximum number of features
min_df: Minimum document frequency (ignored for hashing)
cv: Number of cross-validation folds
n_jobs: Number of parallel jobs
seed: Random seed (None for random seed)
Returns:
Trained model and accuracy
Raises:
ValueError: If the vectorizer is not recognized
"""
rs = None if seed == -1 else seed
# Split the data into training and testing sets
text_train, text_test, label_train, label_test = train_test_split(
token_data,
label_data,
test_size=0.2,
random_state=rs,
)
# Create the model pipeline
vectorizer = _get_vectorizer(vectorizer, max_features, min_df)
classifier = LogisticRegression(max_iter=1000, random_state=rs)
model = Pipeline(
[("vectorizer", vectorizer), ("classifier", classifier)],
memory=Memory(CACHE_DIR, verbose=0),
)
param_dist = {"classifier__C": np.logspace(-4, 4, 20)}
# Perform randomized search for hyperparameter tuning
search = RandomizedSearchCV(
model,
param_dist,
cv=cv,
random_state=rs,
n_jobs=n_jobs,
scoring="accuracy",
n_iter=10,
verbose=2,
)
with warnings.catch_warnings():
warnings.filterwarnings("once", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
search.fit(text_train, label_train)
final_model = search.best_estimator_
return final_model, final_model.score(text_test, label_test)
def evaluate_model(
model: BaseEstimator,
token_data: Sequence[Sequence[str]],
label_data: list[int],
cv: int = 5,
n_jobs: int = 4,
) -> tuple[float, float]:
"""Evaluate the model using cross-validation.
Args:
model: Trained model
token_data: Tokenized text data
label_data: Label data
cv: Number of cross-validation folds
n_jobs: Number of parallel jobs
Returns:
Mean accuracy and standard deviation
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
# Perform cross-validation to evaluate the model
scores = cross_val_score(
model,
token_data,
label_data,
cv=cv,
scoring="accuracy",
n_jobs=n_jobs,
verbose=2,
)
return scores.mean(), scores.std()
def infer_model(
model: BaseEstimator,
text_data: list[str],
batch_size: int = 32,
n_jobs: int = 4,
) -> list[int]:
"""Predict the sentiment of the provided text documents.
Args:
model: Trained model
text_data: Text data
batch_size: Batch size for tokenization
n_jobs: Number of parallel jobs
Returns:
Predicted sentiments
"""
tokens = tokenize(
text_data,
batch_size=batch_size,
n_jobs=n_jobs,
show_progress=False,
)
return model.predict(tokens)