File size: 4,432 Bytes
8896a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os, sys
import subprocess as sp
import random
import torch
import h5py
from tqdm import tqdm
from .fasta import parse, parse_directory, write
from .pretrained import get_pretrained
from .alphabets import Uniprot21
from .models.embedding import SkipLSTM
from datetime import datetime
def lm_embed(sequence, use_cuda=False):
"""
Embed a single sequence using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
:param sequence: Input sequence to be embedded
:type sequence: str
:param use_cuda: Whether to generate embeddings using GPU device [default: False]
:type use_cuda: bool
:return: Embedded sequence
:rtype: torch.Tensor
"""
model = get_pretrained("lm_v1")
torch.nn.init.normal_(model.proj.weight)
model.proj.bias = torch.nn.Parameter(torch.zeros(100))
if use_cuda:
model = model.cuda()
model.eval()
with torch.no_grad():
alphabet = Uniprot21()
es = torch.from_numpy(alphabet.encode(sequence.encode('utf-8')))
x = es.long().unsqueeze(0)
if use_cuda:
x = x.cuda()
z = model.transform(x)
return z.cpu()
def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
"""
Embed sequences using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
:param fastaPath: Input sequence file (``.fasta`` format)
:type fastaPath: str
:param outputPath: Output embedding file (``.h5`` format)
:type outputPath: str
:param device: Compute device to use for embeddings [default: 0]
:type device: int
:param verbose: Print embedding progress
:type verbose: bool
"""
use_cuda = (device >= 0) and torch.cuda.is_available()
if use_cuda:
torch.cuda.set_device(device)
if verbose:
print(f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}")
else:
if verbose:
print("# Using CPU")
if verbose:
print("# Loading Model...")
model = get_pretrained("lm_v1")
torch.nn.init.normal_(model.proj.weight)
model.proj.bias = torch.nn.Parameter(torch.zeros(100))
if use_cuda:
model = model.cuda()
model.eval()
if verbose:
print("# Loading Sequences...")
names, seqs = parse(open(fastaPath, "rb"))
alphabet = Uniprot21()
encoded_seqs = []
for s in tqdm(seqs):
es = torch.from_numpy(alphabet.encode(s))
if use_cuda:
es = es.cuda()
encoded_seqs.append(es)
if verbose:
num_seqs = len(encoded_seqs)
print("# {} Sequences Loaded".format(num_seqs))
print("# Approximate Storage Required (varies by average sequence length): ~{}GB".format(num_seqs * (1/125)))
h5fi = h5py.File(outputPath, "w")
print("# Storing to {}...".format(outputPath))
with torch.no_grad():
try:
for (n, x) in tqdm(zip(names, encoded_seqs),total=len(names)):
name = n.decode("utf-8")
if not name in h5fi:
x = x.long().unsqueeze(0)
z = model.transform(x)
h5fi.create_dataset(name, data=z.cpu().numpy(), compression="lzf")
except KeyboardInterrupt:
h5fi.close()
sys.exit(1)
h5fi.close()
def embed_from_directory(directory, outputPath, device=0, verbose=False, extension=".seq"):
"""
Embed all files in a directory in ``.fasta`` format using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
:param directory: Input directory (``.fasta`` format)
:type directory: str
:param outputPath: Output embedding file (``.h5`` format)
:type outputPath: str
:param device: Compute device to use for embeddings [default: 0]
:type device: int
:param verbose: Print embedding progress
:type verbose: bool
:param extension: Extension of all files to read in
:type extension: str
"""
nam, seq = parse_directory(directory, extension=extension)
fastaPath = f"{directory}/allSeqs.fa"
if os.path.exists(fastaPath):
fastaPath = f"{fastaPath}.{int(datetime.utcnow().timestamp())}"
write(nam, seq, open(fastaPath, "w"))
embed_from_fasta(fastaPath, outputPath, device, verbose)
|