{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "247c85dd", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "d36498d1", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Typed Extraction\n", "\n", "Information extraction that is automatically generated from a typed specification. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb)\n", "\n", "(Novel to MiniChain)\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "e894e4fb", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "from minichain import prompt, show, type_to_prompt, OpenAI\n", "from dataclasses import dataclass\n", "from typing import List\n", "from enum import Enum" ] }, { "cell_type": "markdown", "id": "6ac1ff70", "metadata": {}, "source": [ "Data specification" ] }, { "cell_type": "code", "execution_count": null, "id": "64a00e69", "metadata": {}, "outputs": [], "source": [ "class StatType(Enum):\n", " POINTS = 1\n", " REBOUNDS = 2\n", " ASSISTS = 3\n", "\n", "@dataclass\n", "class Stat:\n", " value: int\n", " stat: StatType\n", "\n", "@dataclass\n", "class Player:\n", " player: str\n", " stats: List[Stat]" ] }, { "cell_type": "code", "execution_count": null, "id": "35e0ed85", "metadata": {}, "outputs": [], "source": [ "@prompt(OpenAI(), template_file=\"stats.pmpt.tpl\", parser=\"json\")\n", "def stats(model, passage):\n", " out = model(dict(passage=passage, typ=type_to_prompt(Player)))\n", " return [Player(**j) for j in out]" ] }, { "cell_type": "code", "execution_count": null, "id": "7241bdc6", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "article = open(\"sixers.txt\").read()\n", "gradio = show(lambda passage: stats(passage),\n", " examples=[article],\n", " subprompts=[stats],\n", " out_type=\"json\",\n", " description=desc,\n", ")\n", "if __name__ == \"__main__\":\n", " gradio.launch()" ] }, { "cell_type": "markdown", "id": "18ed1b96", "metadata": {}, "source": [ "ExtractionPrompt().show({\"passage\": \"Harden had 10 rebounds.\"},\n", " '[{\"player\": \"Harden\", \"stats\": {\"value\": 10, \"stat\": 2}}]')" ] }, { "cell_type": "markdown", "id": "02c96e7d", "metadata": {}, "source": [ "# View the run log." ] }, { "cell_type": "markdown", "id": "0ea427dd", "metadata": {}, "source": [ "minichain.show_log(\"bash.log\")" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }