File size: 6,067 Bytes
5263bd3
 
 
 
 
 
2243c0c
 
 
 
 
5263bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a61f4be
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
a61f4be
 
2243c0c
 
 
 
 
 
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
ea34b0d
5263bd3
2243c0c
 
 
ea34b0d
2243c0c
 
ea34b0d
 
 
 
 
 
 
 
 
a61f4be
2243c0c
 
 
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
a61f4be
2243c0c
 
 
5263bd3
2243c0c
5263bd3
2243c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5263bd3
 
 
 
 
 
2243c0c
 
 
 
 
 
 
 
 
 
 
 
5263bd3
ea34b0d
5263bd3
 
ea34b0d
7a5c7ee
301b9cf
2243c0c
5263bd3
ea34b0d
 
 
 
 
 
5263bd3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
import torch.nn as nn
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VirusClassifier(nn.Module):
    def __init__(self, input_shape: int):
        super(VirusClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_shape, 64),
            nn.GELU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        return self.network(x)

def sequence_to_kmer_vector(sequence: str, k: int = 6) -> np.ndarray:
    """Convert sequence to k-mer frequency vector"""
    try:
        kmers = [''.join(p) for p in product("ACGT", repeat=k)]
        kmer_dict = {kmer: 0 for kmer in kmers}
        
        for i in range(len(sequence) - k + 1):
            kmer = sequence[i:i+k]
            if kmer in kmer_dict:  # only count valid kmers
                kmer_dict[kmer] += 1
        
        return np.array(list(kmer_dict.values()))
    except Exception as e:
        logger.error(f"Error in sequence_to_kmer_vector: {str(e)}")
        raise

def parse_fasta(content: str) -> list:
    """Parse FASTA format from string content"""
    try:
        logger.info(f"Received file content length: {len(content)}")
        
        sequences = []
        current_header = None
        current_sequence = []
        
        for line in content.split('\n'):
            line = line.strip()
            if not line:
                continue
            if line.startswith('>'):
                if current_header is not None:
                    sequences.append((current_header, ''.join(current_sequence)))
                current_header = line[1:]
                current_sequence = []
            else:
                current_sequence.append(line.upper())
                
        if current_header is not None:
            sequences.append((current_header, ''.join(current_sequence)))
        
        logger.info(f"Parsed {len(sequences)} sequences from FASTA")
        return sequences
    except Exception as e:
        logger.error(f"Error parsing FASTA: {str(e)}")
        raise

def predict_sequence(fasta_file) -> str:
    """Process FASTA input and return formatted predictions"""
    try:
        logger.info("Starting prediction process")
        
        if fasta_file is None:
            return "Please upload a FASTA file"
            
        # Get file content - handle both string and file inputs
        try:
            if isinstance(fasta_file, str):
                content = fasta_file
            else:
                content = fasta_file.name  # For Gradio file upload
        except Exception as e:
            logger.error(f"Error reading file: {str(e)}")
            return f"Error reading file: {str(e)}"
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using device: {device}")
        k = 4
        
        # Load model and scaler
        try:
            logger.info("Loading model and scaler")
            model = VirusClassifier(256).to(device)  # 256 = 4^4 for 4-mers
            model.load_state_dict(torch.load('model.pt', map_location=device))
            scaler = joblib.load('scaler.pkl')
            model.eval()
        except Exception as e:
            logger.error(f"Error loading model or scaler: {str(e)}")
            return f"Error loading model: {str(e)}"
        
        # Process sequences
        try:
            sequences = parse_fasta(content)
        except Exception as e:
            logger.error(f"Error parsing FASTA file: {str(e)}")
            return f"Error parsing FASTA file: {str(e)}"
            
        results = []
        
        for header, seq in sequences:
            logger.info(f"Processing sequence: {header}")
            try:
                # Convert sequence to k-mer vector
                kmer_vector = sequence_to_kmer_vector(seq, k)
                kmer_vector = scaler.transform(kmer_vector.reshape(1, -1))
                
                # Get prediction
                with torch.no_grad():
                    output = model(torch.FloatTensor(kmer_vector).to(device))
                    probs = torch.softmax(output, dim=1)
                    
                # Format result
                pred_class = 1 if probs[0][1] > probs[0][0] else 0
                pred_label = 'human' if pred_class == 1 else 'non-human'
                
                result = f"""
Sequence: {header}
Prediction: {pred_label}
Confidence: {float(max(probs[0])):0.4f}
Human probability: {float(probs[0][1]):0.4f}
Non-human probability: {float(probs[0][0]):0.4f}
"""
                results.append(result)
                logger.info(f"Processed sequence {header} successfully")
                
            except Exception as e:
                logger.error(f"Error processing sequence {header}: {str(e)}")
                results.append(f"Error processing sequence {header}: {str(e)}")
        
        return "\n".join(results)
        
    except Exception as e:
        logger.error(f"Unexpected error in predict_sequence: {str(e)}")
        return f"An unexpected error occurred: {str(e)}"

# Create Gradio interface with both file upload and text input
iface = gr.Interface(
    fn=predict_sequence,
    inputs=[
        gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"])
    ]
    outputs=gr.Textbox(label="Prediction Results", lines=10),
    title="Virus Host Classifier",
    description="""Upload a FASTA file or paste your sequence to predict whether a virus sequence is likely to infect human or non-human hosts.
    
Example format:
>sequence_name
ATCGATCGATCG...""",
    examples=[["example.fasta", None]],
    cache_examples=True
)

# Launch the interface
iface.launch()