Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer | |
| from transformers import BertTokenizer, BertModel | |
| from typing import Tuple, Dict, List | |
| import pandas as pd | |
| from tqdm import tqdm | |
| class FeatureExtractor: | |
| def __init__(self, bert_model_name: str = "bert-base-uncased"): | |
| self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name) | |
| self.bert_model = BertModel.from_pretrained(bert_model_name) | |
| self.tfidf_vectorizer = TfidfVectorizer( | |
| max_features=5000, | |
| ngram_range=(1, 2), | |
| stop_words='english' | |
| ) | |
| self.count_vectorizer = CountVectorizer( | |
| max_features=5000, | |
| ngram_range=(1, 2), | |
| stop_words='english' | |
| ) | |
| def get_bert_embeddings(self, texts: List[str], | |
| batch_size: int = 32, | |
| max_length: int = 512) -> np.ndarray: | |
| """Extract BERT embeddings for a list of texts.""" | |
| self.bert_model.eval() | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(0, len(texts), batch_size)): | |
| batch_texts = texts[i:i + batch_size] | |
| # Tokenize and prepare input | |
| encoded = self.bert_tokenizer( | |
| batch_texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ) | |
| # Get BERT embeddings | |
| outputs = self.bert_model(**encoded) | |
| # Use [CLS] token embeddings as sentence representation | |
| batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
| embeddings.append(batch_embeddings) | |
| return np.vstack(embeddings) | |
| def get_tfidf_features(self, texts: List[str]) -> np.ndarray: | |
| """Extract TF-IDF features from texts.""" | |
| return self.tfidf_vectorizer.fit_transform(texts).toarray() | |
| def get_count_features(self, texts: List[str]) -> np.ndarray: | |
| """Extract Count Vectorizer features from texts.""" | |
| return self.count_vectorizer.fit_transform(texts).toarray() | |
| def extract_all_features(self, texts: List[str], | |
| use_bert: bool = True, | |
| use_tfidf: bool = True, | |
| use_count: bool = True) -> Dict[str, np.ndarray]: | |
| """Extract all features from texts.""" | |
| features = {} | |
| if use_bert: | |
| features['bert'] = self.get_bert_embeddings(texts) | |
| if use_tfidf: | |
| features['tfidf'] = self.get_tfidf_features(texts) | |
| if use_count: | |
| features['count'] = self.get_count_features(texts) | |
| return features | |
| def extract_features_from_dataframe(self, | |
| df: pd.DataFrame, | |
| text_column: str, | |
| **kwargs) -> Dict[str, np.ndarray]: | |
| """Extract features from a dataframe's text column.""" | |
| texts = df[text_column].tolist() | |
| return self.extract_all_features(texts, **kwargs) |