Classification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets

#225
by AMCalejandro - opened

Hello,

I am trying to fine the best hyperparameters for a downstream task in which I want to do gene classification (distinguish dosage sensitive versus insensitive) on cells coming from patients with a specific phenotype.

I followed the gene classification notebook to modify the dataset so that my 'label' refers to dosage sensitive or dosage insensitive.
Then, I split in train and validation, and finally I get the trainer ( with a collator for Gene classification ) and then start the hyperparameters search.

This is the error message I am getting:

Trial status: 12 ERROR | 4 RUNNING | 1 PENDING
Current time: 2023-08-20 12:02:51. Total running time: 1hr 22min 13s
Logical resource usage: 12.0/12 CPUs, 4.0/4 GPUs (0.0/1.0 accelerator_type:V100)
[martinezcarraa2@biowulf DOSAGE_HYPERPARAMS]$
[martinezcarraa2@biowulf DOSAGE_HYPERPARAMS]$ cat /home/martinezcarraa2/ray_results/_objective_2023-08-20_10-40-38/_objective_e8616145_12_learning_rate=0.0001,lr_scheduler_type=polynomial,num_train_epochs=1,per_device_train_batch_size=12,seed=29_2023-08-20_11-23-56/error.txt
Failure # 1 (occurred at 2023-08-20_11-45-01)
ray::ImplicitFunc.train() (pid=2040340, ip=10.2.23.52, actor_id=d00ff41a94654f583dbbfb7a01000000, repr=_objective)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 375, in train
    raise skipped from exception_cause(skipped)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 349, in entrypoint
    return self._trainable_func(
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/ray/tune/trainable/function_trainable.py", line 666, in _trainable_func
    output = fn()
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/integrations.py", line 348, in dynamic_modules_import_trainable
    return trainable(*args, **kwargs)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/ray/tune/trainable/util.py", line 324, in inner
    return trainable(config, **fn_kwargs)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/integrations.py", line 249, in _objective
    local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/trainer.py", line 1901, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/trainer.py", line 2226, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/trainer.py", line 2934, in evaluate
    output = eval_loop(
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/transformers/trainer.py", line 3222, in evaluation_loop
    metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
  File "hyperparam_optimiz_for_disease_classifier.py", line 178, in compute_metrics
    acc = accuracy_score(labels, preds)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/sklearn/utils/_param_validation.py", line 211, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 220, in accuracy_score
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/gpfs/gsfs12/users/martinezcarraa2/conda/envs/hackgneformer/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 93, in _check_targets
    raise ValueError(
ValueError: Classification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets
[martinezcarraa2@biowulf DOSAGE_HYPERPARAMS]$
AMCalejandro changed discussion title from TypeError: forward() got an unexpected keyword argument 'label' to lassification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets
ctheodoris changed discussion title from lassification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets to Classification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets

Thank you for your interest in Geneformer! I have not encountered this error. Could you provide more specific information about the code you are running to elicit this error so we can help troubleshoot? This may be relevant: https://stackoverflow.com/questions/54589669/confusion-matrix-error-classification-metrics-cant-handle-a-mix-of-multilabel

ctheodoris changed discussion status to closed

Sign up or log in to comment