Spaces:
Runtime error
Runtime error
File size: 1,495 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
import json
import wandb
from utils.args import parse_args
from utils.logger import setup_logging, print_model_parameters
from data.dataloader import prepare_dataloaders
from models.model_factory import create_models, lora_factory
from training.trainer import Trainer
def main():
# Parse arguments
args = parse_args()
# Setup logging and wandb
logger = setup_logging(args)
if args.wandb:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
entity=args.wandb_entity,
config=vars(args)
)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Initialize models and tokenizer
if args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
model, plm_model, tokenizer = lora_factory(args)
else:
model, plm_model, tokenizer = create_models(args)
print_model_parameters(model, plm_model, logger)
# Prepare data with tokenizer
train_loader, val_loader, test_loader = prepare_dataloaders(args, tokenizer, logger)
# Create trainer
trainer = Trainer(args, model, plm_model, logger)
# Train and evaluate
trainer.train(train_loader, val_loader)
trainer.test(test_loader)
if args.wandb:
wandb.finish()
if __name__ == "__main__":
main() |