{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!uv pip list|grep mammal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask\n", "from mammal.keys import CLS_PRED, SCORES\n", "from mammal.model import Mammal\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# input\n", "target_seq = \"NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC\"\n", "drug_seq = \"CC(=O)NCCC1=CNc2c1cc(OC)cc2\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# Load Model\n", "model = Mammal.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd\")\n", "model.eval()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# Load Tokenizer\n", "tokenizer_op = ModularTokenizerOp.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# convert to MAMMAL style\n", "sample_dict = {\"target_seq\": target_seq, \"drug_seq\": drug_seq}\n", "sample_dict = DtiBindingdbKdTask.data_preprocessing(\n", " sample_dict=sample_dict,\n", " tokenizer_op=tokenizer_op,\n", " target_sequence_key=\"target_seq\",\n", " drug_sequence_key=\"drug_seq\",\n", " norm_y_mean=None,\n", " norm_y_std=None,\n", " device=model.device,\n", ")\n", "\n", "sample_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "f\"<@TOKENIZER-TYPE=AA>\"\\\n", " f\"<@TOKENIZER-TYPE=AA@MAX-LEN={target_max_seq_length}>{target_sequence}\" \\\n", " f\"<@TOKENIZER-TYPE=SMILES@MAX-LEN={drug_max_seq_length}>{drug_sequence}\" \\\n", " \"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# forward pass - encoder_only mode which supports scalar predictions\n", "batch_dict = model.forward_encoder_only([sample_dict])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# Post-process the model's output\n", "batch_dict = DtiBindingdbKdTask.process_model_output(\n", " batch_dict,\n", " scalars_preds_processed_key=\"model.out.dti_bindingdb_kd\",\n", " norm_y_mean=5.79384684128215,\n", " norm_y_std=1.33808027428196,\n", ")\n", "ans = {\n", " \"model.out.dti_bindingdb_kd\": float(batch_dict[\"model.out.dti_bindingdb_kd\"][0])\n", "}\n", "\n", "# Print prediction\n", "print(f\"{ans=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "sample_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# GRADIO app" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "import torch\n", "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n", "from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask\n", "from mammal.keys import *\n", "from mammal.model import Mammal\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd'\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8afb628360ae435395dd34ba644945b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 9 files: 0%| | 0/9 [00:00\")\n", "\n", "# Default input proteins\n", "protein_calmodulin = \"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK\"\n", "protein_calcineurin = \"MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ\"\n", "\n", "\n", "# input\n", "target_seq = \"NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC\"\n", "drug_seq = \"CC(=O)NCCC1=CNc2c1cc(OC)cc2\"\n", "\n", "def format_prompt_ppi(prot1,prot2):\n", " # Formatting prompt to match pre-training syntax\n", " return f\"<@TOKENIZER-TYPE=AA>{prot1}{prot2}\"\n", "def format_prompt_dti(prot,drug):\n", " sample_dict = {\"target_seq\": target_seq, \"drug_seq\": drug_seq}\n", " sample_dict = DtiBindingdbKdTask.data_preprocessing(\n", " sample_dict=sample_dict,\n", " tokenizer_op=tokenizer_op[dti],\n", " target_sequence_key=\"target_seq\",\n", " drug_sequence_key=\"drug_seq\",\n", " norm_y_mean=None,\n", " norm_y_std=None,\n", " device=models[dti].device,\n", " )\n", " return sample_dict\n", "\n", "def run_prompt(prompt):\n", " # Create and load sample\n", " sample_dict = dict()\n", " sample_dict[ENCODER_INPUTS_STR] = prompt\n", "\n", " # Tokenize\n", " sample_dict=tokenizer_op[ppi](\n", " sample_dict=sample_dict,\n", " key_in=ENCODER_INPUTS_STR,\n", " key_out_tokens_ids=ENCODER_INPUTS_TOKENS,\n", " key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,\n", " )\n", " sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(sample_dict[ENCODER_INPUTS_TOKENS])\n", " sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(sample_dict[ENCODER_INPUTS_ATTENTION_MASK])\n", "\n", "\n", " # Generate Prediction\n", " batch_dict = models[ppi].generate(\n", " [sample_dict],\n", " output_scores=True,\n", " return_dict_in_generate=True,\n", " max_new_tokens=5,\n", ")\n", "\n", "\n", " # Get output\n", " generated_output = tokenizer_op[ppi]._tokenizer.decode(batch_dict[CLS_PRED][0])\n", " score = batch_dict['model.out.scores'][0][1][positive_token_id].item()\n", " \n", " return generated_output,score\n", "\n", "def create_and_run_prompt(prot1, prot2):\n", " prompt = format_prompt_ppi(prot1, prot2)\n", " res=prompt, *run_prompt(prompt=prompt)\n", " return res\n", "\n", "def create_and_run_prompt_dtb(prot, drug):\n", " sample_dict = format_prompt_dti(prot, drug)\n", " # Post-process the model's output\n", " # batch_dict = model_dti.forward_encoder_only([sample_dict])\n", " batch_dict = models[dti].forward_encoder_only([sample_dict])\n", " batch_dict = DtiBindingdbKdTask.process_model_output(\n", " batch_dict,\n", " scalars_preds_processed_key=\"model.out.dti_bindingdb_kd\",\n", " norm_y_mean=5.79384684128215,\n", " norm_y_std=1.33808027428196,\n", " )\n", " ans = [\n", " \"model.out.dti_bindingdb_kd\", float(batch_dict[\"model.out.dti_bindingdb_kd\"][0])\n", " ]\n", " res=sample_dict[\"data.query.encoder_input\"], *ans\n", " return res\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def create_ppi_demo():\n", " markup_text = f\"\"\"\n", "# Mammal based Protein-Protein Interaction (PPI) demonstration\n", "\n", "Given two protein sequences, estimate if the proteins interact or not.\n", "\n", "### Using the model from \n", "\n", " ```{model_paths[ppi]} ```\n", "\"\"\"\n", " with gr.Group() as ppi_demo:\n", " gr.Markdown(markup_text)\n", " with gr.Row():\n", " prot1 = gr.Textbox(\n", " label=\"Protein 1 sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=protein_calmodulin,\n", " )\n", " prot2 = gr.Textbox(\n", " label=\"Protein 2 sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=protein_calcineurin,\n", " )\n", " with gr.Row():\n", " run_mammal = gr.Button(\"Run Mammal prompt for Protein-Protein Interaction\",variant='primary')\n", " with gr.Row():\n", " prompt_box = gr.Textbox(label=\"Mammal prompt\",lines=5)\n", " \n", " with gr.Row():\n", " decoded = gr.Textbox(label=\"Mammal output\")\n", " run_mammal.click(\n", " fn=create_and_run_prompt,\n", " inputs=[prot1,prot2],\n", " outputs=[prompt_box,decoded,gr.Number(label='PPI score')]\n", " )\n", " with gr.Row():\n", " gr.Markdown(\"`````` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting\")\n", " ppi_demo.visible=False\n", " return ppi_demo\n", " \n", " \n", "def create_tdb_demo():\n", " markup_text = f\"\"\"\n", "# Mammal based Target-Drug binding affinity demonstration\n", "\n", "Given a protein sequence and a drug (in SMILES), estimate the binding affinity.\n", "\n", "### Using the model from \n", "\n", " ```{model_paths[dti]} ```\n", "\"\"\"\n", " with gr.Group() as tdb_demo:\n", " gr.Markdown(markup_text)\n", " with gr.Row():\n", " prot = gr.Textbox(\n", " label=\"Protein sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=target_seq\n", " )\n", " drug = gr.Textbox(\n", " label=\"drug sequence (SMILES)\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=drug_seq,\n", " )\n", " with gr.Row():\n", " run_mammal = gr.Button(\"Run Mammal prompt for Protein-Protein Interaction\",variant='primary')\n", " with gr.Row():\n", " prompt_box = gr.Textbox(label=\"Mammal prompt\",lines=5)\n", " \n", " with gr.Row():\n", " decoded = gr.Textbox(label=\"Mammal output\")\n", " run_mammal.click(\n", " fn=create_and_run_prompt_dtb,\n", " inputs=[prot,drug],\n", " outputs=[prompt_box,decoded,gr.Number(label='DTI score')]\n", " )\n", " with gr.Row():\n", " gr.Markdown(\"`````` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting\")\n", " tdb_demo.visible=False\n", " return tdb_demo\n", "\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "\n", "def create_application():\n", " \n", "\n", " with gr.Blocks() as demo:\n", " main_dropdown = gr.Dropdown(choices=[\"select demo\",ppi,dti])\n", " main_dropdown.interactive=True\n", " ppi_demo = create_ppi_demo()\n", " dtb_demo = create_tdb_demo()\n", " def set_ppi_vis(main_text):\n", " return gr.Group(visible=main_text==ppi),gr.Group(visible=main_text==dti)\n", " main_dropdown.change(set_ppi_vis, inputs=main_dropdown, outputs=[ppi_demo,dtb_demo])\n", " \n", " \n", " # # part II\n", " \n", " # gr.Markdown(markup_text)\n", " # with gr.Row():\n", " # target = gr.Textbox(\n", " # label=\"target sequence\",\n", " # # info=\"standard\",\n", " # interactive=True,\n", " # lines=1,\n", " # value=\"NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC\",\n", " # )\n", " # drug = gr.Textbox(\n", " # label=\"drug sequence\",\n", " # # info=\"standard\",\n", " # interactive=True,\n", " # lines=1,\n", " # value=\"CC(=O)NCCC1=CNc2c1cc(OC)cc2\",\n", " # )\n", " # with gr.Row():\n", " # dt_run_mammal = gr.Button(\"Run Mammal prompt for drug-target binding affinity\",variant='primary')\n", " # with gr.Row():\n", " # dt_prompt_box = gr.Textbox(label=\"Mammal prompt\",lines=5)\n", " \n", " # with gr.Row():\n", " # dt_decoded = gr.Textbox(label=\"Mammal output\")\n", " # dt_run_mammal.click(\n", " # fn=create_and_run_prompt,\n", " # inputs=[target,drug],\n", " # outputs=[dt_prompt_box,dt_decoded,gr.Number(label='PPI score')]\n", " # )\n", " # with gr.Row():\n", " # gr.Markdown(\"`````` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting\")\n", " \n", " return demo\n", "\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7865\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "demo = create_application()\n", "\n", "demo.launch(show_error=True, share=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# demo.blocks[7].show_label = True\n", "demo.children\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "i=-1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "i = i+1\n", "print(i)\n", "demo.blocks[i].label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo.blocks[7].value = \"changed\"\n", "demo.blocks[7].visible = True " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gr.close_all()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo.visible=False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo.clear()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dropdown_change(val):\n", " if val==\"other\":\n", " visibility=False\n", " else:\n", " visibility=True\n", " return gr.Radio([\"y\",\"N\",val], interactive=True, visible=visibility)\n", "\n", "with gr.Blocks() as d2:\n", " n1 = gr.Radio([\"y\",\"N\"])\n", " n3=gr.Radio([\"nothing\",\"here\"])\n", " @gr.render(inputs=[n1])\n", " def do_something(text, outputs=[n3]):\n", " print(text)\n", " # gr.Radio([text,\"not\"])\n", " # n3=gr.Button(\"bbb\")\n", " dr1 = gr.Dropdown(choices=[\"one\",\"two\",\"other\"], )\n", " dr1.interactive=True\n", " \n", " b1=gr.Button(\"asf\",interactive=True,)\n", " b1.click(dropdown_change, inputs=[n1], outputs=[n3] )\n", " \n", " dr1.change(fn=dropdown_change, inputs=[dr1], outputs=[n1])\n", " \n", "d2.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n1=gr.Radio(gr.update(n1, kwargs={\"interactive\": False}))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gr.update(n1, kwargs={\"interactive\": False})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n1.visible" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "mammal", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }