diff --git "a/modular.ipynb" "b/modular.ipynb" new file mode 100644--- /dev/null +++ "b/modular.ipynb" @@ -0,0 +1,1079 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import gradio as gr\n", + "import torch\n", + "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n", + "from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask\n", + "from mammal.keys import *\n", + "from mammal.model import Mammal\n", + "from abc import ABC, abstractmethod\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MammalObjectBroker():\n", + " def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:\n", + " self.model_path = model_path\n", + " if name is None:\n", + " name = model_path\n", + " self.name = name \n", + " \n", + " if task_list is not None:\n", + " self.tasks=task_list\n", + " else:\n", + " self.task = []\n", + " self._model = None\n", + " self._tokenizer_op = None\n", + " \n", + " \n", + " @property\n", + " def model(self)-> Mammal:\n", + " if self._model is None:\n", + " self._model = Mammal.from_pretrained(self.model_path)\n", + " self._model.eval()\n", + " return self._model\n", + " \n", + " @property\n", + " def tokenizer_op(self):\n", + " if self._tokenizer_op is None:\n", + " self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)\n", + " return self._tokenizer_op\n", + " \n", + " \n", + " \n", + " \n", + "\n", + "class MammalTask(ABC):\n", + " def __init__(self, name:str) -> None:\n", + " self.name = name\n", + " self.description = None\n", + " self._demo = None\n", + "\n", + " @abstractmethod\n", + " def generate_prompt(self, **kwargs) -> str:\n", + " \"\"\"Formatting prompt to match pre-training syntax\n", + "\n", + " Args:\n", + " prot1 (_type_): _description_\n", + " prot2 (_type_): _description_\n", + "\n", + " Raises:\n", + " No: _description_\n", + " \"\"\"\n", + " raise NotImplementedError()\n", + "\n", + " @abstractmethod\n", + " def crate_sample_dict(self, prompt: str, **kwargs) -> dict:\n", + " \"\"\"Formatting prompt to match pre-training syntax\n", + "\n", + " Args:\n", + " prompt (str): _description_\n", + "\n", + " Returns:\n", + " dict: sample_dict for feeding into model\n", + " \"\"\"\n", + " raise NotImplementedError()\n", + "\n", + " @abstractmethod\n", + " def run_model(_, sample_dict, model:Mammal):\n", + " raise NotImplementedError()\n", + " \n", + " def decode_output(self,batch_dict, model):\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def create_demo(self):\n", + " \"\"\"create an gradio demo group\n", + "\n", + " Returns:\n", + " _type_: _description_\n", + " \"\"\"\n", + " raise NotImplementedError()\n", + "\n", + " \n", + " def demo(self,model_dropdown=None):\n", + " if self._demo is None:\n", + " self._demo = self.create_demo(model_dropdown)\n", + " return self._demo\n", + "\n", + " @abstractmethod\n", + " def decode_output(self,batch_dict, model:Mammal):\n", + " raise NotImplementedError()\n", + "\n", + " #self._setup()\n", + " \n", + " # def _setup(self):\n", + " # pass\n", + " \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_tasks = dict()\n", + "all_models= dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class PpiTask(MammalTask):\n", + " def __init__(self):\n", + " super().__init__(name=\"PPI\")\n", + " self.description = \"Protein-Protein Interaction (PPI)\"\n", + " self.examples = {\n", + " \"protein_calmodulin\": ,\"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK\"\n", + " \"protein_calcineurin\": \"MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ\",\n", + " }\n", + " self.markup_text = \"\"\"\n", + " # Mammal based {self.description} demonstration\n", + " \n", + " Given two protein sequences, estimate if the proteins interact or not.\"\"\"\n", + " \n", + " \n", + " \n", + " @staticmethod\n", + " def positive_token_id(model_holder: MammalObjectBroker):\n", + " \"\"\"token for positive binding\n", + "\n", + " Args:\n", + " model (MammalTrainedModel): model holding tokenizer\n", + "\n", + " Returns:\n", + " int: id of positive binding token\n", + " \"\"\"\n", + " return model_holder.tokenizer_op.get_token_id(\"<1>\")\n", + " \n", + " def generate_prompt(self, prot1, prot2):\n", + " \"\"\"Formatting prompt to match pre-training syntax\n", + "\n", + " Args:\n", + " prot1 (str): sequance of protein number 1\n", + " prot2 (str): sequance of protein number 2\n", + "\n", + " Returns:\n", + " str: prompt\n", + " \"\"\" \n", + " prompt = \"<@TOKENIZER-TYPE=AA>\"\\\n", + " \"\"\\\n", + " f\"{prot1}\"\\\n", + " \"\"\\\n", + " f\"{prot2}\"\n", + " return prompt\n", + " \n", + " \n", + " def crate_sample_dict(self,prompt: str, model_holder:MammalObjectBroker):\n", + " # Create and load sample\n", + " sample_dict = dict()\n", + " sample_dict[ENCODER_INPUTS_STR] = prompt\n", + "\n", + " # Tokenize\n", + " sample_dict = model_holder.tokenizer_op(\n", + " sample_dict=sample_dict,\n", + " key_in=ENCODER_INPUTS_STR,\n", + " key_out_tokens_ids=ENCODER_INPUTS_TOKENS,\n", + " key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,\n", + " )\n", + " sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(\n", + " sample_dict[ENCODER_INPUTS_TOKENS]\n", + " )\n", + " sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(\n", + " sample_dict[ENCODER_INPUTS_ATTENTION_MASK]\n", + " )\n", + " return sample_dict\n", + "\n", + " def run_model(_, sample_dict, model: Mammal):\n", + " # Generate Prediction\n", + " batch_dict = model.generate(\n", + " [sample_dict],\n", + " output_scores=True,\n", + " return_dict_in_generate=True,\n", + " max_new_tokens=5,\n", + " )\n", + " return batch_dict\n", + " \n", + " def decode_output(self,batch_dict, model_holder):\n", + "\n", + " # Get output\n", + " generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])\n", + " score = batch_dict[\"model.out.scores\"][0][1][self.positive_token_id(model_holder)].item()\n", + "\n", + " return generated_output, score\n", + "\n", + "\n", + " def create_and_run_prompt(self,model_name,protein1, protein2):\n", + " model_holder = all_models[model_name]\n", + " prompt = self.generate_prompt(protein1, protein2)\n", + " sample_dict = self.crate_sample_dict(prompt=prompt, model_holder=model_holder)\n", + " model_output = self.run_model(sample_dict=sample_dict, model=model_holder.model)\n", + " res = prompt, *model_output\n", + " return res\n", + "\n", + " \n", + " def create_demo(self,model_name_dropdown):\n", + " \n", + " # \"\"\"\n", + " # ### Using the model from\n", + "\n", + " # ```{model} ```\n", + " # \"\"\"\n", + " with gr.Group() as demo:\n", + " gr.Markdown(self.markup_text)\n", + " with gr.Row():\n", + " prot1 = gr.Textbox(\n", + " label=\"Protein 1 sequence\",\n", + " # info=\"standard\",\n", + " interactive=True,\n", + " lines=3,\n", + " value=self.examples[\"protein_calmodulin\"],\n", + " )\n", + " prot2 = gr.Textbox(\n", + " label=\"Protein 2 sequence\",\n", + " # info=\"standard\",\n", + " interactive=True,\n", + " lines=3,\n", + " value=self.examples[\"protein_calcineurin\"],\n", + " )\n", + " with gr.Row():\n", + " run_mammal = gr.Button(\n", + " \"Run Mammal prompt for Protein-Protein Interaction\", variant=\"primary\"\n", + " )\n", + " with gr.Row():\n", + " prompt_box = gr.Textbox(label=\"Mammal prompt\", lines=5)\n", + "\n", + " with gr.Row():\n", + " decoded = gr.Textbox(label=\"Mammal output\")\n", + " run_mammal.click(\n", + " fn=self.create_and_run_prompt,\n", + " inputs=[model_name_dropdown, prot1, prot2],\n", + " outputs=[prompt_box, decoded, gr.Number(label=\"PPI score\")],\n", + " )\n", + " with gr.Row():\n", + " gr.Markdown(\n", + " \"`````` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting\"\n", + " )\n", + " demo.visible = True\n", + " return demo\n", + "\n", + "ppi_task = PpiTask()\n", + "all_tasks[ppi_task.name]=ppi_task\n", + "all_tasks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "### DTI:\n", + "\n", + "#\n", + "dti = \"Drug-Target Binding Affinity\"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "# input\n", + "target_seq = \"NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC\"\n", + "drug_seq = \"CC(=O)NCCC1=CNc2c1cc(OC)cc2\"\n", + "\n", + "\n", + "# token for positive binding\n", + "positive_token_id = tokenizer_op[dti].get_token_id(\"<1>\")\n", + "\n", + "\n", + "def generate_prompt_dti(prot, drug):\n", + " sample_dict = {\"target_seq\": target_seq, \"drug_seq\": drug_seq}\n", + " sample_dict = DtiBindingdbKdTask.data_preprocessing(\n", + " sample_dict=sample_dict,\n", + " tokenizer_op=tokenizer_op[dti],\n", + " target_sequence_key=\"target_seq\",\n", + " drug_sequence_key=\"drug_seq\",\n", + " norm_y_mean=None,\n", + " norm_y_std=None,\n", + " device=models[dti].device,\n", + " )\n", + " return sample_dict\n", + "\n", + "\n", + "def create_and_run_prompt_dtb(prot, drug):\n", + " sample_dict = generate_prompt_dti(prot, drug)\n", + " # Post-process the model's output\n", + " # batch_dict = model_dti.forward_encoder_only([sample_dict])\n", + " batch_dict = models[dti].forward_encoder_only([sample_dict])\n", + " batch_dict = DtiBindingdbKdTask.process_model_output(\n", + " batch_dict,\n", + " scalars_preds_processed_key=\"model.out.dti_bindingdb_kd\",\n", + " norm_y_mean=5.79384684128215,\n", + " norm_y_std=1.33808027428196,\n", + " )\n", + " ans = [\n", + " \"model.out.dti_bindingdb_kd\",\n", + " float(batch_dict[\"model.out.dti_bindingdb_kd\"][0]),\n", + " ]\n", + " res = sample_dict[\"data.query.encoder_input\"], *ans\n", + " return res\n", + "\n", + "\n", + "def create_tdb_demo():\n", + " markup_text = f\"\"\"\n", + "# Mammal based Target-Drug binding affinity demonstration\n", + "\n", + "Given a protein sequence and a drug (in SMILES), estimate the binding affinity.\n", + "\n", + "### Using the model from\n", + "\n", + " ```{model_paths[dti]} ```\n", + "\"\"\"\n", + " with gr.Group() as tdb_demo:\n", + " gr.Markdown(markup_text)\n", + " with gr.Row():\n", + " prot = gr.Textbox(\n", + " label=\"Protein sequence\",\n", + " # info=\"standard\",\n", + " interactive=True,\n", + " lines=3,\n", + " value=target_seq,\n", + " )\n", + " drug = gr.Textbox(\n", + " label=\"drug sequence (SMILES)\",\n", + " # info=\"standard\",\n", + " interactive=True,\n", + " lines=3,\n", + " value=drug_seq,\n", + " )\n", + " with gr.Row():\n", + " run_mammal = gr.Button(\n", + " \"Run Mammal prompt for Target Drug Affinity\", variant=\"primary\"\n", + " )\n", + " with gr.Row():\n", + " prompt_box = gr.Textbox(label=\"Mammal prompt\", lines=5)\n", + "\n", + " with gr.Row():\n", + " decoded = gr.Textbox(label=\"Mammal output\")\n", + " run_mammal.click(\n", + " fn=create_and_run_prompt_dtb,\n", + " inputs=[prot, drug],\n", + " outputs=[prompt_box, decoded, gr.Number(label=\"DTI score\")],\n", + " )\n", + " tdb_demo.visible = False\n", + " return tdb_demo\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "ppi_model = MammalObjectBroker(model_path=\"ibm/biomed.omics.bl.sm.ma-ted-458m\", task_list=[\"PPI\"])\n", + "\n", + "all_models[ppi_model.name]=ppi_model\n", + "# tdi_model = MammalTrainedModel(model_path=\"ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd\") TODO: ## task list still empty\n", + "# all_models.append(tdi_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def create_application():\n", + " def task_change(value):\n", + " choices=[model_name for model_name, model in all_models.items() if value in model.tasks]\n", + " if choices:\n", + " return gr.update(choices=choices, value=choices[0])\n", + " else:\n", + " return\n", + " # return model_dropdown\n", + " \n", + " \n", + " with gr.Blocks() as demo:\n", + " task_dropdown = gr.Dropdown(choices=[\"select demo\"] + list(all_tasks.keys()))\n", + " task_dropdown.interactive = True\n", + " model_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)\n", + " task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_dropdown])\n", + " \n", + " \n", + "\n", + "\n", + "\n", + " ppi_demo = all_tasks[\"PPI\"].demo(model_dropdown = model_dropdown)\n", + " ppi_demo.visible = True\n", + " # dtb_demo = create_tdb_demo()\n", + "\n", + " def set_ppi_vis(main_text):\n", + " main_text=main_text\n", + " print(f\"main text is {main_text}\")\n", + " return gr.Group(visible=True)\n", + " #return gr.Group(visible=(main_text == \"PPI\"))\n", + " # , gr.Group( visible=(main_text == \"DTI\") )\n", + "\n", + " task_dropdown.change(\n", + " set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]\n", + " )\n", + " return demo\n", + "\n", + "full_demo=None\n", + "def main():\n", + " global full_demo\n", + " full_demo = create_application()\n", + " full_demo.launch(show_error=True, share=False)\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for model_name, model_holder in all_models.items():\n", + " print(model_name)\n", + " print(model_holder.tasks, \"PPI\" in model_holder.tasks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "full_demo.blocks[240].EVENTS" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from mammal.examples.tcr_epitope_binding.main_infer import load_model, task_infer\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "tcr_beta_seq = \"NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT\"\n", + "epitope_seq = \"LLQTGIHVRVSQPSL\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9d3f97cfb3784a95a974b73b5bdbf0cc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 10 files: 0%| | 0/10 [00:00 7\u001b[0m \u001b[43mall_models\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mregister_model\u001b[49m \u001b[38;5;241m=\u001b[39m register_model\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__get__\u001b[39m(all_models, \u001b[38;5;28mdict\u001b[39m)\n\u001b[1;32m 8\u001b[0m all_models\u001b[38;5;241m.\u001b[39mregister_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel3\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(all_models)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'dict' object has no attribute 'register_model'" + ] + } + ], + "source": [ + "# Assisted by watsonx Code Assistant \n", + "all_models = {'model1': 'model1_path', 'model2': 'model2_path'}\n", + "\n", + "def register_model(self, name):\n", + " self.update({name: f'{name}_path'})\n", + "\n", + "all_models.register_model = register_model.__get__(all_models, dict)\n", + "all_models.register_model(\"model3\")\n", + "print(all_models)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "class AllModels(dict):\n", + " def register_model(self, name):\n", + " self.update({name: f'{name}_path'})\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_models=AllModels()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "all_models.register_model(\"abc\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'abc': 'abc_path'}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_models" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mammal", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}