{ "cells": [ { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "from tqdm import tqdm\n", "import pandas as pd\n", "import re\n", "import openai\n", "from config import OPENAI_API_KEY\n", "import dask" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "openai.api_key = OPENAI_API_KEY" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "data_df = pd.read_csv('wiki_intro_processed.csv')" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def chunkify(lst, n):\n", " return [lst[i:i + n] for i in range(0, len(lst), n)]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def create_prompt(title, starter_text):\n", " return f'''200 word wikipedia style introduction on '{title}'\n", " {starter_text}'''" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def get_openai_response(dct): \n", " title = dct['title']\n", " starter_text = dct['starter_text']\n", "\n", " prompt = create_prompt(title, starter_text)\n", "\n", " return {title: openai.Completion.create(\n", " model=\"text-curie-001\",\n", " prompt=prompt,\n", " temperature=0.7,\n", " max_tokens=300,\n", " top_p=1,\n", " frequency_penalty=0.4,\n", " presence_penalty=0.1\n", " )}" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "def fake_get_openai_response(dct):\n", " title = dct['title']\n", " starter_text = dct['starter_text']\n", "\n", " prompt = create_prompt(title, starter_text)\n", "\n", " return {title: f'{prompt} blah blah blah'}" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def run_dask(lst_dct):\n", " delayed_calls = [dask.delayed(get_openai_response)(dct) for dct in lst_dct]\n", " results = dask.compute(*delayed_calls)\n", " return results" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "2\n", "3\n", "4\n", "5\n", "6\n", "7\n", "8\n", "9\n", "10\n", "11\n", "12\n", "13\n", "14\n", "15\n", "16\n", "17\n", "18\n", "19\n", "20\n", "21\n", "22\n", "23\n", "24\n", "25\n", "26\n", "27\n", "28\n", "29\n", "30\n", "31\n", "32\n", "33\n", "34\n", "35\n", "36\n", "37\n", "38\n", "39\n", "40\n", "41\n", "42\n", "43\n", "44\n", "45\n", "46\n", "47\n", "48\n", "49\n", "50\n", "51\n", "52\n", "53\n", "54\n", "55\n", "56\n", "57\n", "58\n", "59\n", "60\n", "61\n", "62\n", "63\n", "64\n", "65\n", "66\n", "67\n", "68\n", "69\n", "70\n", "71\n", "72\n", "73\n", "74\n", "75\n", "76\n", "77\n", "78\n", "79\n", "80\n", "81\n", "82\n", "83\n", "84\n", "85\n", "86\n", "87\n", "88\n", "89\n", "90\n", "91\n", "92\n", "93\n", "94\n", "95\n", "96\n", "97\n", "98\n", "99\n", "100\n", "101\n", "102\n", "103\n", "104\n", "105\n", "106\n", "107\n", "108\n", "109\n", "110\n", "111\n", "112\n", "113\n", "114\n", "115\n", "116\n", "117\n", "118\n", "119\n", "120\n", "121\n", "122\n", "123\n", "124\n", "125\n", "126\n", "127\n", "128\n", "129\n", "130\n", "131\n", "132\n", "133\n", "134\n", "135\n", "136\n", "137\n", "138\n", "139\n", "140\n", "141\n", "142\n", "143\n", "144\n", "145\n", "146\n", "147\n", "148\n", "149\n" ] } ], "source": [ "chunk_n = 0\n", "for chunk in chunkify(data_df[['title', 'starter_text']].to_dict('records'), 1000):\n", " if chunk_n == 0:\n", " chunk_n += 1\n", " continue\n", " print(chunk_n)\n", " result = run_dask(chunk)\n", " with open(f'data/result-{chunk_n}.pkl', 'wb') as file:\n", " pickle.dump(result, file)\n", " chunk_n += 1\n" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "3f100d68d9cf80676b1a4c3ace5430b03ae266a1d88e3f101eb196b64b263632" } } }, "nbformat": 4, "nbformat_minor": 2 }