Asmitha-28 commited on
Commit
5cd9136
Β·
verified Β·
1 Parent(s): a516256

Upload src/train_router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/train_router.py +263 -0
src/train_router.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/train_router.py
2
+ # Fine-tune DistilBERT for 8-class ticket routing
3
+ # SupportMind v1.0 β€” Asmitha
4
+ #
5
+ # Memory-optimized for machines with limited RAM:
6
+ # - max_length=64 (tickets are short, saves ~4x memory vs 256)
7
+ # - batch_size=2 (minimal footprint)
8
+ # - gradient_accumulation_steps=8 (effective batch=16)
9
+ # - fp16=True if CUDA available
10
+ # - Datasets cleared before model loading
11
+
12
+ import os
13
+ import sys
14
+ import gc
15
+
16
+ # Disable TensorFlow to prevent DLL loading errors under Application Control policies
17
+ os.environ['USE_TF'] = '0'
18
+ os.environ['USE_JAX'] = '0'
19
+ # Limit torch threads to reduce memory pressure
20
+ os.environ['OMP_NUM_THREADS'] = '1'
21
+ os.environ['MKL_NUM_THREADS'] = '1'
22
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
23
+
24
+ import pandas as pd
25
+ import torch
26
+ import logging
27
+ from transformers import (
28
+ DistilBertTokenizer,
29
+ DistilBertForSequenceClassification,
30
+ Trainer,
31
+ TrainingArguments,
32
+ TrainerCallback
33
+ )
34
+ from transformers.trainer_utils import get_last_checkpoint
35
+ import psutil
36
+ from datasets import Dataset
37
+ import numpy as np
38
+
39
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
40
+ logger = logging.getLogger(__name__)
41
+
42
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
43
+ DATA_DIR = os.path.join(BASE_DIR, 'data', 'processed')
44
+ MODEL_DIR = os.path.join(BASE_DIR, 'models', 'ticket_classifier')
45
+
46
+ # Shorter max_length β€” support tickets are typically short
47
+ # 64 tokens is enough to capture intent from these tickets
48
+ MAX_LENGTH = 64
49
+
50
+
51
+ class MemoryProfilerCallback(TrainerCallback):
52
+ """Logs memory usage + progress summary every N steps."""
53
+ def __init__(self, total_steps: int):
54
+ import os
55
+ self.process = psutil.Process(os.getpid())
56
+ self.total_steps = total_steps
57
+
58
+ def on_step_end(self, args, state, control, **kwargs):
59
+ if state.global_step % args.logging_steps == 0:
60
+ mem_mb = self.process.memory_info().rss / (1024 * 1024)
61
+ pct = (state.global_step / self.total_steps) * 100 if self.total_steps else 0
62
+ logger.info(
63
+ f"[{pct:5.1f}%] Step {state.global_step}/{self.total_steps} "
64
+ f"| Epoch {state.epoch:.2f} | RAM: {mem_mb:.0f} MB"
65
+ )
66
+
67
+
68
+ def compute_metrics(eval_pred):
69
+ """Compute accuracy metric for evaluation."""
70
+ logits, labels = eval_pred
71
+ predictions = np.argmax(logits, axis=-1)
72
+ accuracy = (predictions == labels).astype(np.float32).mean()
73
+ return {"accuracy": float(accuracy)}
74
+
75
+
76
+ def main():
77
+ train_path = os.path.join(DATA_DIR, 'train.csv')
78
+ val_path = os.path.join(DATA_DIR, 'val.csv')
79
+
80
+ if not os.path.exists(train_path):
81
+ logger.error(f"Training data not found at {train_path}. Run data/preprocess.py first.")
82
+ sys.exit(1)
83
+
84
+ # ── Step 1: Load & tokenize data ──────────────────────
85
+ logger.info("Loading processed datasets...")
86
+ train_df = pd.read_csv(train_path)
87
+ val_df = pd.read_csv(val_path)
88
+ logger.info(f"Train: {len(train_df)} samples, Val: {len(val_df)} samples")
89
+ logger.info(f"Label distribution:\n{train_df['label'].value_counts().to_string()}")
90
+
91
+ # Check device
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ use_fp16 = device == "cuda"
94
+ logger.info(f"Device: {device} | FP16: {use_fp16}")
95
+
96
+ # Convert to HF Datasets
97
+ train_dataset = Dataset.from_pandas(train_df[['text', 'label']])
98
+ val_dataset = Dataset.from_pandas(val_df[['text', 'label']])
99
+
100
+ # Free DataFrame memory before tokenization
101
+ del train_df, val_df
102
+ gc.collect()
103
+
104
+ logger.info("Initializing Tokenizer...")
105
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
106
+
107
+ def tokenize_function(examples):
108
+ return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)
109
+
110
+ logger.info("Tokenizing datasets...")
111
+ tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=64)
112
+ tokenized_val = val_dataset.map(tokenize_function, batched=True, batch_size=64)
113
+
114
+ # Free raw datasets
115
+ del train_dataset, val_dataset
116
+ gc.collect()
117
+
118
+ # ── Step 2: Compute class weights for imbalanced data ─
119
+ from sklearn.utils.class_weight import compute_class_weight
120
+
121
+ labels_array = tokenized_train['label']
122
+ unique_labels = sorted(set(labels_array))
123
+ class_weights = compute_class_weight(
124
+ class_weight='balanced',
125
+ classes=np.array(unique_labels),
126
+ y=np.array(labels_array)
127
+ )
128
+ # Map to all 8 classes (some might be missing)
129
+ weight_dict = {c: w for c, w in zip(unique_labels, class_weights)}
130
+ weights_tensor = torch.tensor(
131
+ [weight_dict.get(i, 1.0) for i in range(8)], dtype=torch.float32
132
+ )
133
+ logger.info(f"Class weights: {weights_tensor.tolist()}")
134
+
135
+ # ── Step 3: Load model ───────────────────────────��────
136
+ logger.info("Loading DistilBERT model...")
137
+ model = DistilBertForSequenceClassification.from_pretrained(
138
+ 'distilbert-base-uncased',
139
+ num_labels=8
140
+ )
141
+ param_count = sum(p.numel() for p in model.parameters())
142
+ logger.info(f"Model loaded. Parameters: {param_count:,}")
143
+
144
+ # ── Freeze base layers β€” only fine-tune last 2 transformer layers + head ─
145
+ # Freezing layers[0-3] cuts trainable params from 67M to ~7M,
146
+ # reducing peak RAM from ~3.5 GB to ~800 MB. Quality impact is minimal
147
+ # because the ticket vocabulary is similar to DistilBERT pretraining data.
148
+ for name, param in model.named_parameters():
149
+ param.requires_grad = False # freeze everything first
150
+
151
+ # Unfreeze: last 2 transformer layers (layer 4 and 5 of 6)
152
+ for name, param in model.named_parameters():
153
+ if any(key in name for key in [
154
+ 'transformer.layer.4',
155
+ 'transformer.layer.5',
156
+ 'pre_classifier',
157
+ 'classifier',
158
+ ]):
159
+ param.requires_grad = True
160
+
161
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
162
+ total = sum(p.numel() for p in model.parameters())
163
+ logger.info(f"Trainable params: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")
164
+
165
+ # Force garbage collection after model load
166
+ gc.collect()
167
+
168
+ # ── Step 4: Custom Trainer with weighted loss ─────────
169
+ from torch.nn import CrossEntropyLoss
170
+
171
+ class WeightedTrainer(Trainer):
172
+ """Trainer with class-weighted cross-entropy for imbalanced datasets."""
173
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
174
+ labels = inputs.pop("labels")
175
+ outputs = model(**inputs)
176
+ logits = outputs.logits
177
+ loss_fn = CrossEntropyLoss(weight=weights_tensor.to(logits.device))
178
+ loss = loss_fn(logits, labels)
179
+ return (loss, outputs) if return_outputs else loss
180
+
181
+ # ── Step 5: Training ──────────────────────────────────
182
+ # batch=1 with gradient_accumulation=16 gives effective batch=16
183
+ # gradient_checkpointing trades compute for memory (critical on 5GB RAM)
184
+ # Total steps = (train_samples / effective_batch) * epochs
185
+ # 2800 / 16 * 5 = 875 steps
186
+ total_steps = (len(tokenized_train) // 16) * 5
187
+
188
+ training_args = TrainingArguments(
189
+ output_dir=os.path.join(BASE_DIR, 'results'),
190
+ num_train_epochs=5,
191
+ per_device_train_batch_size=1,
192
+ per_device_eval_batch_size=1,
193
+ gradient_accumulation_steps=16,
194
+ gradient_checkpointing=True,
195
+ warmup_steps=50,
196
+ weight_decay=0.01,
197
+ learning_rate=3e-5,
198
+ logging_dir=os.path.join(BASE_DIR, 'logs'),
199
+ logging_steps=25, # Log every 25 steps (~2 min on CPU)
200
+ evaluation_strategy="steps",
201
+ eval_steps=50, # Evaluate every 50 steps (~4 min)
202
+ save_strategy="steps",
203
+ save_steps=50, # Must equal eval_steps when load_best_model_at_end=True
204
+ save_total_limit=3, # Keep 3 checkpoints (~75 steps of safety)
205
+ load_best_model_at_end=True,
206
+ metric_for_best_model="accuracy",
207
+ fp16=False,
208
+ dataloader_num_workers=0,
209
+ report_to="none",
210
+ use_cpu=True,
211
+ optim="adafactor", # Much less memory than AdamW
212
+ )
213
+
214
+ trainer = WeightedTrainer(
215
+ model=model,
216
+ args=training_args,
217
+ train_dataset=tokenized_train,
218
+ eval_dataset=tokenized_val,
219
+ compute_metrics=compute_metrics,
220
+ callbacks=[MemoryProfilerCallback(total_steps=total_steps)],
221
+ )
222
+
223
+ logger.info("=" * 60)
224
+ logger.info("Starting DistilBERT fine-tuning (5 epochs, weighted loss)...")
225
+ logger.info(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
226
+ logger.info(f" Max sequence length: {MAX_LENGTH}")
227
+ logger.info(f" Training samples: {len(tokenized_train)}")
228
+ logger.info("=" * 60)
229
+
230
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
231
+ if last_checkpoint is not None:
232
+ logger.info(f"Resuming training from checkpoint: {last_checkpoint}")
233
+ else:
234
+ logger.info("No checkpoint found. Starting from scratch.")
235
+
236
+ trainer.train(resume_from_checkpoint=last_checkpoint)
237
+
238
+ # ── Step 6: Evaluate ──────────────────────────────────
239
+ logger.info("Running final evaluation...")
240
+ eval_results = trainer.evaluate()
241
+ logger.info(f"Eval results: {eval_results}")
242
+
243
+ # ── Step 7: Save ──────────────────────────────────────
244
+ logger.info(f"Saving fine-tuned model to {MODEL_DIR}")
245
+ os.makedirs(MODEL_DIR, exist_ok=True)
246
+ model.save_pretrained(MODEL_DIR)
247
+ tokenizer.save_pretrained(MODEL_DIR)
248
+
249
+ # Save eval results
250
+ import json
251
+ results_path = os.path.join(BASE_DIR, 'results', 'training_results.json')
252
+ os.makedirs(os.path.dirname(results_path), exist_ok=True)
253
+ with open(results_path, 'w') as f:
254
+ json.dump(eval_results, f, indent=2, default=str)
255
+ logger.info(f"Results saved to {results_path}")
256
+
257
+ logger.info("=" * 60)
258
+ logger.info("Training complete! Model is ready for inference.")
259
+ logger.info("=" * 60)
260
+
261
+
262
+ if __name__ == '__main__':
263
+ main()