RobbiePasquale commited on
Commit
7f47926
1 Parent(s): c9a5651

Upload 2 files

Browse files
Files changed (2) hide show
  1. distill.py +264 -0
  2. main_menu_new.py +191 -0
distill.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset, random_split
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from datasets import load_dataset
7
+ from typing import List, Optional
8
+ import argparse
9
+ import os
10
+ import json
11
+ import jsonlines
12
+ from tqdm import tqdm
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ # Set up device
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ class CustomDataset(Dataset):
20
+ def __init__(self, inputs, labels):
21
+ self.inputs = inputs
22
+ self.labels = labels
23
+
24
+ def __len__(self):
25
+ return len(self.inputs)
26
+
27
+ def __getitem__(self, idx):
28
+ return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
29
+
30
+ def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
31
+ dataset = load_dataset(dataset_name, config)
32
+ if queries:
33
+ def filter_func(examples):
34
+ return any(query.lower() in examples["text"].lower() for query in queries)
35
+ dataset = dataset.filter(filter_func, batched=True)
36
+ return dataset
37
+
38
+ def prepare_data(tokenizer, dataset, max_length, batch_size):
39
+ # Tokenize the inputs and labels
40
+ tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
41
+ tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
42
+
43
+ # Create custom dataset
44
+ custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
45
+
46
+ # Split into training and validation sets
47
+ train_size = int(0.9 * len(custom_dataset))
48
+ val_size = len(custom_dataset) - train_size
49
+ train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
50
+
51
+ # Create DataLoaders
52
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
53
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
54
+
55
+ return train_loader, val_loader
56
+
57
+ def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
58
+ teacher.eval()
59
+ student.train()
60
+ total_loss = 0
61
+
62
+ for batch in tqdm(data_loader, desc="Training"):
63
+ inputs = batch["input_ids"].to(device)
64
+ labels = batch["labels"].to(device)
65
+
66
+ with autocast():
67
+ with torch.no_grad():
68
+ teacher_outputs = teacher(inputs).logits
69
+ teacher_logits = teacher_outputs / temperature
70
+
71
+ student_outputs = student(inputs).logits
72
+ student_logits = student_outputs / temperature
73
+
74
+ # Compute KL Divergence Loss
75
+ loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
76
+ loss = loss * (temperature ** 2) # Scale loss by temperature squared
77
+
78
+ scaler.scale(loss).backward()
79
+ scaler.step(optimizer)
80
+ scaler.update()
81
+ optimizer.zero_grad()
82
+
83
+ total_loss += loss.item()
84
+
85
+ avg_loss = total_loss / len(data_loader)
86
+ return avg_loss
87
+
88
+ def validate(teacher, student, data_loader, criterion, temperature=2.0):
89
+ teacher.eval()
90
+ student.eval()
91
+ total_loss = 0
92
+
93
+ with torch.no_grad():
94
+ for batch in tqdm(data_loader, desc="Validation"):
95
+ inputs = batch["input_ids"].to(device)
96
+ labels = batch["labels"].to(device)
97
+
98
+ teacher_outputs = teacher(inputs).logits
99
+ teacher_logits = teacher_outputs / temperature
100
+
101
+ student_outputs = student(inputs).logits
102
+ student_logits = student_outputs / temperature
103
+
104
+ loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
105
+ loss = loss * (temperature ** 2)
106
+
107
+ total_loss += loss.item()
108
+
109
+ avg_loss = total_loss / len(data_loader)
110
+ return avg_loss
111
+
112
+ def save_checkpoint(state, save_dir, epoch):
113
+ os.makedirs(save_dir, exist_ok=True)
114
+ checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
115
+ torch.save(state, checkpoint_path)
116
+ print(f"Checkpoint saved at {checkpoint_path}")
117
+
118
+ def load_checkpoint(model, optimizer, scheduler, scaler, save_dir, epoch):
119
+ checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
120
+ if os.path.isfile(checkpoint_path):
121
+ checkpoint = torch.load(checkpoint_path)
122
+ model.load_state_dict(checkpoint['model_state_dict'])
123
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
124
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
125
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
126
+ print(f"Loaded checkpoint from {checkpoint_path}")
127
+ else:
128
+ print(f"No checkpoint found at {checkpoint_path}")
129
+
130
+ def distill_model(
131
+ teacher_model_name: str,
132
+ student_model_name: str,
133
+ dataset_name: str,
134
+ config: str,
135
+ distill_full_model: bool = True,
136
+ query_terms: Optional[List[str]] = None,
137
+ num_epochs: int = 3,
138
+ batch_size: int = 4,
139
+ max_length: int = 128,
140
+ learning_rate: float = 5e-5,
141
+ temperature: float = 2.0,
142
+ save_path: str = "./distilled_model",
143
+ log_dir: str = "./logs",
144
+ checkpoint_dir: str = "./checkpoints",
145
+ early_stopping_patience: int = 3
146
+ ):
147
+ # Initialize TensorBoard writer
148
+ writer = SummaryWriter(log_dir=log_dir)
149
+
150
+ # Load tokenizer
151
+ tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
152
+ if tokenizer.pad_token is None:
153
+ tokenizer.pad_token = tokenizer.eos_token
154
+
155
+ # Load teacher and student models
156
+ teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
157
+ student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
158
+
159
+ # Optionally freeze teacher model parameters
160
+ for param in teacher.parameters():
161
+ param.requires_grad = False
162
+
163
+ # Load and prepare dataset
164
+ if distill_full_model:
165
+ dataset = load_dataset(dataset_name, config)
166
+ else:
167
+ dataset = load_filtered_dataset(dataset_name, config, query_terms)
168
+
169
+ train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
170
+
171
+ # Define optimizer, scheduler, and scaler for mixed precision
172
+ optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
173
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
174
+ scaler = GradScaler()
175
+
176
+ # Define loss criterion
177
+ criterion = nn.KLDivLoss(reduction="batchmean")
178
+
179
+ best_val_loss = float('inf')
180
+ epochs_no_improve = 0
181
+
182
+ # Training loop
183
+ for epoch in range(1, num_epochs + 1):
184
+ print(f"\nEpoch {epoch}/{num_epochs}")
185
+ print("-" * 20)
186
+
187
+ # Training
188
+ train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
189
+ print(f"Training Loss: {train_loss:.4f}")
190
+ writer.add_scalar("Loss/Train", train_loss, epoch)
191
+
192
+ # Validation
193
+ val_loss = validate(teacher, student, val_loader, criterion, temperature)
194
+ print(f"Validation Loss: {val_loss:.4f}")
195
+ writer.add_scalar("Loss/Validation", val_loss, epoch)
196
+
197
+ # Check for improvement
198
+ if val_loss < best_val_loss:
199
+ best_val_loss = val_loss
200
+ epochs_no_improve = 0
201
+ # Save the best model
202
+ save_checkpoint({
203
+ 'epoch': epoch,
204
+ 'model_state_dict': student.state_dict(),
205
+ 'optimizer_state_dict': optimizer.state_dict(),
206
+ 'scheduler_state_dict': scheduler.state_dict(),
207
+ 'scaler_state_dict': scaler.state_dict(),
208
+ 'best_val_loss': best_val_loss
209
+ }, checkpoint_dir, epoch)
210
+ # Save the model as the best one
211
+ student.save_pretrained(save_path)
212
+ tokenizer.save_pretrained(save_path)
213
+ print(f"Best model saved at epoch {epoch}")
214
+ else:
215
+ epochs_no_improve += 1
216
+ print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
217
+ if epochs_no_improve >= early_stopping_patience:
218
+ print("Early stopping triggered")
219
+ break
220
+
221
+ # Step the scheduler
222
+ scheduler.step()
223
+
224
+ writer.close()
225
+ print("\nDistillation completed.")
226
+
227
+ def main():
228
+ parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one.")
229
+ parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
230
+ parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
231
+ parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
232
+ parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
233
+ parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill the full model or not")
234
+ parser.add_argument("--query_terms", type=str, nargs="+", help="Query terms for filtering the dataset")
235
+ parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
236
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
237
+ parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
238
+ parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
239
+ parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
240
+ parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
241
+ parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
242
+ parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
243
+ parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
244
+ return parser.parse_args()
245
+
246
+ if __name__ == "__main__":
247
+ args = main()
248
+ distill_model(
249
+ teacher_model_name=args.teacher_model_name,
250
+ student_model_name=args.student_model_name,
251
+ dataset_name=args.dataset_name,
252
+ config=args.config,
253
+ distill_full_model=args.distill_full_model,
254
+ query_terms=args.query_terms,
255
+ num_epochs=args.num_epochs,
256
+ batch_size=args.batch_size,
257
+ max_length=args.max_length,
258
+ learning_rate=args.learning_rate,
259
+ temperature=args.temperature,
260
+ save_path=args.save_path,
261
+ log_dir=args.log_dir,
262
+ checkpoint_dir=args.checkpoint_dir,
263
+ early_stopping_patience=args.early_stopping_patience
264
+ )
main_menu_new.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main_menu.py
2
+
3
+ import argparse
4
+ import sys
5
+ import os
6
+ from train_agent import train_agent
7
+ from test_agent import TestAgent, run_test_session
8
+ from twisted.internet import reactor, task
9
+ from lightbulb_custom import main as lightbulb_custom_main
10
+ from distillation_pipeline import distill_model # Import the distillation function
11
+ from transformers import logging
12
+
13
+ # Suppress transformers warnings for cleaner output
14
+ logging.set_verbosity_error()
15
+
16
+ def parse_main_args():
17
+ parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
18
+
19
+ # Task selection
20
+ parser.add_argument('--task', type=str, choices=[
21
+ 'train_llm_world',
22
+ 'train_agent',
23
+ 'test_agent',
24
+ 'inference_llm',
25
+ 'inference_world_model',
26
+ 'advanced_inference',
27
+ 'distill_full_model', # New option for full model distillation
28
+ 'distill_domain_specific' # New option for selective distillation
29
+ ],
30
+ required=True,
31
+ help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific')
32
+
33
+ # Common arguments
34
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
35
+ parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation')
36
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
37
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
38
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
39
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
40
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
41
+ parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature')
42
+ parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
43
+
44
+ # Distillation-specific arguments
45
+ parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model")
46
+ parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs")
47
+ parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints")
48
+ parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience")
49
+
50
+ # Inference-specific arguments
51
+ parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
52
+ parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode')
53
+ parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference')
54
+ parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference')
55
+ parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference')
56
+ parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference')
57
+
58
+ # Distillation-specific arguments
59
+ parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not")
60
+ parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation")
61
+
62
+ # Load model for inference
63
+ parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference')
64
+
65
+ return parser.parse_args()
66
+
67
+ def main():
68
+ # Parse arguments for the main function
69
+ args = parse_main_args()
70
+
71
+ # Execute tasks based on user input
72
+ if args.task == 'train_llm_world':
73
+ print("Starting LLM and World Model Training...")
74
+ # Directly call the world model main function with appropriate arguments
75
+ sys.argv = [
76
+ 'lightbulb_custom.py',
77
+ '--mode', 'train',
78
+ '--model_name', args.model_name,
79
+ '--dataset_name', args.dataset_name,
80
+ '--dataset_config', args.dataset_config,
81
+ '--batch_size', str(args.batch_size),
82
+ '--num_epochs', str(args.num_epochs),
83
+ '--max_length', str(args.max_length)
84
+ ]
85
+ lightbulb_custom_main()
86
+
87
+ elif args.task == 'train_agent':
88
+ print("Starting Agent Training...")
89
+ # Call the train_agent function from train_agent.py using Twisted reactor
90
+ d = task.deferLater(reactor, 0, train_agent)
91
+ d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
92
+ d.addBoth(lambda _: reactor.stop())
93
+ reactor.run()
94
+
95
+ elif args.task == 'test_agent':
96
+ print("Starting Test Agent...")
97
+ test_agent = TestAgent()
98
+ if args.query:
99
+ # Directly process a single query
100
+ result = test_agent.process_query(args.query)
101
+ print("\nAgent's response:")
102
+ print(result)
103
+ else:
104
+ # Run the interactive session
105
+ reactor.callWhenRunning(run_test_session)
106
+ reactor.run()
107
+
108
+ elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
109
+ print("Starting Inference Task...")
110
+ # Prepare the arguments for lightbulb_custom.py based on the selected inference task
111
+
112
+ # Map the main_menu task to lightbulb_custom.py's inference_mode
113
+ inference_mode_map = {
114
+ 'inference_llm': 'without_world_model',
115
+ 'inference_world_model': 'world_model',
116
+ 'advanced_inference': 'world_model_tree_of_thought'
117
+ }
118
+
119
+ selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
120
+
121
+ # Construct sys.argv for lightbulb_custom.py
122
+ lightbulb_inf_args = [
123
+ 'lightbulb_custom.py',
124
+ '--mode', 'inference',
125
+ '--model_name', args.model_name,
126
+ '--query', args.query,
127
+ '--max_length', str(args.max_length),
128
+ '--inference_mode', selected_inference_mode,
129
+ '--beam_size', str(args.beam_size),
130
+ '--n_tokens_predict', str(args.n_tokens_predict),
131
+ '--mcts_iterations', str(args.mcts_iterations),
132
+ '--mcts_exploration_constant', str(args.mcts_exploration_constant)
133
+ ]
134
+
135
+ # Include additional arguments if they exist
136
+ if args.load_model:
137
+ lightbulb_inf_args += ['--load_model', args.load_model]
138
+
139
+ # Update sys.argv and call the inference main function
140
+ sys.argv = lightbulb_inf_args
141
+ lightbulb_custom_main()
142
+
143
+ elif args.task == 'distill_full_model':
144
+ print("Starting Full Model Distillation...")
145
+ distill_model(
146
+ teacher_model_name=args.model_name,
147
+ student_model_name=args.student_model_name,
148
+ dataset_name=args.dataset_name,
149
+ config=args.dataset_config,
150
+ distill_full_model=True,
151
+ query_terms=None,
152
+ num_epochs=args.num_epochs,
153
+ batch_size=args.batch_size,
154
+ max_length=args.max_length,
155
+ learning_rate=args.learning_rate,
156
+ temperature=args.temperature,
157
+ save_path=args.save_path,
158
+ log_dir=args.log_dir,
159
+ checkpoint_dir=args.checkpoint_dir,
160
+ early_stopping_patience=args.early_stopping_patience
161
+ )
162
+
163
+ elif args.task == 'distill_domain_specific':
164
+ print("Starting Domain-Specific Distillation...")
165
+ if not args.query_terms:
166
+ print("Error: --query_terms must be provided for domain-specific distillation.")
167
+ sys.exit(1)
168
+ distill_model(
169
+ teacher_model_name=args.model_name,
170
+ student_model_name=args.student_model_name,
171
+ dataset_name=args.dataset_name,
172
+ config=args.dataset_config,
173
+ distill_full_model=False,
174
+ query_terms=args.query_terms,
175
+ num_epochs=args.num_epochs,
176
+ batch_size=args.batch_size,
177
+ max_length=args.max_length,
178
+ learning_rate=args.learning_rate,
179
+ temperature=args.temperature,
180
+ save_path=args.save_path,
181
+ log_dir=args.log_dir,
182
+ checkpoint_dir=args.checkpoint_dir,
183
+ early_stopping_patience=args.early_stopping_patience
184
+ )
185
+
186
+ else:
187
+ print(f"Unknown task: {args.task}")
188
+ sys.exit(1)
189
+
190
+ if __name__ == "__main__":
191
+ main()