{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "8eccacc7", "metadata": {}, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "ebfe63f6", "metadata": { "lines_to_next_cell": 2, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "\n", "desc = \"\"\"\n", "### Named Entity Recognition\n", "\n", "Chain that does named entity recognition with arbitrary labels. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/ner.ipynb)\n", "\n", "(Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja)).\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "45dd8a11", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "from minichain import prompt, show, OpenAI" ] }, { "cell_type": "code", "execution_count": null, "id": "9ada6ebb", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(), template_file = \"ner.pmpt.tpl\", parser=\"json\")\n", "def ner_extract(model, kwargs):\n", " return model(kwargs)" ] }, { "cell_type": "code", "execution_count": null, "id": "f6873c42", "metadata": {}, "outputs": [], "source": [ "@prompt(OpenAI())\n", "def team_describe(model, inp):\n", " query = \"Can you describe these basketball teams? \" + \\\n", " \" \".join([i[\"E\"] for i in inp if i[\"T\"] ==\"Team\"])\n", " return model(query)" ] }, { "cell_type": "code", "execution_count": null, "id": "a89fa41d", "metadata": {}, "outputs": [], "source": [ "def ner(text_input, labels, domain):\n", " extract = ner_extract(dict(text_input=text_input, labels=labels, domain=domain))\n", " return team_describe(extract)" ] }, { "cell_type": "code", "execution_count": null, "id": "3e8a0502", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "634fb50b", "metadata": {}, "outputs": [], "source": [ "gradio = show(ner,\n", " examples=[[\"An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.\", \"Team, Date\", \"Sports\"]],\n", " description=desc,\n", " subprompts=[ner_extract, team_describe],\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "fa353224", "metadata": {}, "outputs": [], "source": [ "if __name__ == \"__main__\":\n", " gradio.launch()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "tags,-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }