File size: 1,975 Bytes
f16bcc1
 
 
 
 
 
 
 
 
 
 
 
5b925e6
 
4bb0bd0
 
5b925e6
4bb0bd0
 
 
 
 
 
f16bcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
---
license: mit
language:
- en
library_name: transformers
tags:
- esm
- esm2
- protein language model
- biology
---

# ESM-2 (`esm2_t6_8M_UR50D`)

This is a fine-tuned version of [ESM-2](https://huggingface.co/facebook/esm2_t6_8M_UR50D) for sequence classification 
that categorizes protein sequences into two classes, either "cystolic" or "membrane". 

## Training and Accuracy

The model is trained using [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb)
and achieved an eval accuracy of 94.83163664839468 %. 

## Using the Model
To use try running:
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Initialize the tokenizer and model
model_path_directory = "AmelieSchreiber/esm2_t6_8M_UR50D-finetuned-localization"
tokenizer = AutoTokenizer.from_pretrained(model_path_directory)
model = AutoModelForSequenceClassification.from_pretrained(model_path_directory)

# Define a function to predict the category of a protein sequence
def predict_category(sequence):
    # Tokenize the sequence and convert it to tensor format
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=512, padding="max_length")

    # Make prediction
    with torch.no_grad():
        logits = model(**inputs).logits

    # Determine the category with the highest score
    predicted_class = torch.argmax(logits, dim=1).item()

    # Return the category: 0 for cytosolic, 1 for membrane
    return "cytosolic" if predicted_class == 0 else "membrane"

# Example sequence
new_protein_sequence = "MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETGGSEGVVRSDQVIITDHPGDLTFTVTLENLTADDAGKYRCGIATILQEDGLSGFLPDPFFQVQVLVSSASSTENSVKTPASPTRPSQCQGSLPSSTCFLLLPLLKVPLLLSILGAILWVNRPWRTPWTES"

# Predict the category
category = predict_category(new_protein_sequence)
print(f"The predicted category for the sequence is: {category}")
```