{ "cells": [ { "cell_type": "markdown", "id": "2e16f61f", "metadata": {}, "source": [ "# NER" ] }, { "cell_type": "markdown", "id": "904e43dd", "metadata": {}, "source": [ "Notebook implementation of named entity recognition.\n", "Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja)." ] }, { "cell_type": "code", "execution_count": 2, "id": "b4b1d58e", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.445242Z", "iopub.status.busy": "2023-03-13T23:43:12.444962Z", "iopub.status.idle": "2023-03-13T23:43:12.450741Z", "shell.execute_reply": "2023-03-13T23:43:12.450139Z" } }, "outputs": [], "source": [ "import json" ] }, { "cell_type": "code", "execution_count": 3, "id": "fdb154d0", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.453113Z", "iopub.status.busy": "2023-03-13T23:43:12.452884Z", "iopub.status.idle": "2023-03-13T23:43:12.649309Z", "shell.execute_reply": "2023-03-13T23:43:12.648483Z" }, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "import minichain" ] }, { "cell_type": "markdown", "id": "d5665917", "metadata": {}, "source": [ "Prompt to extract NER tags as json" ] }, { "cell_type": "code", "execution_count": 4, "id": "1cfe0e75", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.654908Z", "iopub.status.busy": "2023-03-13T23:43:12.653463Z", "iopub.status.idle": "2023-03-13T23:43:12.660078Z", "shell.execute_reply": "2023-03-13T23:43:12.659313Z" }, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "class NERPrompt(minichain.TemplatePrompt):\n", " template_file = \"ner.pmpt.tpl\"\n", "\n", " def parse(self, response, inp):\n", " return json.loads(response)" ] }, { "cell_type": "markdown", "id": "11619d3d", "metadata": {}, "source": [ "Use NER to ask a simple queston." ] }, { "cell_type": "code", "execution_count": 5, "id": "584bef0d", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.667113Z", "iopub.status.busy": "2023-03-13T23:43:12.665599Z", "iopub.status.idle": "2023-03-13T23:43:12.673456Z", "shell.execute_reply": "2023-03-13T23:43:12.672558Z" }, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "class TeamPrompt(minichain.Prompt):\n", " def prompt(self, inp):\n", " return \"Can you describe these basketball teams? \" + \\\n", " \" \".join([i[\"E\"] for i in inp if i[\"T\"] ==\"Team\"])\n", "\n", " def parse(self, response, inp):\n", " return response" ] }, { "cell_type": "markdown", "id": "6ea6c161", "metadata": {}, "source": [ "Run the system." ] }, { "cell_type": "code", "execution_count": 6, "id": "a8ee77f4", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.678805Z", "iopub.status.busy": "2023-03-13T23:43:12.678446Z", "iopub.status.idle": "2023-03-13T23:43:12.682592Z", "shell.execute_reply": "2023-03-13T23:43:12.682060Z" } }, "outputs": [], "source": [ "with minichain.start_chain(\"ner\") as backend:\n", " ner_prompt = NERPrompt(backend.OpenAI())\n", " team_prompt = TeamPrompt(backend.OpenAI())\n", " prompt = ner_prompt.chain(team_prompt)\n", " # results = prompt(\n", " # {\"text_input\": \"An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.\",\n", " # \"labels\" : [\"Team\", \"Date\"],\n", " # \"domain\": \"Sports\"\n", " # }\n", " # )\n", " # print(results)" ] }, { "cell_type": "code", "execution_count": 7, "id": "55b9ce94", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.684777Z", "iopub.status.busy": "2023-03-13T23:43:12.684591Z", "iopub.status.idle": "2023-03-13T23:43:12.687815Z", "shell.execute_reply": "2023-03-13T23:43:12.687194Z" } }, "outputs": [], "source": [ "ner_prompt.set_display_options(markdown=True)\n", "team_prompt.set_display_options(markdown=True) " ] }, { "cell_type": "code", "execution_count": 8, "id": "fe56c4ba", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:12.690233Z", "iopub.status.busy": "2023-03-13T23:43:12.689776Z", "iopub.status.idle": "2023-03-13T23:43:19.799186Z", "shell.execute_reply": "2023-03-13T23:43:19.798652Z" }, "lines_to_next_cell": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt.to_gradio(fields =[\"text_input\", \"labels\", \"domain\"],\n", " examples=[[\"An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.\", \"Team, Date\", \"Sports\"]]).launch()" ] }, { "cell_type": "markdown", "id": "0c81d136", "metadata": {}, "source": [ "View prompt examples." ] }, { "cell_type": "code", "execution_count": 8, "id": "d75cba8c", "metadata": { "execution": { "iopub.execute_input": "2023-03-13T23:43:19.802519Z", "iopub.status.busy": "2023-03-13T23:43:19.802098Z", "iopub.status.idle": "2023-03-13T23:43:19.805558Z", "shell.execute_reply": "2023-03-13T23:43:19.804994Z" }, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "# NERPrompt().show(\n", "# {\n", "# \"input\": \"I went to New York\",\n", "# \"domain\": \"Travel\",\n", "# \"labels\": [\"City\"]\n", "# },\n", "# '[{\"T\": \"City\", \"E\": \"New York\"}]',\n", "# )\n", "# # -\n", "\n", "# # View log.\n", "\n", "# minichain.show_log(\"ner.log\")" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "tags,-all" }, "kernelspec": { "display_name": "minichain", "language": "python", "name": "minichain" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }