File size: 4,740 Bytes
a3ffd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# oobabooga/text-generation-webui\n",
        "\n",
        "After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate an API link.\n",
        "\n",
        "* Project page: https://github.com/oobabooga/text-generation-webui\n",
        "* Gradio server status: https://status.gradio.app/"
      ],
      "metadata": {
        "id": "MFQl6-FjSYtY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 1. Keep this tab alive to prevent Colab from disconnecting you { display-mode: \"form\" }\n",
        "\n",
        "#@markdown Press play on the music player that will appear below:\n",
        "%%html\n",
        "<audio src=\"https://oobabooga.github.io/silence.m4a\" controls>"
      ],
      "metadata": {
        "id": "f7TVVj_z4flw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 2. Launch the web UI\n",
        "\n",
        "#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
        "\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "if Path.cwd().name != 'text-generation-webui':\n",
        "  print(\"Installing the webui...\")\n",
        "\n",
        "  !git clone https://github.com/oobabooga/text-generation-webui\n",
        "  %cd text-generation-webui\n",
        "\n",
        "  torver = torch.__version__\n",
        "  print(f\"TORCH: {torver}\")\n",
        "  is_cuda118 = '+cu118' in torver  # 2.1.0+cu118\n",
        "\n",
        "  textgen_requirements = open('requirements.txt').read().splitlines()\n",
        "  if is_cuda118:\n",
        "      textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
        "  with open('temp_requirements.txt', 'w') as file:\n",
        "      file.write('\\n'.join(textgen_requirements))\n",
        "\n",
        "  !pip install -r extensions/openai/requirements.txt --upgrade\n",
        "  !pip install -r temp_requirements.txt --upgrade\n",
        "\n",
        "  print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
        "  print(\"\\033[1;32;1m\\n --> There is no need to restart the runtime.\\n\\033[0;37;0m\")\n",
        "\n",
        "  try:\n",
        "    import flash_attn\n",
        "  except:\n",
        "    !pip uninstall -y flash_attn\n",
        "\n",
        "# Parameters\n",
        "model_url = \"https://huggingface.co/TheBloke/MythoMax-L2-13B-GPTQ\" #@param {type:\"string\"}\n",
        "branch = \"gptq-4bit-32g-actorder_True\" #@param {type:\"string\"}\n",
        "command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n",
        "api = False #@param {type:\"boolean\"}\n",
        "\n",
        "if api:\n",
        "  for param in ['--api', '--public-api']:\n",
        "    if param not in command_line_flags:\n",
        "      command_line_flags += f\" {param}\"\n",
        "\n",
        "model_url = model_url.strip()\n",
        "if model_url != \"\":\n",
        "    if not model_url.startswith('http'):\n",
        "        model_url = 'https://huggingface.co/' + model_url\n",
        "\n",
        "    # Download the model\n",
        "    url_parts = model_url.strip('/').strip().split('/')\n",
        "    output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
        "    branch = branch.strip('\"\\' ')\n",
        "    if branch.strip() not in ['', 'main']:\n",
        "        output_folder += f\"_{branch}\"\n",
        "        !python download-model.py {model_url} --branch {branch}\n",
        "    else:\n",
        "        !python download-model.py {model_url}\n",
        "else:\n",
        "    output_folder = \"\"\n",
        "\n",
        "# Start the web UI\n",
        "cmd = f\"python server.py --share\"\n",
        "if output_folder != \"\":\n",
        "    cmd += f\" --model {output_folder}\"\n",
        "cmd += f\" {command_line_flags}\"\n",
        "print(cmd)\n",
        "!$cmd"
      ],
      "metadata": {
        "id": "LGQ8BiMuXMDG",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}