File size: 1,495 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
import json
import wandb
from utils.args import parse_args
from utils.logger import setup_logging, print_model_parameters
from data.dataloader import prepare_dataloaders
from models.model_factory import create_models, lora_factory
from training.trainer import Trainer

def main():
    # Parse arguments
    args = parse_args()
    
    # Setup logging and wandb
    logger = setup_logging(args)
    if args.wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.wandb_run_name,
            entity=args.wandb_entity,
            config=vars(args)
        )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize models and tokenizer
    if args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
        model, plm_model, tokenizer = lora_factory(args)
    else:
        model, plm_model, tokenizer = create_models(args)
    print_model_parameters(model, plm_model, logger)
    
    # Prepare data with tokenizer
    train_loader, val_loader, test_loader = prepare_dataloaders(args, tokenizer, logger)
    
    # Create trainer
    trainer = Trainer(args, model, plm_model, logger)
    
    # Train and evaluate
    trainer.train(train_loader, val_loader)
    trainer.test(test_loader)
    
    if args.wandb:
        wandb.finish()

if __name__ == "__main__":
    main()