File size: 5,460 Bytes
0a288ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff64fc
0a288ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from db.db_utils import get_connection
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set the device to CPU
device = torch.device("cpu")

# Directory to save/load the model
save_directory = './specificity-model'

# Define label_mapping globally
label_mapping = {}

# Check if the model exists
if os.path.exists(save_directory):
    logger.info(f"Loading the existing model from {save_directory}...")
    tokenizer = BertTokenizer.from_pretrained(save_directory)
    model = BertForSequenceClassification.from_pretrained(save_directory)
    # Load the label mapping
    if os.path.exists(os.path.join(save_directory, 'label_mapping.txt')):
        with open(os.path.join(save_directory, 'label_mapping.txt'), 'r') as f:
            label_mapping = eval(f.read())
else:
    logger.info("Loading BERT tokenizer and model...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)

    # Get data from database
    logger.info("Connecting to the database...")
    db_conn = get_connection()
    db_cursor = db_conn.cursor()

    logger.info("Fetching data from the database...")
    db_cursor.execute("SELECT input_word, specificity FROM mappings WHERE specificity IS NOT NULL and reviewed = true and is_food = true")
    results = db_cursor.fetchall()
    training_data = [(row[0], row[1]) for row in results]

    texts, labels = zip(*training_data)
    logger.info(f"Fetched {len(texts)} records from the database.")

    # Convert labels to integers
    logger.info("Converting labels to integers...")
    label_mapping = {label: idx for idx, label in enumerate(set(labels))}
    labels = [label_mapping[label] for label in labels]

    # Split data into training and testing sets
    logger.info("Splitting data into training and testing sets...")
    X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42)

    # Tokenize the data
    logger.info("Tokenizing the data...")
    train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=128)
    test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=128)

    class SpecificityDataset(torch.utils.data.Dataset):
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels = labels

        def __getitem__(self, idx):
            item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
            item['labels'] = torch.tensor(self.labels[idx]).to(device)
            return item

        def __len__(self):
            return len(self.labels)

    logger.info("Creating datasets...")
    train_dataset = SpecificityDataset(train_encodings, y_train)
    test_dataset = SpecificityDataset(test_encodings, y_test)

    training_args = TrainingArguments(
        output_dir='./specificity-results',          # output directory
        num_train_epochs=8,              # number of training epochs
        per_device_train_batch_size=16,  # batch size for training
        per_device_eval_batch_size=64,   # batch size for evaluation
        warmup_steps=500,                # number of warmup steps for learning rate scheduler
        weight_decay=0.01,               # strength of weight decay
        logging_dir='./logs',            # directory for storing logs
        logging_steps=10,
        evaluation_strategy="epoch"
    )

    logger.info("Initializing the Trainer...")
    trainer = Trainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        train_dataset=train_dataset,         # training dataset
        eval_dataset=test_dataset             # evaluation dataset
    )

    logger.info("Starting training...")
    trainer.train()

    logger.info("Evaluating the model...")
    eval_result = trainer.evaluate()
    logger.info(f"Evaluation results: {eval_result}")

    # Save the model and tokenizer
    logger.info(f"Saving the model to {save_directory}...")
    model.save_pretrained(save_directory)
    tokenizer.save_pretrained(save_directory)
    # Save the label mapping
    with open(os.path.join(save_directory, 'label_mapping.txt'), 'w') as f:
        f.write(str(label_mapping))

model.to(device)

def classify_text_to_specificity(text):
    logger.info(f"Classifying text: {text}")
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=1).item()
    # Map predicted class ID back to the original label
    inv_label_mapping = {v: k for k, v in label_mapping.items()}
    return inv_label_mapping[predicted_class_id]

# Example usage
# for example_text in ["produce items", "bananas", "milk", "mixed items", "random assortment", "heterogeneous mixture"]:
#     predicted_specificity = classify_text_to_specificity(example_text)
#     logger.info(f"The predicted specificity for '{example_text}' is '{predicted_specificity}'")
#     logger.info("----------")