Question about data splitting

#317
by MinieRosie - opened

Hello, I am continuing to explore your geneformer package and it's been really great.

I had a question about how to properly split data. I know in your publication you split the data by patient, so a certain number of patients were in training 70%, validation 15%, and testing 15% (out of sample). Your three conditions had a relatively even balance of patients and cells.

For my data, I have two conditions and they are unbalanced. For one condition there are 50 patients and for the other condition there are 9 patients. In my situation I randomly split 75% (training) 15%(validation) 15%(test) of the 50 and 9 patients separately.
For one cell type (type 1) I get a good accuracy (0.89) but for another cell type (type 2) I get a poor accuracy (0.61). There are about 5 cell types within this tissue so I am still in the process of testing the classifier on these other cell types.

I had tried a different way of splitting the data where instead of separating by patient, I just pooled all cells together (by condition) and randomly split the data (again evenly by condition). In this situation, I get a very high accuracy score (0.97) for cell type 2. I am wondering if this way of splitting the data is incorrect and what is the purpose of splitting by patient. I am also questioning if maybe the unevenness of the amount of cells (about 200,000 for condition A and 25,000 for condition B) is giving me an inflated accuracy and/or why the accuracy is so different between data split methods. What do you think is the best way to split my data for training/validation/test?

Thank you very much for your time and advice.

Thank you for your question! We split by patient because the cells within each patient can be correlated. If the model has seen cells from patient A already in the training, it will be easier for it to correctly classify the cells from that patient at test time. So, if you split by patient, you can better assess the generalizability of the model in terms of how it would perform if it saw data from a new patient it hadn't seen before.

If you notice that unbalanced data is causing issues, you can balance the data so that the model doesn't learn to predict the major class because it will already get good accuracy just by doing that. It's good to maintain the diversity of the larger number of patients, but you can reduce the number of cells from each patient in that condition that has more patients/cells, for example. Please check out the classifier module for some tools to help with splitting the data, stratifying by certain attributes, setting max number of cells per class, etc. When your data is unbalanced, macro F1 score can provide a better measure of the model performance.

Also, of course I don't know the details of your project, but in some diseases for example, there are particular cell types driving the disease that will be different in normal vs. disease and expected to be differentiated by a classifier, whereas other cell types may not be playing a major role and are more or less unaffected. So, that's another thing to think about in general in terms of the biology.

ctheodoris changed discussion status to closed

It is true that it is unclear if this other cell type is really driving any kind of disease, it really is just a hypothesis in the field at this time. However I will do as you suggest and balance the data by under-sampling from the patients in the major class so there is no question that the model is being affected by this.

Thank you so much for your prompt response

Sign up or log in to comment