Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
from sklearn.model_selection import train_test_split | |
from db.db_utils import get_connection | |
import logging | |
# Set the device to CPU | |
device = torch.device("cpu") | |
# Directory to save/load the model | |
save_directory = './specificity-model' | |
# Define label_mapping globally | |
label_mapping = {} | |
# Check if the model exists | |
if os.path.exists(save_directory): | |
logging.info(f"Loading the existing model from {save_directory}...") | |
tokenizer = BertTokenizer.from_pretrained(save_directory) | |
model = BertForSequenceClassification.from_pretrained(save_directory) | |
# Load the label mapping | |
if os.path.exists(os.path.join(save_directory, 'label_mapping.txt')): | |
with open(os.path.join(save_directory, 'label_mapping.txt'), 'r') as f: | |
label_mapping = eval(f.read()) | |
else: | |
logging.info("Loading BERT tokenizer and model...") | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) | |
# Get data from database | |
logging.info("Connecting to the database...") | |
db_conn = get_connection() | |
db_cursor = db_conn.cursor() | |
logging.info("Fetching data from the database...") | |
db_cursor.execute("SELECT input_word, specificity FROM mappings WHERE specificity IS NOT NULL and reviewed = true and is_food = true") | |
results = db_cursor.fetchall() | |
training_data = [(row[0], row[1]) for row in results] | |
texts, labels = zip(*training_data) | |
logging.info(f"Fetched {len(texts)} records from the database.") | |
# Convert labels to integers | |
logging.info("Converting labels to integers...") | |
label_mapping = {label: idx for idx, label in enumerate(set(labels))} | |
labels = [label_mapping[label] for label in labels] | |
# Split data into training and testing sets | |
logging.info("Splitting data into training and testing sets...") | |
X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42) | |
# Tokenize the data | |
logging.info("Tokenizing the data...") | |
train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=128) | |
test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=128) | |
class SpecificityDataset(torch.utils.data.Dataset): | |
def __init__(self, encodings, labels): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()} | |
item['labels'] = torch.tensor(self.labels[idx]).to(device) | |
return item | |
def __len__(self): | |
return len(self.labels) | |
logging.info("Creating datasets...") | |
train_dataset = SpecificityDataset(train_encodings, y_train) | |
test_dataset = SpecificityDataset(test_encodings, y_test) | |
training_args = TrainingArguments( | |
output_dir='./specificity-results', # output directory | |
num_train_epochs=8, # number of training epochs | |
per_device_train_batch_size=16, # batch size for training | |
per_device_eval_batch_size=64, # batch size for evaluation | |
warmup_steps=500, # number of warmup steps for learning rate scheduler | |
weight_decay=0.01, # strength of weight decay | |
logging_dir='./logs', # directory for storing logs | |
logging_steps=10, | |
evaluation_strategy="epoch" | |
) | |
logging.info("Initializing the Trainer...") | |
trainer = Trainer( | |
model=model, # the instantiated π€ Transformers model to be trained | |
args=training_args, # training arguments, defined above | |
train_dataset=train_dataset, # training dataset | |
eval_dataset=test_dataset # evaluation dataset | |
) | |
logging.info("Starting training...") | |
trainer.train() | |
logging.info("Evaluating the model...") | |
eval_result = trainer.evaluate() | |
logging.info(f"Evaluation results: {eval_result}") | |
# Save the model and tokenizer | |
logging.info(f"Saving the model to {save_directory}...") | |
model.save_pretrained(save_directory) | |
tokenizer.save_pretrained(save_directory) | |
# Save the label mapping | |
with open(os.path.join(save_directory, 'label_mapping.txt'), 'w') as f: | |
f.write(str(label_mapping)) | |
model.to(device) | |
def classify_text_to_specificity(text): | |
logging.info(f"Classifying text: {text}") | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_id = torch.argmax(logits, dim=1).item() | |
# Map predicted class ID back to the original label | |
inv_label_mapping = {v: k for k, v in label_mapping.items()} | |
return inv_label_mapping[predicted_class_id] | |
# Example usage | |
# for example_text in ["produce items", "bananas", "milk", "mixed items", "random assortment", "heterogeneous mixture"]: | |
# predicted_specificity = classify_text_to_specificity(example_text) | |
# logging.info(f"The predicted specificity for '{example_text}' is '{predicted_specificity}'") | |
# logging.info("----------") | |