multiTAP / cytof /utils.py
ivangzf's picture
add multitap files
b78c3b8
import os
import pickle as pkl
import skimage
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
import scipy
from typing import Union, Optional, Type, Tuple, List, Dict
import itertools
from multiprocessing import Pool
from tqdm import tqdm
from readimc import MCDFile, TXTFile
import warnings
def load_CytofImage(savename):
cytof_img = pkl.load(open(savename, "rb"))
return cytof_img
def load_CytofCohort(savename):
cytof_cohort = pkl.load(open(savename, "rb"))
return cytof_cohort
def process_mcd(filename: str,
params: Dict):
"""
A function to process a whole slide .mcd file
"""
from classes import CytofImageTiff, CytofCohort
quality_control_thres = params.get("quality_control_thres", None)
channels_remove = params.get("channels_remove", None)
channels_dict = params.get("channels_dict", None)
use_membrane = params.get("use_membrane", False)
cell_radius = params.get("cell_radius", 5)
normalize_qs = params.get("normalize_qs", 75)
df_cohort = pd.DataFrame(columns = ['Slide', 'ROI', 'input file'])
cytof_images = {}
corrupted = []
with MCDFile(filename) as f:
for slide in f.slides:
sid = f"{slide.description}{slide.id}"
print(sid)
for roi in slide.acquisitions:
rid = roi.description
print(f'processing slide_id-roi: {sid}-{rid}')
if roi.metadata["DataStartOffset"] < roi.metadata["DataEndOffset"]:
img_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float3
img_roi = np.transpose(img_roi, (1, 2, 0))
cytof_img = CytofImageTiff(slide=sid, roi = rid, image=img_roi, filename=f"{sid}-{rid}")
# cytof_img.quality_control(thres=quality_control_thres)
channels = [f"{mk}({cn})" for (mk, cn) in zip(roi.channel_labels, roi.channel_names)]
cytof_img.set_markers(markers=roi.channel_labels, labels=roi.channel_names, channels=channels) # targets, metals
# known corrupted channels, e.g. nan-nan1
if channels_remove is not None and len(channels_remove) > 0:
cytof_img.remove_special_channels(channels_remove)
# maps channel names to nuclei/membrane
if channels_dict is not None:
# remove nuclei channel for segmentation
channels_rm = cytof_img.define_special_channels(channels_dict, rm_key='nuclei')
cytof_img.remove_special_channels(channels_rm)
cytof_img.get_seg(radius=cell_radius, use_membrane=use_membrane)
cytof_img.extract_features(cytof_img.filename)
cytof_img.feature_quantile_normalization(qs=normalize_qs)
df_cohort = pd.concat([df_cohort, pd.DataFrame.from_dict([{'Slide': sid,
'ROI': rid,
'input file': filename}])])
cytof_images[f"{sid}-{rid}"] = cytof_img
else:
corrupted.append(f"{sid}-{rid}")
print(f"This cohort now contains {len(cytof_images)} ROIs, after excluding {len(corrupted)} corrupted ones from the original MCD.")
cytof_cohort = CytofCohort(cytof_images=cytof_images, df_cohort=df_cohort)
if channels_dict is not None:
cytof_cohort.batch_process_feature()
else:
warnings.warn("Feature extraction is not done as no nuclei channels defined by 'channels_dict'!")
return corrupted, cytof_cohort#, cytof_images
def save_multi_channel_img(img, savename):
"""
A helper function to save multi-channel images
"""
skimage.io.imsave(savename, img)
def generate_color_dict(names: List,
sort_names: bool = True,
):
"""
Randomly generate a dictionary of colors based on provided "names"
"""
if sort_names:
names.sort()
color_dict = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
return color_dict
def show_color_table(color_dict: dict, # = None,
# names: List = ['1'],
title: str = "",
maxcols: int = 4,
emptycols: int = 0,
# sort_names: bool = True,
dpi: int = 72,
cell_width: int = 212,
cell_height: int = 22,
swatch_width: int = 48,
margin: int = 12,
topmargin: int = 40,
show: bool = True
):
"""
Show color dictionary
Generate the color table for visualization.
If "color_dict" is provided, show color_dict;
otherwise, randomly generate color_dict based on "names"
reference: https://matplotlib.org/stable/gallery/color/named_colors.html
args:
color_dict (optional) = a dictionary of colors. key: color legend name - value: RGB representation of color
names (optional) = names for each color legend (default=["1"])
title (optional) = title for the color table (default="")
maxcols = maximum number of columns in visualization
emptycols (optional) = number of empty columns for a maxcols-column figure,
i.e. maxcols=4 and emptycols=3 means presenting single column plot (default=0)
sort_names (optional) = a flag indicating whether sort colors based on names (default=True)
"""
# if sort_names:
# names.sort()
# if color_pool is None:
# color_pool = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
# else:
names = color_dict.keys()
n = len(names)
ncols = maxcols - emptycols
nrows = n // ncols + int(n % ncols > 0)
# width = cell_width * 4 + 2 * margin
width = cell_width * ncols + 2 * margin
height = cell_height * nrows + margin + topmargin
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
fig.subplots_adjust(margin / width, margin / height,
(width - margin) / width, (height - topmargin) / height)
# ax.set_xlim(0, cell_width * 4)
ax.set_xlim(0, cell_width * ncols)
ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.)
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_axis_off()
ax.set_title(title, fontsize=16, loc="left", pad=10)
for i, n in enumerate(names):
row = i % nrows
col = i // nrows
y = row * cell_height
swatch_start_x = cell_width * col
text_pos_x = cell_width * col + swatch_width + 7
ax.text(text_pos_x, y, n, fontsize=12,
horizontalalignment='left',
verticalalignment='center')
ax.add_patch(
Rectangle(xy=(swatch_start_x, y - 9), width=swatch_width,
height=18, facecolor=color_dict[n], edgecolor='0.7')
)
def _extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, morphology, nuclei_morphology, cell_morphology,
channels, raw_image, sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell):
regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1)
if len(regions) >= 1:
this_nucleus = regions[0]
else:
return {}
regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
if len(regions) >= 1:
this_cell = regions[0]
else:
return {}
centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columnsb
res = {"filename": filename,
"id": nuclei_id,
"coordinate_x": centroid_x,
"coordinate_y": centroid_y}
# morphology
for i, feature in enumerate(morphology[:-1]):
res[nuclei_morphology[i]] = getattr(this_nucleus, feature)
res[cell_morphology[i]] = getattr(this_cell, feature)
res[nuclei_morphology[-1]] = 1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area
res[cell_morphology[-1]] = 1.0 * this_cell.perimeter ** 2 / this_cell.filled_area
# markers
for ch, marker in enumerate(channels):
res[sum_exp_nuclei[ch]] = np.sum(raw_image[nuclei_seg == nuclei_id, ch])
res[ave_exp_nuclei[ch]] = np.average(raw_image[nuclei_seg == nuclei_id, ch])
res[sum_exp_cell[ch]] = np.sum(raw_image[cell_seg == nuclei_id, ch])
res[ave_exp_cell[ch]] = np.average(raw_image[cell_seg == nuclei_id, ch])
return res
def extract_feature(channels: List,
raw_image: np.ndarray,
nuclei_seg: np.ndarray,
cell_seg: np.ndarray,
filename: str,
use_parallel: bool = True,
show_sample: bool = False) -> pd.DataFrame:
""" Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
results
Inputs:
channels = channels to extract feature from
raw_image = raw cytof image
nuclei_seg = nuclei segmentation result
cell_seg = cell segmentation result
filename = filename of current cytof image
Returns:
feature_summary_df = a dataframe containing summary of extracted features
morphology = names of morphology features extracted
:param channels: list
:param raw_image: numpy.ndarray
:param nuclei_seg: numpy.ndarray
:param cell_seg: numpy.ndarray
:param filename: string
:param morpholoty: list
:return feature_summary_df: pandas.core.frame.DataFrame
"""
assert (len(channels) == raw_image.shape[-1])
# morphology features to be extracted
morphology = ["area", "convex_area", "eccentricity", "extent",
"filled_area", "major_axis_length", "minor_axis_length",
"orientation", "perimeter", "solidity", "pa_ratio"]
## morphology features
nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
## single cell features
# nuclei level
sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
# cell level
sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
# column names of final result dataframe
column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
sum_exp_cell + ave_exp_cell + cell_morphology
# Initiate
n_nuclei = np.max(nuclei_seg)
feature_summary_df = pd.DataFrame(columns=column_names)
if use_parallel:
nuclei_ids = range(2, n_nuclei + 1)
with Pool() as mp_pool:
res = mp_pool.starmap(_extract_feature_one_nuclei,
zip(nuclei_ids,
itertools.repeat(nuclei_seg),
itertools.repeat(cell_seg),
itertools.repeat(filename),
itertools.repeat(morphology),
itertools.repeat(nuclei_morphology),
itertools.repeat(cell_morphology),
itertools.repeat(channels),
itertools.repeat(raw_image),
itertools.repeat(sum_exp_nuclei),
itertools.repeat(ave_exp_nuclei),
itertools.repeat(sum_exp_cell),
itertools.repeat(ave_exp_cell)
))
# print(len(res), n_nuclei)
else:
res = []
for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
res.append(_extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename,
morphology, nuclei_morphology, cell_morphology,
channels, raw_image,
sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell))
feature_summary_df = pd.DataFrame(res)
if show_sample:
print(feature_summary_df.sample(5))
return feature_summary_df
def check_feature_distribution(feature_summary_df, features):
""" Visualize feature distribution for each feature
Inputs:
feature_summary_df = dataframe of extracted feature summary
features = features to check distribution
Returns:
None
:param feature_summary_df: pandas.core.frame.DataFrame
:param features: list
"""
for feature in features:
print(feature)
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
ax.set_xlim(-15, 15)
plt.show()
# def visualize_scatter(data, communities, n_community, title, figsize=(4,4), savename=None, show=False):
# """
# data = data to visualize (N, 2)
# communities = group indices correspond to each sample in data (N, 1) or (N, )
# n_community = total number of groups in the cohort (n_community >= unique number of communities)
# """
# fig, ax = plt.subplots(1,1, figsize=figsize)
# ax.set_title(title)
# sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
# hue_order=np.arange(n_community))
# # legend=legend,
# # hue_order=np.arange(n_community))
# plt.axis('tight')
# plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
# if savename is not None:
# print("saving plot to {}".format(savename))
# plt.savefig(savename)
# if show:
# plt.show()
# return None
# return fig
def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None):
"""
data = data to visualize (N, 2)
communities = group indices correspond to each sample in data (N, 1) or (N, )
n_community = total number of groups in the cohort (n_community >= unique number of communities)
"""
clos = not show and ax is None
show = show and ax is None
if ax is None:
fig, ax = plt.subplots(1,1)
else:
fig = None
ax.set_title(title)
sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
hue_order=np.arange(n_community), ax=ax)
# legend=legend,
# hue_order=np.arange(n_community))
ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
# plt.axis('tight')
if savename is not None:
print("saving plot to {}".format(savename))
plt.tight_layout()
plt.savefig(savename)
if show:
plt.show()
if clos:
plt.close('all')
return fig
def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
clos = not show and ax is None
show = show and ax is None
if ax is None:
fig, ax = plt.subplots(1,1)
else:
fig = None
sns.heatmap(data,
cmap='magma',
xticklabels=markers,
yticklabels=group_ids,
ax=ax
)
ax.set_xlabel("Markers")
ax.set_ylabel("Phenograph clusters")
ax.set_title("normalized expression - {}".format(title))
ax.xaxis.set_tick_params(labelsize=8)
if savename is not None:
plt.tight_layout()
plt.savefig(savename)
if show:
plt.show()
if clos:
plt.close('all')
return fig
def _get_thresholds(df_feature: pd.DataFrame,
features: List[str],
thres_bg: float = 0.3,
visualize: bool = True,
verbose: bool = False):
"""Calculate thresholds for each feature by assuming a Gaussian Mixture Model
Inputs:
df_feature = dataframe of extracted feature summary
features = a list of features to calculate thresholds from
thres_bg = a threshold such that the component with the mixing weight greater than the threshold would
be considered as background. (Default=0.3)
visualize = a flag indicating whether to visualize the feature distributions and thresholds or not.
(Default=True)
verbose = a flag indicating whether to print calculated values on screen or not. (Default=False)
Outputs:
thresholds = a dictionary of calculated threshold values
:param df_feature: pandas.core.frame.DataFrame
:param features: list
:param visualize: bool
:param verbose: bool
:return thresholds: dict
"""
thresholds = {}
for f, feat_name in enumerate(features):
X = df_feature[feat_name].values.reshape(-1, 1)
gm = GaussianMixture(n_components=2, random_state=0, n_init=2).fit(X)
mu = np.min(gm.means_[gm.weights_ > thres_bg])
which_component = np.argmax(gm.means_ == mu)
if verbose:
print(f"GMM mean values: {gm.means_}")
print(f"GMM weights: {gm.weights_}")
print(f"GMM covariances: {gm.covariances_}")
X = df_feature[feat_name].values
hist = np.histogram(X, 150)
sigma = np.sqrt(gm.covariances_[which_component, 0, 0])
background_ratio = gm.weights_[which_component]
thres = sigma * 2.5 + mu
thresholds[feat_name] = thres
n = sum(X > thres)
percentage = n / len(X)
## visualize
if visualize:
fig, ax = plt.subplots(1, 1)
ax.hist(X, 150, density=True)
ax.set_xlabel("log2({})".format(feat_name))
ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], mu, sigma) * background_ratio, c='red')
_which_component = np.argmin(gm.means_ == mu)
_mu = gm.means_[_which_component]
_sigma = np.sqrt(gm.covariances_[_which_component, 0, 0])
ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], _mu, _sigma) * (1 - background_ratio), c='orange')
ax.axvline(x=thres, c='red')
ax.text(0.7, 0.9, "n={}, percentage={}".format(n, np.round(percentage, 3)), ha='center', va='center',
transform=ax.transAxes)
ax.text(0.3, 0.9, "mu={}, sigma={}".format(np.round(mu, 2), np.round(sigma, 2)), ha='center', va='center',
transform=ax.transAxes)
ax.text(0.3, 0.8, "background ratio={}".format(np.round(background_ratio, 2)), ha='center', va='center',
transform=ax.transAxes)
ax.set_title(feat_name)
plt.show()
return thresholds
def _generate_summary(df_feature: pd.DataFrame, features: List[str], thresholds: dict) -> pd.DataFrame:
"""Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
ratio of individuals (cells) with greater than threshold values
Inputs:
df_feature = dataframe of extracted feature summary
features = a list of features to generate summary table
thresholds = (calculated GMM-based) thresholds for each feature
Outputs:
df_info = summary table for each feature
:param df_feature: pandas.core.frame.DataFrame
:param features: list
:param thresholds: dict
:return df_info: pandas.core.frame.DataFrame
"""
df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
for feature in features: # loop over each feature
thres = thresholds[feature] # fetch threshold for the feature
X = df_feature[feature].values
n = sum(X > thres)
N = len(X)
df_new_row = pd.DataFrame({'feature': feature, 'total number': N, 'threshold': thres,
'positive counts': n, 'positive ratio': n / N}, index=[0])
df_info = pd.concat([df_info, df_new_row])
return df_info.reset_index(drop=True)