{ "cells": [ { "cell_type": "markdown", "source": [ "# beecoder - basic test notebook\n", "\n", "\n", "- https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA-python\n", "- 🐝" ], "metadata": { "id": "5948Y879XU-V" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lIYdn1woOS1n", "outputId": "9075b1fa-bc2e-4e1a-f181-d2869f99317b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m41.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m54.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.3/88.3 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m53.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "pip install transformers==4.33.3 accelerate sentencepiece fire -q" ] }, { "cell_type": "code", "source": [ "\"\"\"\n", "simple script for testing model(s) designed to generate/complete code\n", "\n", "See details/args with the below.\n", " python textgen_inference_code.py --help\n", "\"\"\"\n", "import logging\n", "import random\n", "import time\n", "from pathlib import Path\n", "\n", "import fire\n", "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "\n", "\n", "class Timer:\n", " \"\"\"\n", " Basic timer utility.\n", " \"\"\"\n", "\n", " def __enter__(self):\n", "\n", " self.start_time = time.perf_counter()\n", " return self\n", "\n", " def __exit__(self, exc_type, exc_value, traceback):\n", "\n", " self.end_time = time.perf_counter()\n", " self.elapsed_time = self.end_time - self.start_time\n", " print(f\"Elapsed time: {self.elapsed_time:.4f} seconds\")\n", "\n", "\n", "def load_model(model_name, use_fast=False):\n", " \"\"\" util for loading model and tokenizer\"\"\"\n", " print(f\"Loading model: {model_name}\")\n", " tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)\n", " model = AutoModelForCausalLM.from_pretrained(\n", " model_name, torch_dtype=\"auto\", device_map=\"auto\"\n", " )\n", " model = torch.compile(model)\n", " return tokenizer, model\n", "\n", "\n", "def run_inference(prompt, model, tokenizer, max_new_tokens: int = 256):\n", " \"\"\"\n", " run_inference\n", "\n", " Args:\n", " prompt (TYPE): Description\n", " model (TYPE): Description\n", " tokenizer (TYPE): Description\n", " max_new_tokens (int, optional): Description\n", "\n", " Returns:\n", " TYPE: Description\n", " \"\"\"\n", " print(f\"Running inference with max_new_tokens={max_new_tokens} ...\")\n", " with Timer() as timer:\n", " inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", " outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=max_new_tokens,\n", " min_new_tokens=8,\n", " renormalize_logits=True,\n", " no_repeat_ngram_size=8,\n", " repetition_penalty=1.04,\n", " num_beams=4,\n", " early_stopping=True,\n", " )\n", " text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n", " print(f\"Output text:\\n\\n{text}\")\n", " return text\n", "\n", "\n", "def main(\n", " model_name=\"BEE-spoke-data/smol_llama-101M-GQA-python\",\n", " prompt:str=None,\n", " use_fast=False,\n", " n_tokens: int = 256,\n", "):\n", " \"\"\"Summary\n", "\n", " Args:\n", " model_name (str, optional): Description\n", " prompt (None, optional): specify the prompt directly (default: random choice from list)\n", " n_tokens (int, optional): max new tokens to generate\n", " \"\"\"\n", " print(f\"Inference with:\\t{model_name}, max_new_tokens:{n_tokens}\")\n", "\n", " if prompt is None:\n", " prompt_list = [\n", " '''\n", " def print_primes(n: int):\n", " \"\"\"\n", " Print all primes between 1 and n\n", " \"\"\"''',\n", " \"def quantum_analysis(\",\n", " \"def sanitize_filenames(target_dir:str, recursive:False, extension\",\n", " ]\n", " prompt = random.SystemRandom().choice(prompt_list)\n", "\n", " print(f\"Using prompt:\\t{prompt}\")\n", "\n", " tokenizer, model = load_model(model_name, use_fast=use_fast)\n", "\n", " run_inference(prompt, model, tokenizer, n_tokens)\n", "\n", "\n", "# if __name__ == \"__main__\":\n", "# fire.Fire(main)" ], "metadata": { "id": "Qd59E_mlW67e" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "main(\n", " prompt=None, # replace with your custom prompt string\n", ")" ], "metadata": { "id": "mToCRUd7Xk7V", "outputId": "40cd0f7f-6783-4552-dea3-eab2ffd74cd8", "colab": { "base_uri": "https://localhost:8080/", "height": 796, "referenced_widgets": [ "d9028652e32848c48bda98d6437b5c65", "95e066c18f7f4eac897a6c03d912e43e", "00f4d2d0a2424969a0fe1299f022e82a", "87863465112e48c6bd40a9a4d279e631", "13030a123fe744fdb16ad24583df2765", "40b70d660f834f488a414700184e069c", "3340b5ee1b8847b59cb46aab716e31b6", "f34163eb8e4d4eaab807ae5604bab0f3", "c04833f28e5c46789c269bc1de0f78a2", "8e467e3f7fa84efdad7c2342c53ab14f", "63b3635dfc8d47cdb597a3c083ec99a2", "78601e1c9cba4a50a971277e8a261936", "dceede4a055f4e26a47c51b5427dea89", "c7565e9d4f764d6d8a63fbb7ac3af8f6", "ad69924e84394c048b9a5076f2d65db1", "19ad6fe2c6e44422815d33d5afd7055b", "d0ff994892884ab3acad83f03ad024b1", "5a035fa3d0ce460b9df0b3df26eb718d", "f9caa24be78e486f9c96d372b0dc8372", "7d210f04d3b5461a955944e5bd3adf74", "3a802174ad844f02b72f3ff7609b3429", "f40e4816c8bb489a8517d59326b979ea", "5647151fdba34b9bbd5e85d90f1af21b", "9099bd1ad0ec4777bdffef46ac3eac03", "7ff539a4b06f4bf5af7ee37c985883ec", "77d68c43ec344f3187b9dd542abc3418", "9e9056c287624a729dd76232c2a0a08b", "d59dc93587f1462caa1eb55b9062f4fb", "4313e2adcaa64271b1fe53f4a655027e", "3aac699971844b1695061cf8b0bad013", "beb0d22b3b164b14973bb1806725fec7", "0ac94cdd568e40afb7fb017891ac5b6f", "abfeebe5b23340f58851eb768d9bf6fe", "52bb6ba2ad3b470cba369c768baee493", "7d6723564af64d8aabd900128f0e6534", "a0dc9321b6e6460baf4c3ee2dd367dcd", "74329585e523413798008d3811f5c0a7", "95ea5d597e4244e69b18bad25b457385", "5964a90365d94a6ebe9cf9083c18e6dd", "0c7cfe3228764d318a7e522e11f1d5a3", "279e6348bfca4ab389f9b8fecfad172d", "5d09dd15dd534e848f9f31a8d2365011", "888ceea6c82044bfac4a0973483a5f3f", "c17d1a6cca214e4faa77c340187d93ea", "4d0e06759c9e479f896f1234be7df726", "f2023eea57c64061a6e20150d66574b1", "39610c0bbc25412ca1b764488da655b0", "673e777e38e84020bbfba7d88a339630", "3c4f3b77295e4ee08c05b8bc3a62119a", "a6eab9a438cd49ed871c79d150ebc8f5", "5de75f65782840c5888fbbccf3e99590", "020a13abe3154e90abc429c1414cbabb", "a3eda809847f48b8a41887699d82318c", "9b68eb4d0de54e06beecedf847b6aa88", "226b20ca6a2b4b0abb818aababfec5fd", "ca59010c9cd14d55bb043fb89f99ada6", "1045e9d1b9ac421d9eab1b1215c33900", "85d390484eb34ebf9a629c815b426909", "11b3b32be56443339fe0c91c264a00b8", "f8973caa3cce47ecab561489bbbdc0c9", "d9c87adad23c44909ed4135e7f3dce6d", "e46a0009e5c740b8a3fd789b0a75de67", "3bae0e6431744c0197b535d146413c34", "aa981f2d03dd4c4fa041f07ccec95a7f", "230259883228451a8785acc0fcce4605", "5cee37e47e534d2abdff4ac5dd7fb313", "262d77f631224018bf10d826debb417f", "d9f496a61d9a4e238ff0d8860648e0b8", "541d67ebb5f8404481a377912c50d6c1", "b0bde0ee16994678aa9ba8d30a3f1174", "48a639afaea7474ea70f267d91bc5d1a", "0a68b787b01d48e4b2eaecc372dea03b", "724f87c14c2d4b68a9d428738bc3f4c6", "2e302f5c905843b1bc1c087d01963d54", "f827ea74692441c18557d6bf8f356372", "8075cfc29a284824b13eb456ed2a72ac", "70860a2fd76d4867abf81f729e567bb6" ] } }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Inference with:\tBEE-spoke-data/smol_llama-101M-GQA-python, max_new_tokens:256\n", "Using prompt:\tdef quantum_analysis(\n", "Loading model: BEE-spoke-data/smol_llama-101M-GQA-python\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/16.3k [00:00