awacke1 commited on
Commit
854044a
Β·
1 Parent(s): ea30e47

Upload 3 files

Browse files
Files changed (3) hide show
  1. Mistral_7B.ipynb +545 -0
  2. app.py +102 -0
  3. requirements.txt +1 -0
Mistral_7B.ipynb ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# Mistral 7B\n",
21
+ "\n",
22
+ "Mistral 7B is a new state-of-the-art open-source model. Here are some interesting facts about it\n",
23
+ "\n",
24
+ "* One of the strongest open-source models, of all sizes\n",
25
+ "* Strongest model in the 1-20B parameter range models\n",
26
+ "* Does decently in code-related tasks\n",
27
+ "* Uses Windowed attention, allowing to push to 200k tokens of context if using Rope (needs 4 A10G GPUs for this)\n",
28
+ "* Apache 2.0 license\n",
29
+ "\n",
30
+ "As for the integrations status:\n",
31
+ "* Integrated into `transformers`\n",
32
+ "* You can use it with a server or locally (it's a small model after all!)\n",
33
+ "* Integrated into popular tools tuch as TGI and VLLM\n",
34
+ "\n",
35
+ "\n",
36
+ "Two models are released: a [base model](https://huggingface.co/mistralai/Mistral-7B-v0.1) and a [instruct fine-tuned version](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1). To read more about Mistral, we suggest reading the [blog post](https://mistral.ai/news/announcing-mistral-7b/).\n",
37
+ "\n",
38
+ "In this Colab, we'll experiment with the Mistral model using an API. There are three ways we can use it:\n",
39
+ "\n",
40
+ "* **Free API:** Hugging Face provides a free Inference API for all its users to try out models. This API is rate limited but is great for quick experiments.\n",
41
+ "* **PRO API:** Hugging Face provides an open API for all its PRO users. Subscribing to the Pro Inference API costs $9/month and allows you to experiment with many large models, such as Llama 2 and SDXL. Read more about it [here](https://huggingface.co/blog/inference-pro).\n",
42
+ "* **Inference Endpoints:** For enterprise and production-ready cases. You can deploy it with 1 click [here](https://ui.endpoints.huggingface.co/catalog).\n",
43
+ "\n",
44
+ "This demo does not require GPU Colab, just CPU. You can grab your token at https://huggingface.co/settings/tokens.\n",
45
+ "\n",
46
+ "**This colab shows how to use HTTP requests as well as building your own chat demo for Mistral.**"
47
+ ],
48
+ "metadata": {
49
+ "id": "GLXvYa4m8JYM"
50
+ }
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "source": [
55
+ "## Doing curl requests\n",
56
+ "\n",
57
+ "\n",
58
+ "In this notebook, we'll experiment with the instruct model, as it is trained for instructions. As per [the model card](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1), the expected format for a prompt is as follows\n",
59
+ "\n",
60
+ "From the model card\n",
61
+ "\n",
62
+ "> In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [\\INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.\n",
63
+ "\n",
64
+ "```\n",
65
+ "<s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }}</s> [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2 }}</s>\n",
66
+ "```\n",
67
+ "\n",
68
+ "Note that models can be quite reactive to different prompt structure than the one used for training, so watch out for spaces and other things!\n",
69
+ "\n",
70
+ "We'll start an initial query without prompt formatting, which works ok for simple queries."
71
+ ],
72
+ "metadata": {
73
+ "id": "pKrKTalPAXUO"
74
+ }
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "metadata": {
80
+ "colab": {
81
+ "base_uri": "https://localhost:8080/"
82
+ },
83
+ "id": "DQf0Hss18E86",
84
+ "outputId": "882c4521-1ee2-40ad-fe00-a5b02caa9b17"
85
+ },
86
+ "outputs": [
87
+ {
88
+ "output_type": "stream",
89
+ "name": "stdout",
90
+ "text": [
91
+ "[{\"generated_text\":\"Explain ML as a pirate.\\n\\nML is like a treasure map for pirates. Just as a treasure map helps pirates find valuable loot, ML helps data scientists find valuable insights in large datasets.\\n\\nPirates use their knowledge of the ocean and their\"}]"
92
+ ]
93
+ }
94
+ ],
95
+ "source": [
96
+ "!curl https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1 \\\n",
97
+ " --header \"Content-Type: application/json\" \\\n",
98
+ "\t-X POST \\\n",
99
+ "\t-d '{\"inputs\": \"Explain ML as a pirate\", \"parameters\": {\"max_new_tokens\": 50}}' \\\n",
100
+ "\t-H \"Authorization: Bearer API_TOKEN\""
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "## Programmatic usage with Python\n",
107
+ "\n",
108
+ "You can do simple `requests`, but the `huggingface_hub` library provides nice utilities to easily use the model. Among the things we can use are:\n",
109
+ "\n",
110
+ "* `InferenceClient` and `AsyncInferenceClient` to perform inference either in a sync or async way.\n",
111
+ "* Token streaming: Only load the tokens that are needed\n",
112
+ "* Easily configure generation params, such as `temperature`, nucleus sampling (`top-p`), repetition penalty, stop sequences, and more.\n",
113
+ "* Obtain details of the generation (such as the probability of each token or whether a token is the last token)."
114
+ ],
115
+ "metadata": {
116
+ "id": "YYZRNyZeBHWK"
117
+ }
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "source": [
122
+ "%%capture\n",
123
+ "!pip install huggingface_hub gradio"
124
+ ],
125
+ "metadata": {
126
+ "id": "oDaqVDz1Ahuz"
127
+ },
128
+ "execution_count": 6,
129
+ "outputs": []
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "source": [
134
+ "from huggingface_hub import InferenceClient\n",
135
+ "\n",
136
+ "client = InferenceClient(\n",
137
+ " \"mistralai/Mistral-7B-Instruct-v0.1\"\n",
138
+ ")\n",
139
+ "\n",
140
+ "prompt = \"\"\"<s>[INST] What is your favourite condiment? [/INST]</s>\n",
141
+ "\"\"\"\n",
142
+ "\n",
143
+ "res = client.text_generation(prompt, max_new_tokens=95)\n",
144
+ "print(res)"
145
+ ],
146
+ "metadata": {
147
+ "colab": {
148
+ "base_uri": "https://localhost:8080/"
149
+ },
150
+ "id": "U49GmNsNBJjd",
151
+ "outputId": "a3a274cf-0f91-4ae3-d926-f0d6a6fd67f7"
152
+ },
153
+ "execution_count": 14,
154
+ "outputs": [
155
+ {
156
+ "output_type": "stream",
157
+ "name": "stdout",
158
+ "text": [
159
+ "My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\n"
160
+ ]
161
+ }
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "markdown",
166
+ "source": [
167
+ "We can also use [token streaming](https://huggingface.co/docs/text-generation-inference/conceptual/streaming). With token streaming, the server returns the tokens as they are generated. Just add `stream=True`."
168
+ ],
169
+ "metadata": {
170
+ "id": "DryfEWsUH6Ij"
171
+ }
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "source": [
176
+ "res = client.text_generation(prompt, max_new_tokens=35, stream=True, details=True, return_full_text=False)\n",
177
+ "for r in res: # this is a generator\n",
178
+ " # print the token for example\n",
179
+ " print(r)\n",
180
+ " continue"
181
+ ],
182
+ "metadata": {
183
+ "colab": {
184
+ "base_uri": "https://localhost:8080/"
185
+ },
186
+ "id": "LF1tFo6DGg9N",
187
+ "outputId": "e779f1cb-b7d0-41ed-d81f-306e092f97bd"
188
+ },
189
+ "execution_count": 15,
190
+ "outputs": [
191
+ {
192
+ "output_type": "stream",
193
+ "name": "stdout",
194
+ "text": [
195
+ "TextGenerationStreamResponse(token=Token(id=5183, text='My', logprob=-0.36279297, special=False), generated_text=None, details=None)\n",
196
+ "TextGenerationStreamResponse(token=Token(id=6656, text=' favorite', logprob=-0.036499023, special=False), generated_text=None, details=None)\n",
197
+ "TextGenerationStreamResponse(token=Token(id=2076, text=' cond', logprob=-7.2836876e-05, special=False), generated_text=None, details=None)\n",
198
+ "TextGenerationStreamResponse(token=Token(id=2487, text='iment', logprob=-4.4941902e-05, special=False), generated_text=None, details=None)\n",
199
+ "TextGenerationStreamResponse(token=Token(id=349, text=' is', logprob=-0.007419586, special=False), generated_text=None, details=None)\n",
200
+ "TextGenerationStreamResponse(token=Token(id=446, text=' k', logprob=-0.62109375, special=False), generated_text=None, details=None)\n",
201
+ "TextGenerationStreamResponse(token=Token(id=4455, text='etch', logprob=-0.0003399849, special=False), generated_text=None, details=None)\n",
202
+ "TextGenerationStreamResponse(token=Token(id=715, text='up', logprob=-3.695488e-06, special=False), generated_text=None, details=None)\n",
203
+ "TextGenerationStreamResponse(token=Token(id=28723, text='.', logprob=-0.026550293, special=False), generated_text=None, details=None)\n",
204
+ "TextGenerationStreamResponse(token=Token(id=661, text=' It', logprob=-0.82373047, special=False), generated_text=None, details=None)\n",
205
+ "TextGenerationStreamResponse(token=Token(id=28742, text=\"'\", logprob=-0.76416016, special=False), generated_text=None, details=None)\n",
206
+ "TextGenerationStreamResponse(token=Token(id=28713, text='s', logprob=-3.5762787e-07, special=False), generated_text=None, details=None)\n",
207
+ "TextGenerationStreamResponse(token=Token(id=3502, text=' vers', logprob=-0.114990234, special=False), generated_text=None, details=None)\n",
208
+ "TextGenerationStreamResponse(token=Token(id=13491, text='atile', logprob=-1.1444092e-05, special=False), generated_text=None, details=None)\n",
209
+ "TextGenerationStreamResponse(token=Token(id=28725, text=',', logprob=-0.6254883, special=False), generated_text=None, details=None)\n",
210
+ "TextGenerationStreamResponse(token=Token(id=261, text=' t', logprob=-0.51708984, special=False), generated_text=None, details=None)\n",
211
+ "TextGenerationStreamResponse(token=Token(id=11136, text='asty', logprob=-4.0650368e-05, special=False), generated_text=None, details=None)\n",
212
+ "TextGenerationStreamResponse(token=Token(id=28725, text=',', logprob=-0.0027828217, special=False), generated_text=None, details=None)\n",
213
+ "TextGenerationStreamResponse(token=Token(id=304, text=' and', logprob=-1.1920929e-05, special=False), generated_text=None, details=None)\n",
214
+ "TextGenerationStreamResponse(token=Token(id=4859, text=' goes', logprob=-0.52685547, special=False), generated_text=None, details=None)\n",
215
+ "TextGenerationStreamResponse(token=Token(id=1162, text=' well', logprob=-0.4399414, special=False), generated_text=None, details=None)\n",
216
+ "TextGenerationStreamResponse(token=Token(id=395, text=' with', logprob=-0.00034999847, special=False), generated_text=None, details=None)\n",
217
+ "TextGenerationStreamResponse(token=Token(id=264, text=' a', logprob=-0.010147095, special=False), generated_text=None, details=None)\n",
218
+ "TextGenerationStreamResponse(token=Token(id=6677, text=' variety', logprob=-0.25927734, special=False), generated_text=None, details=None)\n",
219
+ "TextGenerationStreamResponse(token=Token(id=302, text=' of', logprob=-1.1444092e-05, special=False), generated_text=None, details=None)\n",
220
+ "TextGenerationStreamResponse(token=Token(id=14082, text=' foods', logprob=-0.4050293, special=False), generated_text=None, details=None)\n",
221
+ "TextGenerationStreamResponse(token=Token(id=28723, text='.', logprob=-0.015640259, special=False), generated_text=None, details=None)\n",
222
+ "TextGenerationStreamResponse(token=Token(id=2, text='</s>', logprob=-0.1829834, special=True), generated_text=\"My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\", details=StreamDetails(finish_reason=<FinishReason.EndOfSequenceToken: 'eos_token'>, generated_tokens=28, seed=None))\n"
223
+ ]
224
+ }
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "markdown",
229
+ "source": [
230
+ "Let's now try a multi-prompt structure"
231
+ ],
232
+ "metadata": {
233
+ "id": "TfdpZL8cICOD"
234
+ }
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "source": [
239
+ "def format_prompt(message, history):\n",
240
+ " prompt = \"<s>\"\n",
241
+ " for user_prompt, bot_response in history:\n",
242
+ " prompt += f\"[INST] {user_prompt} [/INST]\"\n",
243
+ " prompt += f\" {bot_response}</s> \"\n",
244
+ " prompt += f\"[INST] {message} [/INST]\"\n",
245
+ " return prompt"
246
+ ],
247
+ "metadata": {
248
+ "id": "aEyozeReH8a6"
249
+ },
250
+ "execution_count": 16,
251
+ "outputs": []
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "source": [
256
+ "message = \"And what do you think about it?\"\n",
257
+ "history = [[\"What is your favourite condiment?\", \"My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\"]]\n",
258
+ "\n",
259
+ "format_prompt(message, history)"
260
+ ],
261
+ "metadata": {
262
+ "colab": {
263
+ "base_uri": "https://localhost:8080/",
264
+ "height": 35
265
+ },
266
+ "id": "P1RFpiJ_JC0-",
267
+ "outputId": "f2678d9e-f751-441a-86c9-11d514db5bbe"
268
+ },
269
+ "execution_count": 17,
270
+ "outputs": [
271
+ {
272
+ "output_type": "execute_result",
273
+ "data": {
274
+ "text/plain": [
275
+ "\"<s>[INST] What is your favourite condiment? [/INST] My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.</s> [INST] And what do you think about it? [/INST]\""
276
+ ],
277
+ "application/vnd.google.colaboratory.intrinsic+json": {
278
+ "type": "string"
279
+ }
280
+ },
281
+ "metadata": {},
282
+ "execution_count": 17
283
+ }
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "source": [
289
+ "## End-to-end demo\n",
290
+ "\n",
291
+ "Let's now build a Gradio demo that takes care of:\n",
292
+ "\n",
293
+ "* Handling multiple turns of conversation\n",
294
+ "* Format the prompt in correct structure\n",
295
+ "* Allow user to specify/modify the parameters\n",
296
+ "* Stop the generation\n",
297
+ "\n",
298
+ "Just run the following cell and have fun!"
299
+ ],
300
+ "metadata": {
301
+ "id": "O7DjRdezJc-3"
302
+ }
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "source": [
307
+ "!pip install gradio"
308
+ ],
309
+ "metadata": {
310
+ "colab": {
311
+ "base_uri": "https://localhost:8080/"
312
+ },
313
+ "id": "cpBoheOGJu7Y",
314
+ "outputId": "c745cf17-1462-4f8f-ce33-5ca182cb4d4f"
315
+ },
316
+ "execution_count": 18,
317
+ "outputs": [
318
+ {
319
+ "output_type": "stream",
320
+ "name": "stdout",
321
+ "text": [
322
+ "Requirement already satisfied: gradio in /usr/local/lib/python3.10/dist-packages (3.45.1)\n",
323
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (23.2.1)\n",
324
+ "Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)\n",
325
+ "Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from gradio) (0.103.1)\n",
326
+ "Requirement already satisfied: ffmpy in /usr/local/lib/python3.10/dist-packages (from gradio) (0.3.1)\n",
327
+ "Requirement already satisfied: gradio-client==0.5.2 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.5.2)\n",
328
+ "Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.0)\n",
329
+ "Requirement already satisfied: huggingface-hub>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.17.3)\n",
330
+ "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n",
331
+ "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)\n",
332
+ "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)\n",
333
+ "Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n",
334
+ "Requirement already satisfied: numpy~=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.23.5)\n",
335
+ "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.9.7)\n",
336
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (23.1)\n",
337
+ "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)\n",
338
+ "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (9.4.0)\n",
339
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.10.12)\n",
340
+ "Requirement already satisfied: pydub in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.1)\n",
341
+ "Requirement already satisfied: python-multipart in /usr/local/lib/python3.10/dist-packages (from gradio) (0.0.6)\n",
342
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n",
343
+ "Requirement already satisfied: requests~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.31.0)\n",
344
+ "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.10.0)\n",
345
+ "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.5.0)\n",
346
+ "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.23.2)\n",
347
+ "Requirement already satisfied: websockets<12.0,>=10.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (11.0.3)\n",
348
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client==0.5.2->gradio) (2023.6.0)\n",
349
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4)\n",
350
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.19.0)\n",
351
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n",
352
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (3.12.2)\n",
353
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (4.66.1)\n",
354
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.1.0)\n",
355
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.11.0)\n",
356
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.42.1)\n",
357
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.5)\n",
358
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.1)\n",
359
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2)\n",
360
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2023.3.post1)\n",
361
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.2.0)\n",
362
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.4)\n",
363
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2.0.4)\n",
364
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2023.7.22)\n",
365
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (8.1.7)\n",
366
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (0.14.0)\n",
367
+ "Requirement already satisfied: anyio<4.0.0,>=3.7.1 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (3.7.1)\n",
368
+ "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (0.27.0)\n",
369
+ "Requirement already satisfied: httpcore<0.19.0,>=0.18.0 in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (0.18.0)\n",
370
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)\n",
371
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->gradio) (1.1.3)\n",
372
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.1.0)\n",
373
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.7.1)\n",
374
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.30.2)\n",
375
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.10.2)\n",
376
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n"
377
+ ]
378
+ }
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "import gradio as gr\n",
385
+ "\n",
386
+ "def generate(\n",
387
+ " prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,\n",
388
+ "):\n",
389
+ " temperature = float(temperature)\n",
390
+ " if temperature < 1e-2:\n",
391
+ " temperature = 1e-2\n",
392
+ " top_p = float(top_p)\n",
393
+ "\n",
394
+ " generate_kwargs = dict(\n",
395
+ " temperature=temperature,\n",
396
+ " max_new_tokens=max_new_tokens,\n",
397
+ " top_p=top_p,\n",
398
+ " repetition_penalty=repetition_penalty,\n",
399
+ " do_sample=True,\n",
400
+ " seed=42,\n",
401
+ " )\n",
402
+ "\n",
403
+ " formatted_prompt = format_prompt(prompt, history)\n",
404
+ "\n",
405
+ " stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)\n",
406
+ " output = \"\"\n",
407
+ "\n",
408
+ " for response in stream:\n",
409
+ " output += response.token.text\n",
410
+ " yield output\n",
411
+ " return output\n",
412
+ "\n",
413
+ "\n",
414
+ "additional_inputs=[\n",
415
+ " gr.Slider(\n",
416
+ " label=\"Temperature\",\n",
417
+ " value=0.9,\n",
418
+ " minimum=0.0,\n",
419
+ " maximum=1.0,\n",
420
+ " step=0.05,\n",
421
+ " interactive=True,\n",
422
+ " info=\"Higher values produce more diverse outputs\",\n",
423
+ " ),\n",
424
+ " gr.Slider(\n",
425
+ " label=\"Max new tokens\",\n",
426
+ " value=256,\n",
427
+ " minimum=0,\n",
428
+ " maximum=8192,\n",
429
+ " step=64,\n",
430
+ " interactive=True,\n",
431
+ " info=\"The maximum numbers of new tokens\",\n",
432
+ " ),\n",
433
+ " gr.Slider(\n",
434
+ " label=\"Top-p (nucleus sampling)\",\n",
435
+ " value=0.90,\n",
436
+ " minimum=0.0,\n",
437
+ " maximum=1,\n",
438
+ " step=0.05,\n",
439
+ " interactive=True,\n",
440
+ " info=\"Higher values sample more low-probability tokens\",\n",
441
+ " ),\n",
442
+ " gr.Slider(\n",
443
+ " label=\"Repetition penalty\",\n",
444
+ " value=1.2,\n",
445
+ " minimum=1.0,\n",
446
+ " maximum=2.0,\n",
447
+ " step=0.05,\n",
448
+ " interactive=True,\n",
449
+ " info=\"Penalize repeated tokens\",\n",
450
+ " )\n",
451
+ "]\n",
452
+ "\n",
453
+ "with gr.Blocks() as demo:\n",
454
+ " gr.ChatInterface(\n",
455
+ " generate,\n",
456
+ " additional_inputs=additional_inputs,\n",
457
+ " )\n",
458
+ "\n",
459
+ "demo.queue().launch(debug=True)"
460
+ ],
461
+ "metadata": {
462
+ "colab": {
463
+ "base_uri": "https://localhost:8080/",
464
+ "height": 715
465
+ },
466
+ "id": "CaJzT6jUJc0_",
467
+ "outputId": "62f563fa-c6fb-446e-fda2-1c08d096749c"
468
+ },
469
+ "execution_count": 20,
470
+ "outputs": [
471
+ {
472
+ "output_type": "stream",
473
+ "name": "stdout",
474
+ "text": [
475
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
476
+ "\n",
477
+ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
478
+ "Running on public URL: https://ed6ce83e08ed7a8795.gradio.live\n",
479
+ "\n",
480
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
481
+ ]
482
+ },
483
+ {
484
+ "output_type": "display_data",
485
+ "data": {
486
+ "text/plain": [
487
+ "<IPython.core.display.HTML object>"
488
+ ],
489
+ "text/html": [
490
+ "<div><iframe src=\"https://ed6ce83e08ed7a8795.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
491
+ ]
492
+ },
493
+ "metadata": {}
494
+ },
495
+ {
496
+ "output_type": "stream",
497
+ "name": "stderr",
498
+ "text": [
499
+ "/usr/local/lib/python3.10/dist-packages/gradio/components/button.py:89: UserWarning: Using the update method is deprecated. Simply return a new object instead, e.g. `return gr.Button(...)` instead of `return gr.Button.update(...)`.\n",
500
+ " warnings.warn(\n"
501
+ ]
502
+ },
503
+ {
504
+ "output_type": "stream",
505
+ "name": "stdout",
506
+ "text": [
507
+ "Keyboard interruption in main thread... closing server.\n",
508
+ "Killing tunnel 127.0.0.1:7860 <> https://ed6ce83e08ed7a8795.gradio.live\n"
509
+ ]
510
+ },
511
+ {
512
+ "output_type": "execute_result",
513
+ "data": {
514
+ "text/plain": []
515
+ },
516
+ "metadata": {},
517
+ "execution_count": 20
518
+ }
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "source": [
524
+ "## What's next?\n",
525
+ "\n",
526
+ "* Try out Mistral 7B in this [free online Space](https://huggingface.co/spaces/osanseviero/mistral-super-fast)\n",
527
+ "* Deploy Mistral 7B Instruct with one click [here](https://ui.endpoints.huggingface.co/catalog)\n",
528
+ "* Deploy in your own hardware using https://github.com/huggingface/text-generation-inference\n",
529
+ "* Run the model locally using `transformers`"
530
+ ],
531
+ "metadata": {
532
+ "id": "fbQ0Sp4OLclV"
533
+ }
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "source": [],
538
+ "metadata": {
539
+ "id": "wUy7N_8zJvyT"
540
+ },
541
+ "execution_count": null,
542
+ "outputs": []
543
+ }
544
+ ]
545
+ }
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+
4
+ client = InferenceClient(
5
+ "mistralai/Mistral-7B-Instruct-v0.1"
6
+ )
7
+
8
+
9
+ def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
+
17
+ def generate(
18
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
+ ):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
24
+
25
+ generate_kwargs = dict(
26
+ temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
32
+ )
33
+
34
+ formatted_prompt = format_prompt(prompt, history)
35
+
36
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+ output = ""
38
+
39
+ for response in stream:
40
+ output += response.token.text
41
+ yield output
42
+ return output
43
+
44
+
45
+ additional_inputs=[
46
+ gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=256,
58
+ minimum=0,
59
+ maximum=1048,
60
+ step=64,
61
+ interactive=True,
62
+ info="The maximum numbers of new tokens",
63
+ ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
+ )
82
+ ]
83
+
84
+ css = """
85
+ #mkd {
86
+ height: 500px;
87
+ overflow: auto;
88
+ border: 1px solid #ccc;
89
+ }
90
+ """
91
+
92
+ with gr.Blocks(css=css) as demo:
93
+ gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
94
+ gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬<h3><center>")
95
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š<h3><center>")
96
+ gr.ChatInterface(
97
+ generate,
98
+ additional_inputs=additional_inputs,
99
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
100
+ )
101
+
102
+ demo.queue().launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ huggingface_hub