jesse-tong commited on
Commit
ae47555
·
1 Parent(s): 9e5f013

Add LSTM fine tuning

Browse files
.gitignore CHANGED
@@ -4,4 +4,7 @@ __pycache__/
4
  *.pyc
5
  *.pyo
6
  *.pyd
7
- *.db
 
 
 
 
4
  *.pyc
5
  *.pyo
6
  *.pyd
7
+ *.db
8
+ metrics.txt
9
+ predictions.txt
10
+ *.pth
dataset.py CHANGED
@@ -29,7 +29,7 @@ class DocumentDataset(Dataset):
29
  f"but found range [{min_label}, {max_label}]")
30
  logger.warning(f"Unique label values: {sorted(unique_labels)}")
31
 
32
- # Fix labels by remapping them to start from 0
33
  if min_label != 0:
34
  logger.warning(f"Auto-correcting labels to be zero-indexed...")
35
  label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
@@ -132,8 +132,20 @@ def create_data_loaders(train_data, val_data, test_data, tokenizer_name='bert-ba
132
  test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length, num_classes)
133
 
134
  # Create data loaders
135
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
136
- val_loader = DataLoader(val_dataset, batch_size=batch_size)
137
- test_loader = DataLoader(test_dataset, batch_size=batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  return train_loader, val_loader, test_loader
 
29
  f"but found range [{min_label}, {max_label}]")
30
  logger.warning(f"Unique label values: {sorted(unique_labels)}")
31
 
32
+ # Fix labels by remapping them to start from 0 (some datasets might have labels starting from 1)
33
  if min_label != 0:
34
  logger.warning(f"Auto-correcting labels to be zero-indexed...")
35
  label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
 
132
  test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length, num_classes)
133
 
134
  # Create data loaders
135
+ if len(train_dataset.texts) == 0:
136
+ logger.warning("Training dataset is empty. Check your data loading and splitting.")
137
+ train_loader = None
138
+ else:
139
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
140
+ if len(val_dataset.texts) == 0:
141
+ logger.warning("Validation dataset is empty. Check your data loading and splitting.")
142
+ val_loader = None
143
+ else:
144
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
145
+ if len(test_dataset.texts) == 0:
146
+ logger.warning("Test dataset is empty. Check your data loading and splitting.")
147
+ test_loader = None
148
+ else:
149
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
150
 
151
  return train_loader, val_loader, test_loader
dataset_lstm.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import pandas as pd
5
+ from collections import Counter
6
+ import re
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class LSTMTokenizer:
12
+ """
13
+ Simple tokenizer for LSTM models
14
+ """
15
+ def __init__(self, max_vocab_size=30000, max_seq_length=512):
16
+ self.word2idx = {}
17
+ self.idx2word = {}
18
+ self.word2idx['<pad>'] = 0
19
+ self.word2idx['<unk>'] = 1
20
+ self.idx2word[0] = '<pad>'
21
+ self.idx2word[1] = '<unk>'
22
+ self.vocab_size = 2 # Start with pad and unk tokens
23
+ self.max_vocab_size = max_vocab_size
24
+ self.max_seq_length = max_seq_length
25
+
26
+ def fit(self, texts):
27
+ """Build vocabulary from texts"""
28
+ word_counts = Counter()
29
+
30
+ # Clean and tokenize texts
31
+ for text in texts:
32
+ words = self._tokenize(text)
33
+ word_counts.update(words)
34
+
35
+ # Sort by frequency and take most common words
36
+ vocab_words = [word for word, count in word_counts.most_common(self.max_vocab_size - 2)]
37
+
38
+ # Add words to vocabulary
39
+ for word in vocab_words:
40
+ if word not in self.word2idx:
41
+ self.word2idx[word] = self.vocab_size
42
+ self.idx2word[self.vocab_size] = word
43
+ self.vocab_size += 1
44
+
45
+ logger.info(f"Vocabulary size: {self.vocab_size}")
46
+ return self
47
+
48
+ def _tokenize(self, text):
49
+ """Simple tokenization by splitting on whitespace and removing punctuation"""
50
+ text = text.lower()
51
+ # Remove punctuation and split on whitespace
52
+ text = re.sub(r'[^\w\s]', '', text)
53
+ return text.split()
54
+
55
+ def encode(self, text, padding=True, truncation=True):
56
+ """Convert text to token ids"""
57
+ words = self._tokenize(text)
58
+
59
+ # Truncate if needed
60
+ if truncation and len(words) > self.max_seq_length:
61
+ words = words[:self.max_seq_length]
62
+
63
+ # Convert to indices
64
+ ids = [self.word2idx.get(word, self.word2idx['<unk>']) for word in words]
65
+
66
+ # Create attention mask (1 for tokens, 0 for padding)
67
+ attention_mask = [1] * len(ids)
68
+
69
+ # Pad if needed
70
+ if padding and len(ids) < self.max_seq_length:
71
+ padding_length = self.max_seq_length - len(ids)
72
+ ids = ids + [self.word2idx['<pad>']] * padding_length
73
+ attention_mask = attention_mask + [0] * padding_length
74
+
75
+ return {
76
+ 'input_ids': torch.tensor(ids, dtype=torch.long),
77
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
78
+ }
79
+
80
+ class LSTMDataset(Dataset):
81
+ """Dataset for LSTM model"""
82
+ def __init__(self, texts, labels, tokenizer):
83
+ self.texts = texts
84
+ self.labels = labels
85
+ self.tokenizer = tokenizer
86
+
87
+ def __len__(self):
88
+ return len(self.texts)
89
+
90
+ def __getitem__(self, idx):
91
+ text = str(self.texts[idx])
92
+ label = self.labels[idx]
93
+
94
+ # Tokenize
95
+ encoding = self.tokenizer.encode(text)
96
+
97
+ return {
98
+ 'input_ids': encoding['input_ids'],
99
+ 'attention_mask': encoding['attention_mask'],
100
+ 'label': torch.tensor(label, dtype=torch.long)
101
+ }
102
+
103
+ def prepare_lstm_data(data_path, text_col='text', label_col='label',
104
+ max_vocab_size=30000, max_seq_length=512,
105
+ val_split=0.1, test_split=0.1, batch_size=32, seed=42):
106
+ """
107
+ Load data and prepare for LSTM model
108
+ """
109
+ # Load data
110
+ if data_path.endswith('.csv'):
111
+ df = pd.read_csv(data_path)
112
+ elif data_path.endswith('.tsv'):
113
+ df = pd.read_csv(data_path, sep='\t')
114
+ else:
115
+ raise ValueError("Unsupported file format. Please provide CSV or TSV file.")
116
+
117
+ # Convert labels to numeric if they aren't already
118
+ if not np.issubdtype(df[label_col].dtype, np.number):
119
+ label_map = {label: idx for idx, label in enumerate(sorted(df[label_col].unique()))}
120
+ df['label_numeric'] = df[label_col].map(label_map)
121
+ labels = df['label_numeric'].values
122
+ logger.info(f"Label mapping: {label_map}")
123
+ else:
124
+ labels = df[label_col].values
125
+ # Make sure labels start from 0
126
+ min_label = labels.min()
127
+ if min_label != 0:
128
+ label_map = {label: idx for idx, label in enumerate(sorted(set(labels)))}
129
+ labels = np.array([label_map[label] for label in labels])
130
+
131
+ texts = df[text_col].values
132
+
133
+ # Split data
134
+ np.random.seed(seed)
135
+ indices = np.random.permutation(len(texts))
136
+
137
+ test_size = int(test_split * len(texts))
138
+ val_size = int(val_split * len(texts))
139
+ train_size = len(texts) - test_size - val_size
140
+
141
+ train_indices = indices[:train_size]
142
+ val_indices = indices[train_size:train_size + val_size]
143
+ test_indices = indices[train_size + val_size:]
144
+
145
+ train_texts, train_labels = texts[train_indices], labels[train_indices]
146
+ val_texts, val_labels = texts[val_indices], labels[val_indices]
147
+ test_texts, test_labels = texts[test_indices], labels[test_indices]
148
+
149
+ # Create tokenizer and fit on training data
150
+ tokenizer = LSTMTokenizer(max_vocab_size=max_vocab_size, max_seq_length=max_seq_length)
151
+ tokenizer.fit(train_texts)
152
+
153
+ # Create datasets
154
+ train_dataset = LSTMDataset(train_texts, train_labels, tokenizer)
155
+ val_dataset = LSTMDataset(val_texts, val_labels, tokenizer)
156
+ test_dataset = LSTMDataset(test_texts, test_labels, tokenizer)
157
+
158
+ # Create data loaders
159
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
160
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
161
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
162
+
163
+ return train_loader, val_loader, test_loader, tokenizer.vocab_size
distill_bert_to_lstm.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from model import DocBERT
8
+ from models.lstm_model import DocumentBiLSTM
9
+ from dataset import load_data, create_data_loaders
10
+ from dataset_lstm import prepare_lstm_data
11
+ from knowledge_distillation import DistillationTrainer
12
+ from transformers import BertTokenizer
13
+
14
+ # Setup logging
15
+ logging.basicConfig(
16
+ format="%(asctime)s - %(levelname)s - %(message)s",
17
+ level=logging.INFO,
18
+ datefmt="%Y-%m-%d %H:%M:%S",
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def set_seed(seed):
23
+ """Set all seeds for reproducibility"""
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+ def tokenize_for_lstm(texts, bert_tokenizer, max_seq_length=512):
33
+ """
34
+ Convert BERT tokenization format to format suitable for LSTM
35
+ This is a simple approach that just takes whole words from BERT tokenization
36
+ """
37
+ from collections import Counter
38
+
39
+ # Create vocabulary from all texts
40
+ word_counts = Counter()
41
+ all_words = []
42
+
43
+ for text in texts:
44
+ # Simple tokenization by splitting on whitespace
45
+ words = text.lower().split()
46
+ word_counts.update(words)
47
+ all_words.extend(words)
48
+
49
+ # Create word->index mapping
50
+ word2idx = {'<pad>': 0, '<unk>': 1}
51
+ for idx, (word, _) in enumerate(word_counts.most_common(30000 - 2), 2):
52
+ word2idx[word] = idx
53
+
54
+ vocab_size = len(word2idx)
55
+ logger.info(f"Created vocabulary with {vocab_size} tokens")
56
+
57
+ return word2idx, vocab_size
58
+
59
+ def main():
60
+ parser = argparse.ArgumentParser(description="Distill knowledge from BERT to LSTM for document classification")
61
+
62
+ # Data arguments
63
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
64
+ parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
65
+ parser.add_argument("--label_column", type=str, default="label", help="Name of the label column")
66
+ parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio")
67
+ parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio")
68
+
69
+ # BERT model arguments
70
+ parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model to use")
71
+ parser.add_argument("--bert_model_path", type=str, required=True, help="Path to saved BERT model weights")
72
+ parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length")
73
+
74
+ # LSTM model arguments
75
+ parser.add_argument("--embedding_dim", type=int, default=300, help="Dimension of word embeddings in LSTM")
76
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of LSTM")
77
+ parser.add_argument("--num_layers", type=int, default=2, help="Number of LSTM layers")
78
+ parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability")
79
+
80
+ # Distillation arguments
81
+ parser.add_argument("--temperature", type=float, default=2.0, help="Temperature for softening probability distributions")
82
+ parser.add_argument("--alpha", type=float, default=0.5, help="Weight for distillation loss vs. regular loss")
83
+ parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
84
+
85
+ # Training arguments
86
+ parser.add_argument("--batch_size", type=int, default=16, help="Training batch size")
87
+ parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for LSTM")
88
+ parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
89
+
90
+ # Other arguments
91
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
92
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save models")
93
+
94
+ args = parser.parse_args()
95
+
96
+ # Set seed for reproducibility
97
+ set_seed(args.seed)
98
+
99
+ # Create output directory if it doesn't exist
100
+ if not os.path.exists(args.output_dir):
101
+ os.makedirs(args.output_dir)
102
+
103
+ # Load and prepare data for both BERT and LSTM
104
+ logger.info("Loading and preparing data...")
105
+
106
+ # Load data first
107
+ train_data, val_data, test_data = load_data(
108
+ args.data_path,
109
+ text_col=args.text_column,
110
+ label_col=args.label_column,
111
+ validation_split=args.val_split,
112
+ test_split=args.test_split,
113
+ seed=args.seed
114
+ )
115
+
116
+ # Create BERT data loaders
117
+ logger.info("Creating BERT data loaders...")
118
+ bert_train_loader, bert_val_loader, bert_test_loader = create_data_loaders(
119
+ train_data,
120
+ val_data,
121
+ test_data,
122
+ tokenizer_name=args.bert_model,
123
+ max_length=args.max_seq_length,
124
+ batch_size=args.batch_size,
125
+ num_classes=args.num_classes
126
+ )
127
+
128
+ # Create LSTM data loaders
129
+ logger.info("Creating LSTM data loaders...")
130
+ lstm_train_loader, lstm_val_loader, lstm_test_loader, vocab_size = prepare_lstm_data(
131
+ args.data_path,
132
+ text_col=args.text_column,
133
+ label_col=args.label_column,
134
+ max_vocab_size=30000,
135
+ max_seq_length=args.max_seq_length,
136
+ batch_size=args.batch_size,
137
+ seed=args.seed
138
+ )
139
+
140
+ logger.info(f"LSTM Vocabulary size: {vocab_size}")
141
+
142
+ # Load pre-trained BERT model (teacher)
143
+ logger.info("Loading pre-trained BERT model (teacher)...")
144
+ bert_model = DocBERT(
145
+ num_classes=args.num_classes,
146
+ bert_model_name=args.bert_model,
147
+ dropout_prob=0.1
148
+ )
149
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
+ # Load saved BERT weights
151
+ bert_model.load_state_dict(torch.load(args.bert_model_path, map_location=device))
152
+ logger.info(f"Loaded teacher model from {args.bert_model_path}")
153
+
154
+ # Initialize LSTM model (student)
155
+ logger.info("Initializing LSTM model (student)...")
156
+ lstm_model = DocumentBiLSTM(
157
+ vocab_size=vocab_size,
158
+ embedding_dim=args.embedding_dim,
159
+ hidden_dim=args.hidden_dim,
160
+ output_dim=args.num_classes,
161
+ n_layers=args.num_layers,
162
+ dropout=args.dropout
163
+ )
164
+
165
+ # Print model sizes for comparison
166
+ bert_params = sum(p.numel() for p in bert_model.parameters())
167
+ lstm_params = sum(p.numel() for p in lstm_model.parameters())
168
+ logger.info(f"BERT model size: {bert_params:,} parameters")
169
+ logger.info(f"LSTM model size: {lstm_params:,} parameters")
170
+ logger.info(f"Size reduction: {bert_params / lstm_params:.1f}x")
171
+
172
+ # Initialize distillation trainer
173
+ trainer = DistillationTrainer(
174
+ teacher_model=bert_model,
175
+ student_model=lstm_model,
176
+ train_loader=bert_train_loader, # Using BERT loader to match tokenization
177
+ val_loader=bert_val_loader,
178
+ test_loader=bert_test_loader,
179
+ temperature=args.temperature,
180
+ alpha=args.alpha,
181
+ lr=args.learning_rate,
182
+ weight_decay=1e-5
183
+ )
184
+
185
+ # Train with knowledge distillation
186
+ logger.info("Starting knowledge distillation...")
187
+ save_path = os.path.join(args.output_dir, "distilled_lstm_model.pth")
188
+ trainer.train(epochs=args.epochs, save_path=save_path)
189
+
190
+ logger.info("Knowledge distillation completed!")
191
+
192
+ if __name__ == "__main__":
193
+ main()
example_uses.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python .\inference_example.py --model_path "./bert_base_uncased/best_model.pth" --num_classes 4 --class_names "World" "Sports" "Business" "Science" --text_column "Description" --label_column "Class Index" --data_path "./train.csv" --inference_batch_limit 10
inference_example.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import DocBERT
2
+ from dataset import load_data, create_data_loaders
3
+ from trainer import Trainer
4
+ import argparse
5
+ import os, sklearn
6
+ import numpy as np
7
+ import torch
8
+
9
+ if __name__ == "__main__":
10
+ parser = argparse.ArgumentParser(description="Document Classification with Distillation")
11
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset")
12
+ parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="Pre-trained BERT model name")
13
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
14
+ parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length for BERT")
15
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
16
+ parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification")
17
+ parser.add_argument("--text_column", type=str, default="text", help="Column name for text data")
18
+ parser.add_argument("--label_column", type=str, default="label", help="Column name for labels")
19
+ parser.add_argument("--class_names", type=str, nargs='+', required=True, help="List of class names for classification")
20
+ parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts")
21
+ parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console")
22
+ args = parser.parse_args()
23
+
24
+ class_names = args.class_names
25
+
26
+ # Set device
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ train_data, val_data, test_data = load_data(args.data_path,
29
+ text_col=args.text_column,
30
+ label_col=args.label_column,
31
+ validation_split=0.0,
32
+ test_split=1.0)
33
+ train_loader, val_loader, test_loader = create_data_loaders(train_data=train_data,
34
+ val_data=val_data,
35
+ test_data=test_data,
36
+ tokenizer_name=args.bert_model,
37
+ batch_size=args.batch_size,
38
+ max_length=args.max_seq_length)
39
+
40
+ model = DocBERT(bert_model_name=args.bert_model, num_classes=args.num_classes)
41
+ model.load_state_dict(torch.load(args.model_path, map_location=device))
42
+ model = model.to(device)
43
+
44
+ all_labels = np.array([], dtype=int)
45
+ all_predictions = np.array([], dtype=int)
46
+ batch_window_index = 0
47
+ batch_size = args.batch_size
48
+
49
+ # Inference
50
+ for batch in test_loader:
51
+ input_ids = batch['input_ids']
52
+ attention_mask = batch['attention_mask']
53
+ token_type_ids = batch['token_type_ids']
54
+ labels = batch['label']
55
+
56
+ input_ids = input_ids.to(device)
57
+ attention_mask = attention_mask.to(device)
58
+ token_type_ids = token_type_ids.to(device)
59
+ labels = labels.to(device)
60
+ all_labels = np.append(all_labels, labels.cpu().numpy())
61
+
62
+ with torch.no_grad():
63
+ outputs = model(input_ids, attention_mask=attention_mask)
64
+ logits = outputs
65
+ predictions = torch.argmax(logits, dim=-1)
66
+ all_predictions = np.append(all_predictions, predictions.cpu().numpy())
67
+
68
+ if args.print_predictions:
69
+ for i in range(len(predictions)):
70
+ idx = int(i)
71
+ print(f"Text: {test_data[0][batch_window_index*batch_size + idx]}")
72
+ print(f"True Label: {labels[idx].item()}, Predicted Label: {predictions[idx].item()}")
73
+ print(f"Predicted Class: {class_names[predictions[idx].item() if len(class_names) > predictions[idx].item() else 'Unknown']}")
74
+ print(f"True Class: {class_names[labels[idx].item()] if len(class_names) > predictions[idx].item() else 'Unknown'}")
75
+ print("-" * 50)
76
+
77
+ batch_window_index += 1
78
+ if args.inference_batch_limit > 0 and batch_window_index >= args.inference_batch_limit:
79
+ break
80
+
81
+ # Calculate accuracy, F1 score, recall, and precision
82
+ accuracy = sklearn.metrics.accuracy_score(all_labels, all_predictions)
83
+ f1 = sklearn.metrics.f1_score(all_labels, all_predictions, average='weighted')
84
+ precision = sklearn.metrics.precision_score(all_labels, all_predictions, average='weighted')
85
+ recall = sklearn.metrics.recall_score(all_labels, all_predictions, average='weighted')
86
+
87
+ print(f"Accuracy: {accuracy}")
88
+ print(f"F1 Score: {f1}")
89
+ print(f"Precision: {precision}")
90
+ print(f"Recall: {recall}")
91
+
92
+ with open("predictions.txt", "w") as f:
93
+ for i in range(len(all_labels)):
94
+ idx = int(i)
95
+ f.write(f"Text: {test_data[0][idx]}\n")
96
+ f.write(f"True Label: {all_labels[idx]}, Predicted Label: {all_predictions[idx]}\n")
97
+ f.write(f"Predicted Class: {class_names[all_predictions[idx]] if len(class_names) > all_predictions[idx] else "Unknown"}, True Class: {class_names[all_labels[idx]] if len(class_names) > all_predictions[idx] else "Unknown"}\n")
98
+ f.write("-" * 50 + "\n")
99
+
100
+ with open("metrics.txt", "w") as f:
101
+ f.write(f"Accuracy: {accuracy}\n")
102
+ f.write(f"F1 Score: {f1}\n")
103
+ f.write(f"Precision: {precision}\n")
104
+ f.write(f"Recall: {recall}\n")
knowledge_distillation.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import logging
7
+ import os
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class DistillationTrainer:
12
+ """
13
+ Trainer for knowledge distillation from teacher model (BERT) to student model (LSTM)
14
+ """
15
+ def __init__(
16
+ self,
17
+ teacher_model,
18
+ student_model,
19
+ train_loader,
20
+ val_loader,
21
+ test_loader=None,
22
+ temperature=2.0,
23
+ alpha=0.5, # Weight for distillation loss vs. regular loss
24
+ lr=0.001,
25
+ weight_decay=1e-5,
26
+ device=None
27
+ ):
28
+ self.teacher_model = teacher_model
29
+ self.student_model = student_model
30
+ self.train_loader = train_loader
31
+ self.val_loader = val_loader
32
+ self.test_loader = test_loader
33
+ self.temperature = temperature
34
+ self.alpha = alpha
35
+
36
+ self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ logger.info(f"Using device: {self.device}")
38
+
39
+ # Move models to device
40
+ self.teacher_model.to(self.device)
41
+ self.student_model.to(self.device)
42
+
43
+ # Set teacher model to evaluation mode
44
+ self.teacher_model.eval()
45
+
46
+ # Optimizer for student model
47
+ self.optimizer = torch.optim.Adam(
48
+ self.student_model.parameters(),
49
+ lr=lr,
50
+ weight_decay=weight_decay
51
+ )
52
+
53
+ # Learning rate scheduler
54
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
55
+ self.optimizer, mode='max', factor=0.5, patience=2, verbose=True
56
+ )
57
+
58
+ # Loss functions
59
+ self.ce_loss = nn.CrossEntropyLoss() # For hard targets
60
+
61
+ # Tracking metrics
62
+ self.best_val_f1 = 0.0
63
+ self.best_model_state = None
64
+
65
+ def distillation_loss(self, student_logits, teacher_logits, labels, temperature, alpha):
66
+ """
67
+ Compute the knowledge distillation loss
68
+
69
+ Args:
70
+ student_logits: Output from student model
71
+ teacher_logits: Output from teacher model
72
+ labels: Ground truth labels
73
+ temperature: Temperature for softening probability distributions
74
+ alpha: Weight for distillation loss vs. cross-entropy loss
75
+
76
+ Returns:
77
+ Combined loss
78
+ """
79
+ # Softmax with temperature for soft targets
80
+ soft_targets = F.softmax(teacher_logits / temperature, dim=1)
81
+ soft_prob = F.log_softmax(student_logits / temperature, dim=1)
82
+
83
+ # Distillation loss (KL divergence)
84
+ distill_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
85
+
86
+ # Standard cross entropy with hard targets
87
+ ce_loss = self.ce_loss(student_logits, labels)
88
+
89
+ # Weighted combination of the two losses
90
+ loss = alpha * distill_loss + (1 - alpha) * ce_loss
91
+
92
+ return loss
93
+
94
+ def train(self, epochs, save_path='best_distilled_model.pth'):
95
+ """
96
+ Train student model with knowledge distillation
97
+ """
98
+ logger.info(f"Starting distillation training for {epochs} epochs")
99
+ logger.info(f"Temperature: {self.temperature}, Alpha: {self.alpha}")
100
+
101
+ for epoch in range(epochs):
102
+ self.student_model.train()
103
+ train_loss = 0.0
104
+ all_preds = []
105
+ all_labels = []
106
+
107
+ # Training loop
108
+ train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
109
+ for batch in train_iterator:
110
+ # Move batch to device
111
+ input_ids = batch['input_ids'].to(self.device)
112
+ attention_mask = batch['attention_mask'].to(self.device)
113
+ labels = batch['label'].to(self.device)
114
+
115
+ # Get teacher predictions (no grad needed for teacher)
116
+ with torch.no_grad():
117
+ teacher_logits = self.teacher_model(
118
+ input_ids=input_ids,
119
+ attention_mask=attention_mask
120
+ )
121
+
122
+ # Forward pass through student model
123
+ student_logits = self.student_model(
124
+ input_ids=input_ids,
125
+ attention_mask=attention_mask
126
+ )
127
+
128
+ # Calculate distillation loss
129
+ loss = self.distillation_loss(
130
+ student_logits,
131
+ teacher_logits,
132
+ labels,
133
+ self.temperature,
134
+ self.alpha
135
+ )
136
+
137
+ # Backward and optimize
138
+ self.optimizer.zero_grad()
139
+ loss.backward()
140
+ torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), 1.0)
141
+ self.optimizer.step()
142
+
143
+ train_loss += loss.item()
144
+
145
+ # Calculate accuracy for progress tracking
146
+ _, preds = torch.max(student_logits, 1)
147
+ all_preds.extend(preds.cpu().tolist())
148
+ all_labels.extend(labels.cpu().tolist())
149
+
150
+ # Update progress bar
151
+ train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
152
+
153
+ # Calculate training metrics
154
+ train_loss = train_loss / len(self.train_loader)
155
+ train_acc = sum(1 for p, l in zip(all_preds, all_labels) if p == l) / len(all_preds)
156
+
157
+ # Evaluate on validation set
158
+ val_loss, val_acc, val_f1 = self.evaluate()
159
+
160
+ # Update learning rate based on validation performance
161
+ self.scheduler.step(val_f1)
162
+
163
+ # Save best model
164
+ if val_f1 > self.best_val_f1:
165
+ self.best_val_f1 = val_f1
166
+ self.best_model_state = self.student_model.state_dict().copy()
167
+ torch.save({
168
+ 'epoch': epoch,
169
+ 'model_state_dict': self.student_model.state_dict(),
170
+ 'optimizer_state_dict': self.optimizer.state_dict(),
171
+ 'val_f1': val_f1,
172
+ }, save_path)
173
+ logger.info(f"New best model saved with validation F1: {val_f1:.4f}")
174
+
175
+ logger.info(f"Epoch {epoch+1}/{epochs}: "
176
+ f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
177
+ f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
178
+
179
+ # Load best model for final evaluation
180
+ if self.best_model_state is not None:
181
+ self.student_model.load_state_dict(self.best_model_state)
182
+ logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
183
+
184
+ # Final evaluation on test set if provided
185
+ if self.test_loader:
186
+ test_loss, test_acc, test_f1 = self.evaluate(self.test_loader, "Test")
187
+ logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
188
+
189
+ def evaluate(self, data_loader=None, phase="Validation"):
190
+ """
191
+ Evaluate the student model
192
+ """
193
+ if data_loader is None:
194
+ data_loader = self.val_loader
195
+
196
+ self.student_model.eval()
197
+ eval_loss = 0.0
198
+ all_preds = []
199
+ all_labels = []
200
+
201
+ with torch.no_grad():
202
+ for batch in tqdm(data_loader, desc=f"[{phase}]"):
203
+ input_ids = batch['input_ids'].to(self.device)
204
+ attention_mask = batch['attention_mask'].to(self.device)
205
+ labels = batch['label'].to(self.device)
206
+
207
+ # Forward pass through student
208
+ student_logits = self.student_model(
209
+ input_ids=input_ids,
210
+ attention_mask=attention_mask
211
+ )
212
+
213
+ # Calculate regular CE loss (no distillation during evaluation)
214
+ loss = self.ce_loss(student_logits, labels)
215
+ eval_loss += loss.item()
216
+
217
+ # Get predictions
218
+ _, preds = torch.max(student_logits, 1)
219
+ all_preds.extend(preds.cpu().tolist())
220
+ all_labels.extend(labels.cpu().tolist())
221
+
222
+ # Calculate metrics
223
+ eval_loss = eval_loss / len(data_loader)
224
+
225
+ # Accuracy
226
+ accuracy = sum(1 for p, l in zip(all_preds, all_labels) if p == l) / len(all_preds)
227
+
228
+ # F1 score (macro-averaged)
229
+ from sklearn.metrics import f1_score
230
+ f1 = f1_score(all_labels, all_preds, average='macro')
231
+
232
+ return eval_loss, accuracy, f1
models/__init__.py ADDED
File without changes
models/lstm_model.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchtext.vocab import GloVe # For loading pre-trained word embeddings
5
+
6
+ class DocumentLSTM(nn.Module):
7
+ """
8
+ LSTM model for document classification using GloVe embeddings
9
+ """
10
+ def __init__(self, num_classes, vocab_size=30000, embedding_dim=300,
11
+ hidden_dim=256, num_layers=2, bidirectional=True,
12
+ dropout_rate=0.3, use_pretrained=True, padding_idx=0):
13
+ super(DocumentLSTM, self).__init__()
14
+
15
+ self.hidden_dim = hidden_dim
16
+ self.num_layers = num_layers
17
+ self.bidirectional = bidirectional
18
+ self.num_directions = 2 if bidirectional else 1
19
+
20
+ # Embedding layer (with option to use pre-trained GloVe)
21
+ if use_pretrained:
22
+ # Initialize with GloVe embeddings
23
+ try:
24
+ glove = GloVe(name='6B', dim=embedding_dim)
25
+ # You'd need to map your vocabulary to GloVe indices
26
+ # This is a simplified placeholder
27
+ self.embedding = nn.Embedding.from_pretrained(
28
+ glove.vectors[:vocab_size],
29
+ padding_idx=padding_idx,
30
+ freeze=False
31
+ )
32
+ except Exception as e:
33
+ print(f"Could not load pretrained embeddings: {e}")
34
+ # Fall back to random initialization
35
+ self.embedding = nn.Embedding(
36
+ vocab_size, embedding_dim, padding_idx=padding_idx
37
+ )
38
+ else:
39
+ # Random initialization
40
+ self.embedding = nn.Embedding(
41
+ vocab_size, embedding_dim, padding_idx=padding_idx
42
+ )
43
+
44
+ # LSTM layer
45
+ self.lstm = nn.LSTM(
46
+ embedding_dim,
47
+ hidden_dim,
48
+ num_layers=num_layers,
49
+ bidirectional=bidirectional,
50
+ batch_first=True,
51
+ dropout=dropout_rate if num_layers > 1 else 0
52
+ )
53
+
54
+ # Attention mechanism
55
+ self.attention = nn.Linear(hidden_dim * self.num_directions, 1)
56
+
57
+ # Layer normalization
58
+ self.layer_norm = nn.LayerNorm(hidden_dim * self.num_directions)
59
+
60
+ # Dropout layer
61
+ self.dropout = nn.Dropout(dropout_rate)
62
+
63
+ # Classification layer
64
+ self.classifier = nn.Linear(hidden_dim * self.num_directions, num_classes)
65
+
66
+ def forward(self, input_ids, attention_mask=None, **kwargs):
67
+ """
68
+ Forward pass through LSTM model
69
+
70
+ Args:
71
+ input_ids: Tensor of token ids [batch_size, seq_len]
72
+ attention_mask: Tensor indicating which tokens to attend to [batch_size, seq_len]
73
+ """
74
+ # Word embeddings
75
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
76
+
77
+ # Pass through LSTM
78
+ lstm_out, (hidden, cell) = self.lstm(embedded)
79
+ # lstm_out: [batch_size, seq_len, hidden_dim * num_directions]
80
+
81
+ # Apply attention
82
+ if attention_mask is not None:
83
+ # Apply attention mask (1 for tokens to attend to, 0 for padding)
84
+ attention_mask = attention_mask.unsqueeze(-1) # [batch_size, seq_len, 1]
85
+ attention_scores = self.attention(lstm_out) # [batch_size, seq_len, 1]
86
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e10)
87
+ attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
88
+
89
+ # Weighted sum
90
+ context_vector = torch.sum(attention_weights * lstm_out, dim=1) # [batch_size, hidden_dim * num_directions]
91
+ else:
92
+ # If no attention mask, use the last hidden state
93
+ if self.bidirectional:
94
+ # For bidirectional LSTM, concatenate last hidden states from both directions
95
+ last_hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # [batch_size, hidden_dim * 2]
96
+ else:
97
+ last_hidden = hidden[-1] # [batch_size, hidden_dim]
98
+
99
+ context_vector = last_hidden
100
+
101
+ # Layer normalization
102
+ normalized = self.layer_norm(context_vector)
103
+
104
+ # Dropout
105
+ dropped = self.dropout(normalized)
106
+
107
+ # Classification
108
+ logits = self.classifier(dropped)
109
+
110
+ return logits
111
+
112
+ class DocumentBiLSTM(nn.Module):
113
+ """
114
+ A simpler BiLSTM implementation that doesn't require pre-loaded embeddings
115
+ Good for getting started quickly
116
+ """
117
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
118
+ n_layers=2, dropout=0.5, pad_idx=0):
119
+ super().__init__()
120
+
121
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
122
+
123
+ self.lstm = nn.LSTM(embedding_dim,
124
+ hidden_dim,
125
+ num_layers=n_layers,
126
+ bidirectional=True,
127
+ dropout=dropout if n_layers > 1 else 0,
128
+ batch_first=True)
129
+
130
+ self.fc = nn.Linear(hidden_dim * 2, output_dim)
131
+
132
+ self.dropout = nn.Dropout(dropout)
133
+
134
+ def forward(self, input_ids, attention_mask=None, **kwargs):
135
+ # input_ids = [batch size, seq len]
136
+
137
+ # embedded = [batch size, seq len, emb dim]
138
+ embedded = self.embedding(input_ids)
139
+
140
+ # Apply dropout to embeddings
141
+ embedded = self.dropout(embedded)
142
+
143
+ if attention_mask is not None:
144
+ # Create packed sequence for variable length sequences
145
+ # This is a simplified version - in practice you'd use pack_padded_sequence
146
+ # but that requires knowing the actual sequence lengths
147
+ pass
148
+
149
+ # output = [batch size, seq len, hid dim * num directions]
150
+ # hidden = [n layers * num directions, batch size, hid dim]
151
+ # cell = [n layers * num directions, batch size, hid dim]
152
+ output, (hidden, cell) = self.lstm(embedded)
153
+
154
+ # Concatenate the final forward and backward hidden states
155
+ hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
156
+
157
+ # Apply dropout to hidden state
158
+ hidden = self.dropout(hidden)
159
+
160
+ # prediction = [batch size, output dim]
161
+ prediction = self.fc(hidden)
162
+
163
+ return prediction
train.py CHANGED
@@ -121,7 +121,7 @@ def main():
121
 
122
  # Train the model
123
  logger.info("Starting training...")
124
- save_path = os.path.join(args.output_dir, "best_model.pth")
125
  trainer.train(epochs=args.epochs, save_path=save_path)
126
 
127
  logger.info("Training completed!")
 
121
 
122
  # Train the model
123
  logger.info("Starting training...")
124
+ save_path = os.path.join(args.output_dir, "bert-base-uncased")
125
  trainer.train(epochs=args.epochs, save_path=save_path)
126
 
127
  logger.info("Training completed!")