| from math import ceil |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| from re import match |
| import seaborn as sns |
|
|
| from model import Model |
|
|
| class Data: |
| """Container for input and output data""" |
| |
| model = Model() |
|
|
| def parse_seq(self, src: str): |
| """Parse input sequence""" |
| self.seq = src.strip().upper().replace('\n', '') |
| if not all(x in self.model.alphabet for x in self.seq): |
| raise RuntimeError("Unrecognised characters in sequence") |
|
|
| def parse_sub(self, trg: str): |
| """Parse input substitutions""" |
| self.mode = None |
| self.sub = list() |
| self.trg = trg.strip().upper().split() |
| self.resi = list() |
|
|
| |
| if len(self.trg) == 1 and len(self.trg[0]) == len(self.seq) and match(r'^\w+$', self.trg[0]): |
| |
| self.mode = 'MUT' |
| for resi, (src, trg) in enumerate(zip(self.seq, self.trg[0]), 1): |
| if src != trg: |
| self.sub.append(f"{src}{resi}{trg}") |
| self.resi.append(resi) |
| else: |
| if all(match(r'\d+', x) for x in self.trg): |
| |
| self.mode = 'DMS' |
| for resi in map(int, self.trg): |
| src = self.seq[resi-1] |
| for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): |
| self.sub.append(f"{src}{resi}{trg}") |
| self.resi.append(resi) |
| elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): |
| |
| self.mode = 'MUT' |
| self.sub = self.trg |
| self.resi = [int(x[1:-1]) for x in self.trg] |
| for s, *resi, _ in self.trg: |
| if self.seq[int(''.join(resi))-1] != s: |
| raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}") |
| else: |
| self.mode = 'TMS' |
| for resi, src in enumerate(self.seq, 1): |
| for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): |
| self.sub.append(f"{src}{resi}{trg}") |
| self.resi.append(resi) |
|
|
| self.sub = pd.DataFrame(self.sub, columns=['0']) |
|
|
| def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'): |
| "initialise data" |
| |
| if self.model.model_name != model_name: |
| self.model_name = model_name |
| self.model = Model(model_name) |
| self.parse_seq(src) |
| self.offset = 0 |
| self.parse_sub(trg) |
| self.scoring_strategy = scoring_strategy |
| self.token_probs = None |
| self.out = pd.DataFrame(self.sub, columns=['0', self.model_name]) |
| self.out_img = f'{out_file}.png' |
| self.out_csv = f'{out_file}.csv' |
|
|
| def parse_output(self) -> None: |
| "format output data for visualisation" |
| if self.mode == 'TMS': |
| self.process_tms_mode() |
| self.out.to_csv(self.out_csv, float_format='%.2f') |
| else: |
| if self.mode == 'DMS': |
| self.sort_by_residue_and_score() |
| elif self.mode == 'MUT': |
| self.sort_by_score() |
| else: |
| raise RuntimeError(f"Unrecognised mode {self.mode}") |
| self.out.columns = [str(i) for i in range(self.out.shape[1])] |
| self.out_img = (self.out.style |
| .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x) |
| .hide(axis=0) |
| .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)) |
| self.out.to_csv(self.out_csv, float_format='%.2f', index=False, header=False) |
|
|
| def sort_by_score(self): |
| self.out = self.out.sort_values(self.model_name, ascending=False) |
|
|
| def sort_by_residue_and_score(self): |
| self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) |
| .sort_values(['resi', self.model_name], ascending=[True,False]) |
| .groupby(['resi']) |
| .head(19) |
| .drop(['resi'], axis=1)) |
| self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)] |
| , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns') |
|
|
| def process_tms_mode(self): |
| self.out = self.assign_resi_and_group() |
| self.out = self.concat_and_set_axis() |
| self.out /= self.out.abs().max().max() |
| divs = self.calculate_divs() |
| ncols = min(divs, key=lambda x: abs(x-60)) |
| nrows = ceil(self.out.shape[1]/ncols) |
| ncols = self.adjust_ncols(ncols, nrows) |
| self.plot_heatmap(ncols, nrows) |
|
|
| def assign_resi_and_group(self): |
| return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) |
| .groupby(['resi']) |
| .head(19)) |
|
|
| def concat_and_set_axis(self): |
| return (pd.concat([(self.out.iloc[19*x:19*(x+1)] |
| .pipe(self.create_dataframe) |
| .sort_values(['0'], ascending=[True]) |
| .drop(['resi', '0'], axis=1) |
| .set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', |
| 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']) |
| .astype(float) |
| ) for x in range(self.out.shape[0]//19)] |
| , axis=1) |
| .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns')) |
|
|
| def create_dataframe(self, df): |
| return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True) |
|
|
| def calculate_divs(self): |
| return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60] |
|
|
| def adjust_ncols(self, ncols, nrows): |
| while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]: |
| ncols -= 1 |
| return ncols + 1 |
|
|
| def plot_heatmap(self, ncols, nrows): |
| if nrows < 2: |
| self.plot_single_heatmap() |
| else: |
| self.plot_multiple_heatmaps(ncols, nrows) |
|
|
| plt.savefig(self.out_img, format='png', dpi=300) |
|
|
| def plot_single_heatmap(self): |
| fig = plt.figure(figsize=(12, 6)) |
| sns.heatmap(self.out |
| , cmap='RdBu' |
| , cbar=False |
| , square=True |
| , xticklabels=1 |
| , yticklabels=1 |
| , center=0 |
| , annot=self.out.map(lambda x: ' ' if x != 0 else '·') |
| , fmt='s' |
| , annot_kws={'size': 'xx-large'}) |
| fig.tight_layout() |
|
|
| def plot_multiple_heatmaps(self, ncols, nrows): |
| fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows)) |
| for i in range(nrows): |
| tmp = self.out.iloc[:,i*ncols:(i+1)*ncols] |
| label = tmp.map(lambda x: ' ' if x != 0 else '·') |
| sns.heatmap(tmp |
| , ax=ax[i] |
| , cmap='RdBu' |
| , cbar=False |
| , square=True |
| , xticklabels=1 |
| , yticklabels=1 |
| , center=0 |
| , annot=label |
| , fmt='s' |
| , annot_kws={'size': 'xx-large'}) |
| ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0) |
| ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90) |
| fig.tight_layout() |
| |
| def calculate(self): |
| "run model and parse output" |
| self.model.run_model(self) |
| self.parse_output() |
| return self |
|
|
| def csv(self): |
| "return output data" |
| return self.out_csv |
|
|
| def image(self): |
| "return output data" |
| return self.out_img |
|
|