license: mit
inference: false
datasets:
- ibm/otter_stitch
Otter STITCH DistMult Model Card
Model details
Otter models are based on Graph Neural Networks (GNN) that propagates initial embeddings through a set of layers that upgrade input embedding according to the node neighbours. The architecture of GNN consists of two main blocks: encoder and decoder.
- For encoder we first define a projection layer which consists of a set of linear transformations for each node modality and projects nodes into common dimensionality, then we apply several multi-relational graph convolutional layers (R-GCN) which distinguish between different types of edges between source and target nodes by having a set of trainable parameters for each edge type.
- For decoder we consider link prediction task, which consists of a scoring function that maps each triple of source and target nodes and the corresponding edge and maps that to a scalar number defined over interval [0; 1].
Model type:
For link prediction, we consider three choices of scoring functions: DistMult, TransE and a Binary Classifier that are commonly used in the literature. The outcomes of scoring of each triple are then compared against actual labels using negative log likelihood loss function.
- Flow control: One crucial aspect of pretraining the GNN involves addressing the disparity between the data accessible during pretraining and the data accessible during subsequent tasks. Specifically, during pretraining, there are numerous attributes associated with proteins or drugs, whereas during downstream fine-tuning, only amino acid sequences and SMILES are available. Consequently, during pretraining, we explore two scenarios: one which controls the information propagated to the Drug/Protein entities and one without such control. In our experiments, we present results for both cases to provide an insight on the impact of restricting information flow during pretraining on the subsequent tasks.
- Noisy Links: An additional significant consideration is the presence of noisy links within the up-stream data and how they affect the downstream tasks. To investigate the potential impact on these tasks, we manually handpick a subset of links from each database that are relevant to drug discovery (see details in the Appendix). We then compare the outcomes when training the GNN using only these restricted links versus using all possible links present in the graphs.
- Regression: Certain pretraining datasets, like Uniprot, contain numerical data properties. Hence, we incorporate an extra regression objective aimed at minimizing the root mean square error (MSE) of the predicted numerical data properties. In the learning process, we combine the regression objective and the link prediction objective to create a single objective function.
Scoring Type | Noisy Links | Flow Control | Regression |
---|---|---|---|
DistMult | No | Yes | No |
Model training data:
The model was trained over STITCH. STITCH (Search Tool for Interacting Chemicals) is a database of known and predicted interactions between chemicals represented by SMILES strings and proteins whose sequences are taken from STRING database. It contains 10,717,791 triples for 17,572 different chemicals and 1,886,496 different proteins. Furthermore, the graph was split into 5 roughly same size subgraphs and GNN was trained sequentially on each of them by upgrading the model trained using the previous subgraph.
Model results:
Dataset | DTI DG | DAVIS | KIBA | ||||
---|---|---|---|---|---|---|---|
Splits | Temporal | Random | Target | Drug | Random | Target | Drug |
Results | 0.575 | 0.808 | 0.573 | 0.138 | 0.859 | 0.615 | 0.603 |
Paper or resources for more information:
License:
MIT
Where to send questions or comments about the model:
How to use
Clone the repo:
git clone https://github.com/IBM/otter-knowledge.git
cd otter-knowledge
Run the inference for Proteins:
Replace test_data with the path to a CSV file containing the protein sequences, name_of_the_column with the name of the column of the protein sequence in the CSV and output_path with the filename of the JSON file to be created with the embeddings.
python inference.py --input_path test_data --sequence_column name_of_the_column --model_path ibm/otter_stitch_distmult --output_path output_path
Run the inference for Drugs:
Replace test_data with the path to a CSV file containing the Drug SMILES, name_of_the_column with the name of the column of the SMILES in the CSV and output_path with the filename of the JSON file to be created with the embeddings..*
python inference.py --input_path test_data --sequence_column name_of_the_column input_type Drug --relation_name smiles --model_path ibm/otter_stitch_distmult --output_path output_path