|
from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
|
|
from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment |
|
|
|
|
|
class AbAlignment: |
|
|
|
def __init__(self, device = 'cpu', ncpu = 1): |
|
|
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def number_sequences(self, seqs, chain = 'H', fragmented = False): |
|
if chain == 'HL': |
|
numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu) |
|
else: |
|
assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.' |
|
numbered_seqs, seqs, number_alignment = unpaired_msa_numbering( |
|
seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu |
|
) |
|
|
|
return numbered_seqs, seqs, number_alignment |
|
|
|
def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): |
|
|
|
aligned_list = [ |
|
create_alignment( |
|
res_embed, numbered_seq, seq, number_alignment |
|
) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs) |
|
] |
|
aligned_encodings = np.concatenate([aligned_list], axis=0) |
|
return aligned_encodings |
|
|
|
|
|
def reformat_subsets( |
|
self, |
|
subset_list, |
|
mode = 'seqcoding', |
|
align = False, |
|
numbered_seqs = None, |
|
seqs = None, |
|
number_alignment = None, |
|
): |
|
|
|
if mode in [ |
|
'seqcoding', |
|
'restore', |
|
'pseudo_log_likelihood', |
|
'confidence' |
|
]: |
|
return np.concatenate(subset_list) |
|
elif align: |
|
subset_list = [ |
|
self.align_encodings( |
|
subset, |
|
numbered_seqs[num*len(subset):(num+1)*len(subset)], |
|
seqs[num*len(subset):(num+1)*len(subset)], |
|
number_alignment |
|
) for num, subset in enumerate(subset_list) |
|
] |
|
|
|
subset = np.concatenate(subset_list) |
|
|
|
return aligned_results( |
|
aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]], |
|
aligned_embeds = subset[:,:,:-1].astype(float), |
|
number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values |
|
) |
|
|
|
elif not align: |
|
return sum(subset_list, []) |
|
else: |
|
return np.concatenate(subset_list) |
|
|
|
|
|
@dataclass |
|
class aligned_results(): |
|
""" |
|
Dataclass used to store output. |
|
""" |
|
|
|
aligned_seqs: None |
|
aligned_embeds: None |
|
number_alignment: None |