abondrn commited on
Commit
4d84fae
1 Parent(s): bc4ccb5

Added msa.py utils

Browse files
Files changed (2) hide show
  1. app.py +6 -1
  2. msa.py +62 -0
app.py CHANGED
@@ -64,12 +64,17 @@ def msa_embed(msa):
64
  msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
65
  msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
66
 
67
- temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
 
68
  temp = temp[12][:,:,0,:]
69
  temp = torch.mean(temp,(0,1))
70
  return temp
71
 
72
 
 
 
 
 
73
  def download_data_if_required():
74
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
75
  fps = [pg.trained_model_fp]
 
64
  msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
65
  msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
66
 
67
+ with torch.no_grad():
68
+ temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
69
  temp = temp[12][:,:,0,:]
70
  temp = torch.mean(temp,(0,1))
71
  return temp
72
 
73
 
74
+ def go_embed(terms):
75
+ pass
76
+
77
+
78
  def download_data_if_required():
79
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
80
  fps = [pg.trained_model_fp]
msa.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import itertools
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
5
+ import string
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.spatial.distance import squareform, pdist, cdist
10
+ from Bio import SeqIO
11
+ #import biotite.structure as bs
12
+ #from biotite.structure.io.pdbx import PDBxFile, get_structure
13
+ #from biotite.database import rcsb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+
18
+ # This is an efficient way to delete lowercase characters and insertion characters from a string
19
+ deletekeys = dict.fromkeys(string.ascii_lowercase)
20
+ deletekeys["."] = None
21
+ deletekeys["*"] = None
22
+ translation = str.maketrans(deletekeys)
23
+
24
+
25
+ def read_sequence(filename: str) -> Tuple[str, str]:
26
+ """ Reads the first (reference) sequences from a fasta or MSA file."""
27
+ record = next(SeqIO.parse(filename, "fasta"))
28
+ return record.description, str(record.seq)
29
+
30
+ def remove_insertions(sequence: str) -> str:
31
+ """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
32
+ return sequence.translate(translation)
33
+
34
+ def read_msa(filename: str) -> List[Tuple[str, str]]:
35
+ """ Reads the sequences from an MSA file, automatically removes insertions."""
36
+ return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
37
+
38
+
39
+ def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
40
+ """
41
+ Select sequences from the MSA to maximize the hamming distance
42
+ Alternatively, can use hhfilter
43
+ """
44
+ assert mode in ("max", "min")
45
+ if len(msa) <= num_seqs:
46
+ return msa
47
+
48
+ array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
49
+
50
+ optfunc = np.argmax if mode == "max" else np.argmin
51
+ all_indices = np.arange(len(msa))
52
+ indices = [0]
53
+ pairwise_distances = np.zeros((0, len(msa)))
54
+ for _ in range(num_seqs - 1):
55
+ dist = cdist(array[indices[-1:]], array, "hamming")
56
+ pairwise_distances = np.concatenate([pairwise_distances, dist])
57
+ shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
58
+ shifted_index = optfunc(shifted_distance)
59
+ index = np.delete(all_indices, indices)[shifted_index]
60
+ indices.append(index)
61
+ indices = sorted(indices)
62
+ return [msa[idx] for idx in indices]