Spaces:
Paused
Paused
File size: 5,460 Bytes
0a288ad 0ff64fc 0a288ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from db.db_utils import get_connection
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the device to CPU
device = torch.device("cpu")
# Directory to save/load the model
save_directory = './specificity-model'
# Define label_mapping globally
label_mapping = {}
# Check if the model exists
if os.path.exists(save_directory):
logger.info(f"Loading the existing model from {save_directory}...")
tokenizer = BertTokenizer.from_pretrained(save_directory)
model = BertForSequenceClassification.from_pretrained(save_directory)
# Load the label mapping
if os.path.exists(os.path.join(save_directory, 'label_mapping.txt')):
with open(os.path.join(save_directory, 'label_mapping.txt'), 'r') as f:
label_mapping = eval(f.read())
else:
logger.info("Loading BERT tokenizer and model...")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)
# Get data from database
logger.info("Connecting to the database...")
db_conn = get_connection()
db_cursor = db_conn.cursor()
logger.info("Fetching data from the database...")
db_cursor.execute("SELECT input_word, specificity FROM mappings WHERE specificity IS NOT NULL and reviewed = true and is_food = true")
results = db_cursor.fetchall()
training_data = [(row[0], row[1]) for row in results]
texts, labels = zip(*training_data)
logger.info(f"Fetched {len(texts)} records from the database.")
# Convert labels to integers
logger.info("Converting labels to integers...")
label_mapping = {label: idx for idx, label in enumerate(set(labels))}
labels = [label_mapping[label] for label in labels]
# Split data into training and testing sets
logger.info("Splitting data into training and testing sets...")
X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42)
# Tokenize the data
logger.info("Tokenizing the data...")
train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=128)
test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=128)
class SpecificityDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx]).to(device)
return item
def __len__(self):
return len(self.labels)
logger.info("Creating datasets...")
train_dataset = SpecificityDataset(train_encodings, y_train)
test_dataset = SpecificityDataset(test_encodings, y_test)
training_args = TrainingArguments(
output_dir='./specificity-results', # output directory
num_train_epochs=8, # number of training epochs
per_device_train_batch_size=16, # batch size for training
per_device_eval_batch_size=64, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
logging_steps=10,
evaluation_strategy="epoch"
)
logger.info("Initializing the Trainer...")
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=test_dataset # evaluation dataset
)
logger.info("Starting training...")
trainer.train()
logger.info("Evaluating the model...")
eval_result = trainer.evaluate()
logger.info(f"Evaluation results: {eval_result}")
# Save the model and tokenizer
logger.info(f"Saving the model to {save_directory}...")
model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
# Save the label mapping
with open(os.path.join(save_directory, 'label_mapping.txt'), 'w') as f:
f.write(str(label_mapping))
model.to(device)
def classify_text_to_specificity(text):
logger.info(f"Classifying text: {text}")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=1).item()
# Map predicted class ID back to the original label
inv_label_mapping = {v: k for k, v in label_mapping.items()}
return inv_label_mapping[predicted_class_id]
# Example usage
# for example_text in ["produce items", "bananas", "milk", "mixed items", "random assortment", "heterogeneous mixture"]:
# predicted_specificity = classify_text_to_specificity(example_text)
# logger.info(f"The predicted specificity for '{example_text}' is '{predicted_specificity}'")
# logger.info("----------")
|