Spaces:
Running
Running
| import json | |
| from dataclasses import dataclass | |
| from io import StringIO | |
| from typing import Literal, Optional, TypedDict, cast | |
| from zipfile import ZipFile | |
| import polars as pl | |
| import solara | |
| from Bio.PDB import MMCIFParser, Structure | |
| from ipymolstar import PDBeMolstar | |
| from ipymolstar.widget import QueryParam | |
| from polarify import polarify | |
| from solara.components.file_drop import FileInfo | |
| class ColorData(TypedDict): | |
| data: list[QueryParam] | |
| NonSelectedColor: None | |
| class TooltipData(TypedDict): | |
| data: list[QueryParam] | |
| class CustomData(TypedDict): | |
| data: str | |
| format: Literal["cif"] | |
| binary: Literal[False] | |
| class AlphaFoldData: | |
| name: str | |
| structure: Structure | |
| atom_data: pl.DataFrame | |
| residue_data: pl.DataFrame | |
| custom_data: CustomData | |
| color_data: ColorData | |
| tooltip_data: TooltipData | |
| def write_atoms(self): | |
| return self.atom_data.write_csv() | |
| COLOR_LUT = { | |
| "very-high": {"r": 16, "g": 109, "b": 255}, | |
| "confident": {"r": 16, "g": 207, "b": 241}, | |
| "low": {"r": 246, "g": 237, "b": 18}, | |
| "very-low": {"r": 239, "g": 130, "b": 30}, | |
| } | |
| NO_COLOR_DATA = {"data": [], "nonSelectedColor": None} | |
| NO_TOOLTIP_DATA = {"data": []} | |
| PARSER = MMCIFParser() | |
| result_index = solara.reactive(0) | |
| file_info = solara.reactive(cast(Optional[FileInfo], None)) | |
| def assign_confidence(x: pl.Expr) -> pl.Expr: | |
| s = pl.lit("very-high") | |
| if x < 50: | |
| s = pl.lit("very-low") | |
| elif x < 70: | |
| s = pl.lit("low") | |
| elif x < 90: | |
| s = pl.lit("confident") | |
| return s | |
| def load_result() -> Optional[AlphaFoldData]: | |
| f_idx = result_index.value | |
| with ZipFile(file_info.value["file_obj"]) as zf: | |
| files = zf.namelist() | |
| names = sorted(f for f in files if f.endswith(".cif")) | |
| structure_file = sorted(f for f in files if f.endswith(".cif"))[f_idx] | |
| json_data_file = sorted(f for f in files if "full_data" in f)[f_idx] | |
| with zf.open(json_data_file) as json_f: | |
| json_load = json.load(json_f) | |
| cif_str = zf.read(structure_file).decode("utf-8") | |
| alphafold_name = structure_file.rstrip(".cif") | |
| sio = StringIO(cif_str) | |
| sio.seek(0) | |
| structure = PARSER.get_structure(structure_file.removesuffix(".cif"), sio) | |
| names = pl.Series( | |
| (atom.get_parent().resname for atom in structure.get_atoms()), | |
| dtype=pl.Categorical, | |
| ) | |
| resn = pl.Series(atom.get_parent().id[1] for atom in structure.get_atoms()) | |
| chain = pl.Series(json_load["atom_chain_ids"], dtype=pl.Categorical) | |
| atoms_df = pl.DataFrame( | |
| { | |
| "name": names, | |
| "resn": resn, | |
| "chain": chain, | |
| "plddt": json_load["atom_plddts"], | |
| } | |
| ) | |
| residue_df = ( | |
| atoms_df.group_by(["chain", "resn", "name"]) | |
| .agg(pl.col("plddt").mean().alias("mean_plddt")) | |
| .sort(["chain", "resn"]) | |
| .with_columns( | |
| assign_confidence(pl.col("mean_plddt")) | |
| .alias("confidence") | |
| .cast(pl.Categorical) | |
| ) | |
| ) | |
| custom_data = { | |
| "data": cif_str, | |
| "format": "cif", | |
| "binary": False, | |
| } | |
| color_query = [] | |
| tooltip_query = [] | |
| for chain, resn, name, mean_plddt, confidence in residue_df.iter_rows(): | |
| res_color = { | |
| "struct_asym_id": chain, | |
| "residue_number": resn, | |
| "color": COLOR_LUT[confidence], | |
| } | |
| res_tt = { | |
| "struct_asym_id": chain, | |
| "residue_number": resn, | |
| "tooltip": f"Confidence: {confidence}; plddt: {mean_plddt:.2f}", | |
| } | |
| color_query.append(res_color) | |
| tooltip_query.append(res_tt) | |
| plddt_color_data = {"data": color_query, "nonSelectedColor": None} | |
| plddt_tooltip_data = {"data": tooltip_query} | |
| data = AlphaFoldData( | |
| name=alphafold_name, | |
| structure=structure, | |
| atom_data=atoms_df, | |
| residue_data=residue_df, | |
| custom_data=custom_data, | |
| color_data=plddt_color_data, | |
| tooltip_data=plddt_tooltip_data, | |
| ) | |
| return data | |
| def Page(): | |
| color_mode = solara.use_reactive("chain") | |
| spin = solara.use_reactive(True) | |
| dark_effective = solara.lab.use_dark_effective() | |
| def on_color_mode(value: str): | |
| color_mode.set(value) | |
| def set_result_index(value: int): | |
| result_index.set(value) | |
| load_result() | |
| solara.Title("Solarafold result viewer") | |
| with solara.AppBar(): | |
| solara.lab.ThemeToggle() | |
| with solara.Sidebar(): | |
| solara.FileDrop(label="Upload zip file", on_file=file_info.set, lazy=True) | |
| solara.Button( | |
| label="Load result", | |
| on_click=load_result, | |
| block=True, | |
| disabled=file_info.value is None, | |
| ) | |
| if not load_result.not_called: | |
| disabled = load_result.pending | |
| solara.Select( | |
| label="Result index", | |
| value=result_index.value, | |
| on_value=set_result_index, | |
| values=list(range(5)), | |
| disabled=disabled, | |
| ) | |
| solara.Select( | |
| label="Color mode", | |
| values=["chain", "plddt"], | |
| value=color_mode.value, | |
| on_value=on_color_mode, | |
| disabled=disabled, | |
| ) | |
| solara.Checkbox(label="Spin", value=spin) | |
| def write_atoms(): | |
| return load_result.value.atom_data.write_csv() | |
| solara.FileDownload( | |
| write_atoms, | |
| filename="NA" if disabled else f"{load_result.value.name}_atoms.csv", | |
| children=[ | |
| solara.Button( | |
| "Download atom plddt", | |
| block=True, | |
| disabled=disabled, | |
| ) | |
| ], | |
| ) | |
| solara.Div(style={"height": "20px"}) | |
| def write_residues(): | |
| return load_result.value.residue_data.write_csv() | |
| solara.FileDownload( | |
| write_residues, | |
| filename="NA" if disabled else f"{load_result.value.name}_residues.csv", | |
| children=[ | |
| solara.Button( | |
| "Download residue plddt", | |
| block=True, | |
| disabled=disabled, | |
| ) | |
| ], | |
| ) | |
| if load_result.not_called: | |
| solara.HTML( | |
| tag="p", | |
| unsafe_innerHTML='Drag and drop an alphafold3 result .zip file to get started. You can download an example file <a href="https://www.gstatic.com/alphafoldserver/examplefold_pdb_8aw3/examplefold_pdb_8aw3.zip">here</a>.', | |
| ) | |
| elif load_result.pending: | |
| solara.ProgressLinear(load_result.pending) | |
| elif load_result.finished: | |
| fold_data: AlphaFoldData = load_result.value | |
| color_data = ( | |
| NO_COLOR_DATA if color_mode.value == "chain" else fold_data.color_data | |
| ) | |
| with solara.Card(): | |
| theme = "dark" if dark_effective else "light" | |
| PDBeMolstar.element( | |
| height="calc(100vh - 150px)", | |
| custom_data=fold_data.custom_data, | |
| color_data=color_data, | |
| tooltips=fold_data.tooltip_data, | |
| show_water=False, | |
| spin=spin.value, | |
| theme=theme, | |
| ).key(f"pdbemolstar-{dark_effective}") | |
| def Layout(children): | |
| dark_effective = solara.lab.use_dark_effective() | |
| return solara.AppLayout( | |
| children=children, toolbar_dark=dark_effective, color=None | |
| ) # if dark_effective else "primary") | |