Spaces:
Runtime error
Runtime error
import os | |
os.environ["HF_ENDPOINT"]="https://hf-mirror.com" | |
import json | |
import wandb | |
from utils.args import parse_args | |
from utils.logger import setup_logging, print_model_parameters | |
from data.dataloader import prepare_dataloaders | |
from models.model_factory import create_models, lora_factory | |
from training.trainer import Trainer | |
def main(): | |
# Parse arguments | |
args = parse_args() | |
# Setup logging and wandb | |
logger = setup_logging(args) | |
if args.wandb: | |
wandb.init( | |
project=args.wandb_project, | |
name=args.wandb_run_name, | |
entity=args.wandb_entity, | |
config=vars(args) | |
) | |
# Create output directory | |
os.makedirs(args.output_dir, exist_ok=True) | |
# Initialize models and tokenizer | |
if args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
model, plm_model, tokenizer = lora_factory(args) | |
else: | |
model, plm_model, tokenizer = create_models(args) | |
print_model_parameters(model, plm_model, logger) | |
# Prepare data with tokenizer | |
train_loader, val_loader, test_loader = prepare_dataloaders(args, tokenizer, logger) | |
# Create trainer | |
trainer = Trainer(args, model, plm_model, logger) | |
# Train and evaluate | |
trainer.train(train_loader, val_loader) | |
trainer.test(test_loader) | |
if args.wandb: | |
wandb.finish() | |
if __name__ == "__main__": | |
main() |