Spaces:
Sleeping
Sleeping
import json | |
from dataclasses import dataclass | |
from io import StringIO | |
from typing import Literal, Optional, TypedDict, cast | |
from zipfile import ZipFile | |
import polars as pl | |
import solara | |
from Bio.PDB import MMCIFParser, Structure | |
from ipymolstar import PDBeMolstar | |
from ipymolstar.widget import QueryParam | |
from polarify import polarify | |
from solara.components.file_drop import FileInfo | |
class ColorData(TypedDict): | |
data: list[QueryParam] | |
NonSelectedColor: None | |
class TooltipData(TypedDict): | |
data: list[QueryParam] | |
class CustomData(TypedDict): | |
data: str | |
format: Literal["cif"] | |
binary: Literal[False] | |
class AlphaFoldData: | |
name: str | |
structure: Structure | |
atom_data: pl.DataFrame | |
residue_data: pl.DataFrame | |
custom_data: CustomData | |
color_data: ColorData | |
tooltip_data: TooltipData | |
def write_atoms(self): | |
return self.atom_data.write_csv() | |
COLOR_LUT = { | |
"very-high": {"r": 16, "g": 109, "b": 255}, | |
"confident": {"r": 16, "g": 207, "b": 241}, | |
"low": {"r": 246, "g": 237, "b": 18}, | |
"very-low": {"r": 239, "g": 130, "b": 30}, | |
} | |
NO_COLOR_DATA = {"data": [], "nonSelectedColor": None} | |
NO_TOOLTIP_DATA = {"data": []} | |
PARSER = MMCIFParser() | |
result_index = solara.reactive(0) | |
file_info = solara.reactive(cast(Optional[FileInfo], None)) | |
def assign_confidence(x: pl.Expr) -> pl.Expr: | |
s = pl.lit("very-high") | |
if x < 50: | |
s = pl.lit("very-low") | |
elif x < 70: | |
s = pl.lit("low") | |
elif x < 90: | |
s = pl.lit("confident") | |
return s | |
def load_result() -> Optional[AlphaFoldData]: | |
f_idx = result_index.value | |
with ZipFile(file_info.value["file_obj"]) as zf: | |
files = zf.namelist() | |
names = sorted(f for f in files if f.endswith(".cif")) | |
structure_file = sorted(f for f in files if f.endswith(".cif"))[f_idx] | |
json_data_file = sorted(f for f in files if "full_data" in f)[f_idx] | |
with zf.open(json_data_file) as json_f: | |
json_load = json.load(json_f) | |
cif_str = zf.read(structure_file).decode("utf-8") | |
alphafold_name = structure_file.rstrip(".cif") | |
sio = StringIO(cif_str) | |
sio.seek(0) | |
structure = PARSER.get_structure(structure_file.removesuffix(".cif"), sio) | |
names = pl.Series( | |
(atom.get_parent().resname for atom in structure.get_atoms()), | |
dtype=pl.Categorical, | |
) | |
resn = pl.Series(atom.get_parent().id[1] for atom in structure.get_atoms()) | |
chain = pl.Series(json_load["atom_chain_ids"], dtype=pl.Categorical) | |
atoms_df = pl.DataFrame( | |
{ | |
"name": names, | |
"resn": resn, | |
"chain": chain, | |
"plddt": json_load["atom_plddts"], | |
} | |
) | |
residue_df = ( | |
atoms_df.group_by(["chain", "resn", "name"]) | |
.agg(pl.col("plddt").mean().alias("mean_plddt")) | |
.sort(["chain", "resn"]) | |
.with_columns( | |
assign_confidence(pl.col("mean_plddt")) | |
.alias("confidence") | |
.cast(pl.Categorical) | |
) | |
) | |
custom_data = { | |
"data": cif_str, | |
"format": "cif", | |
"binary": False, | |
} | |
color_query = [] | |
tooltip_query = [] | |
for chain, resn, name, mean_plddt, confidence in residue_df.iter_rows(): | |
res_color = { | |
"struct_asym_id": chain, | |
"residue_number": resn, | |
"color": COLOR_LUT[confidence], | |
} | |
res_tt = { | |
"struct_asym_id": chain, | |
"residue_number": resn, | |
"tooltip": f"Confidence: {confidence}; plddt: {mean_plddt:.2f}", | |
} | |
color_query.append(res_color) | |
tooltip_query.append(res_tt) | |
plddt_color_data = {"data": color_query, "nonSelectedColor": None} | |
plddt_tooltip_data = {"data": tooltip_query} | |
data = AlphaFoldData( | |
name=alphafold_name, | |
structure=structure, | |
atom_data=atoms_df, | |
residue_data=residue_df, | |
custom_data=custom_data, | |
color_data=plddt_color_data, | |
tooltip_data=plddt_tooltip_data, | |
) | |
return data | |
def Page(): | |
color_mode = solara.use_reactive("chain") | |
spin = solara.use_reactive(True) | |
dark_effective = solara.lab.use_dark_effective() | |
def on_color_mode(value: str): | |
color_mode.set(value) | |
def set_result_index(value: int): | |
result_index.set(value) | |
load_result() | |
solara.Title("Solarafold result viewer") | |
with solara.AppBar(): | |
solara.lab.ThemeToggle() | |
with solara.Sidebar(): | |
solara.FileDrop(label="Upload zip file", on_file=file_info.set, lazy=True) | |
solara.Button( | |
label="Load result", | |
on_click=load_result, | |
block=True, | |
disabled=file_info.value is None, | |
) | |
if not load_result.not_called: | |
disabled = load_result.pending | |
solara.Select( | |
label="Result index", | |
value=result_index.value, | |
on_value=set_result_index, | |
values=list(range(5)), | |
disabled=disabled, | |
) | |
solara.Select( | |
label="Color mode", | |
values=["chain", "plddt"], | |
value=color_mode.value, | |
on_value=on_color_mode, | |
disabled=disabled, | |
) | |
solara.Checkbox(label="Spin", value=spin) | |
def write_atoms(): | |
return load_result.value.atom_data.write_csv() | |
solara.FileDownload( | |
write_atoms, | |
filename="NA" if disabled else f"{load_result.value.name}_atoms.csv", | |
children=[ | |
solara.Button( | |
"Download atom plddt", | |
block=True, | |
disabled=disabled, | |
) | |
], | |
) | |
solara.Div(style={"height": "20px"}) | |
def write_residues(): | |
return load_result.value.residue_data.write_csv() | |
solara.FileDownload( | |
write_residues, | |
filename="NA" if disabled else f"{load_result.value.name}_residues.csv", | |
children=[ | |
solara.Button( | |
"Download residue plddt", | |
block=True, | |
disabled=disabled, | |
) | |
], | |
) | |
if load_result.not_called: | |
solara.HTML( | |
tag="p", | |
unsafe_innerHTML='Drag and drop an alphafold3 result .zip file to get started. You can download an example file <a href="https://www.gstatic.com/alphafoldserver/examplefold_pdb_8aw3/examplefold_pdb_8aw3.zip">here</a>.', | |
) | |
elif load_result.pending: | |
solara.ProgressLinear(load_result.pending) | |
elif load_result.finished: | |
fold_data: AlphaFoldData = load_result.value | |
color_data = ( | |
NO_COLOR_DATA if color_mode.value == "chain" else fold_data.color_data | |
) | |
with solara.Card(): | |
theme = "dark" if dark_effective else "light" | |
PDBeMolstar.element( | |
height="calc(100vh - 150px)", | |
custom_data=fold_data.custom_data, | |
color_data=color_data, | |
tooltips=fold_data.tooltip_data, | |
show_water=False, | |
spin=spin.value, | |
theme=theme, | |
).key(f"pdbemolstar-{dark_effective}") | |
def Layout(children): | |
dark_effective = solara.lab.use_dark_effective() | |
return solara.AppLayout( | |
children=children, toolbar_dark=dark_effective, color=None | |
) # if dark_effective else "primary") | |