{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a914a238", "metadata": {}, "outputs": [], "source": [ "!pip install -q git+https://github.com/srush/MiniChain\n", "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " ] }, { "cell_type": "code", "execution_count": null, "id": "aa7b8a98", "metadata": { "lines_to_next_cell": 2, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Word Problem Solver\n", "\n", "Chain that solves a math word problem by first generating and then running Python code. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/math_demo.ipynb)\n", "\n", "(Adapted from Dust [maths-generate-code](https://dust.tt/spolu/a/d12ac33169))\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "1a34507a", "metadata": {}, "outputs": [], "source": [ "from minichain import show, prompt, OpenAI, Python" ] }, { "cell_type": "code", "execution_count": null, "id": "43cdbfc0", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(), template_file=\"math.pmpt.tpl\")\n", "def math_prompt(model, question):\n", " \"Prompt to call GPT with a Jinja template\"\n", " return model(dict(question=question))" ] }, { "cell_type": "code", "execution_count": null, "id": "14d6eb95", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(Python(), template=\"import math\\n{{code}}\")\n", "def python(model, code):\n", " \"Prompt to call Python interpreter\"\n", " code = \"\\n\".join(code.strip().split(\"\\n\")[1:-1])\n", " return int(model(dict(code=code)))" ] }, { "cell_type": "code", "execution_count": null, "id": "dac2ba43", "metadata": {}, "outputs": [], "source": [ "def math_demo(question):\n", " \"Chain them together\"\n", " return python(math_prompt(question))" ] }, { "cell_type": "code", "execution_count": null, "id": "4a4f5bbf", "metadata": { "lines_to_next_cell": 0, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "gradio = show(math_demo,\n", " examples=[\"What is the sum of the powers of 3 (3^i) that are smaller than 100?\",\n", " \"What is the sum of the 10 first positive integers?\",],\n", " # \"Carla is downloading a 200 GB file. She can download 2 GB/minute, but 40% of the way through the download, the download fails. Then Carla has to restart the download from the beginning. How load did it take her to download the file in minutes?\"],\n", " subprompts=[math_prompt, python],\n", " out_type=\"json\",\n", " description=desc,\n", " )\n", "if __name__ == \"__main__\":\n", " gradio.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "a478587c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "cell_metadata_filter": "tags,-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }