import os import torch from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments from sklearn.model_selection import train_test_split from db.db_utils import get_connection import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set the device to CPU device = torch.device("cpu") # Directory to save/load the model save_directory = './specificity-model' # Define label_mapping globally label_mapping = {} # Check if the model exists if os.path.exists(save_directory): logger.info(f"Loading the existing model from {save_directory}...") tokenizer = BertTokenizer.from_pretrained(save_directory) model = BertForSequenceClassification.from_pretrained(save_directory) # Load the label mapping if os.path.exists(os.path.join(save_directory, 'label_mapping.txt')): with open(os.path.join(save_directory, 'label_mapping.txt'), 'r') as f: label_mapping = eval(f.read()) else: logger.info("Loading BERT tokenizer and model...") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) # Get data from database logger.info("Connecting to the database...") db_conn = get_connection() db_cursor = db_conn.cursor() logger.info("Fetching data from the database...") db_cursor.execute("SELECT input_word, specificity FROM mappings WHERE specificity IS NOT NULL and reviewed = true and is_food = true") results = db_cursor.fetchall() training_data = [(row[0], row[1]) for row in results] texts, labels = zip(*training_data) logger.info(f"Fetched {len(texts)} records from the database.") # Convert labels to integers logger.info("Converting labels to integers...") label_mapping = {label: idx for idx, label in enumerate(set(labels))} labels = [label_mapping[label] for label in labels] # Split data into training and testing sets logger.info("Splitting data into training and testing sets...") X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42) # Tokenize the data logger.info("Tokenizing the data...") train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=128) test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=128) class SpecificityDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]).to(device) return item def __len__(self): return len(self.labels) logger.info("Creating datasets...") train_dataset = SpecificityDataset(train_encodings, y_train) test_dataset = SpecificityDataset(test_encodings, y_test) training_args = TrainingArguments( output_dir='./specificity-results', # output directory num_train_epochs=8, # number of training epochs per_device_train_batch_size=16, # batch size for training per_device_eval_batch_size=64, # batch size for evaluation warmup_steps=500, # number of warmup steps for learning rate scheduler weight_decay=0.01, # strength of weight decay logging_dir='./logs', # directory for storing logs logging_steps=10, evaluation_strategy="epoch" ) logger.info("Initializing the Trainer...") trainer = Trainer( model=model, # the instantiated 🤗 Transformers model to be trained args=training_args, # training arguments, defined above train_dataset=train_dataset, # training dataset eval_dataset=test_dataset # evaluation dataset ) logger.info("Starting training...") trainer.train() logger.info("Evaluating the model...") eval_result = trainer.evaluate() logger.info(f"Evaluation results: {eval_result}") # Save the model and tokenizer logger.info(f"Saving the model to {save_directory}...") model.save_pretrained(save_directory) tokenizer.save_pretrained(save_directory) # Save the label mapping with open(os.path.join(save_directory, 'label_mapping.txt'), 'w') as f: f.write(str(label_mapping)) model.to(device) def classify_text_to_specificity(text): logger.info(f"Classifying text: {text}") inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) outputs = model(**inputs) logits = outputs.logits predicted_class_id = torch.argmax(logits, dim=1).item() # Map predicted class ID back to the original label inv_label_mapping = {v: k for k, v in label_mapping.items()} return inv_label_mapping[predicted_class_id] # Example usage # for example_text in ["produce items", "bananas", "milk", "mixed items", "random assortment", "heterogeneous mixture"]: # predicted_specificity = classify_text_to_specificity(example_text) # logger.info(f"The predicted specificity for '{example_text}' is '{predicted_specificity}'") # logger.info("----------")