fine tuning adapter take longer than fine tuning whole model

#74
by hail75 - opened

I want to train only adapter with my data. Problems are:

  • Much slower than when 'lora_main_params_trainable': True
  • When setting default task, it will use the passage adapter to encode my query for loss and evaluation metrics
    Any ideas to solve this, thanks a lot
dataset = load_dataset('json', data_files='jina-ft-data.jsonl')['train'].train_test_split(test_size=1000, seed=42)
train_dataset = dataset['train']
eval_dataset = dataset['test']

loss = MultipleNegativesRankingLoss(jina)

args = SentenceTransformerTrainingArguments(
    output_dir='jina-embeddings-v3',
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    lr_scheduler_type='cosine',
    warmup_ratio=0.1,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy='steps',
    eval_steps=1000,
    save_strategy='steps',
    save_steps=1000,
    save_total_limit=2,
    logging_steps=1000,
    load_best_model_at_end=True,
    metric_for_best_model='cosine_accuracy',
)

evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative_1"],
)

print(evaluator(jina))

trainer = SentenceTransformerTrainer(
    model=jina,
    loss=loss,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    evaluator=evaluator
)

trainer.train()```
Jina AI org

When you set `'lora_main_params_trainable': True, the main parameters are trained, which slows down the process compared to training only a single adapter (which is about 1% of the total weights). In this case, you should not set a task and instead use the main parameters directly.

When setting default task, it will use the passage adapter to encode my query for loss and evaluation metrics
Any ideas to solve this, thanks a lot

Yes, unfortunately currently there's no easy way to fine-tune passage and query adapters at the same time.

When you set 'lora_main_params_trainable': True, the main parameters are trained, which slows down the process compared to training only a single adapter (which is about 1% of the total weights). In this case, you should not set a task and instead use the main parameters directly.

@jupyterjazz Maybe you misunderstood, when i set'lora_main_params_trainable': True without a task specified, it's 2 time faster

Sign up or log in to comment