pszemraj commited on
Commit
e01a839
โ€ข
1 Parent(s): b690c07

init modified demo

Browse files
Files changed (7) hide show
  1. README.md +40 -7
  2. app.py +242 -0
  3. constants.py +4 -0
  4. settings.py +16 -0
  5. static/loading-icon.svg +4 -0
  6. static/styles.css +78 -0
  7. utils.py +45 -0
README.md CHANGED
@@ -1,13 +1,46 @@
1
  ---
2
- title: Beecoder Playground
3
- emoji: ๐Ÿ 
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.0.2
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BeeCoder Demo
3
+ emoji: ๐Ÿ
4
+ colorFrom: gray
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.28.3
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  ---
12
 
13
+ # ๐ŸBeeCoder Demo๐Ÿ
14
+
15
+ ## Code-Completion Playground ๐Ÿ’ป with ๐Ÿ[BeeCoder](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA-python) Models
16
+
17
+ This is a demo playground for generating Python code with the power of ๐Ÿ[BeeCoder](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA-python), a **fine-tuned** version of the tiny [101M base model](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA) on a dataset of pypi packages.
18
+
19
+ โ„น๏ธ This is not an instruction model but just a code completion tool.
20
+
21
+ ---
22
+
23
+ **Intended Use**: This app and its [supporting model](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA-python) are provided for demonstration purposes only; not to serve as a replacement for human expertise. For more details on the model, please refer to the [model card](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA-python).
24
+
25
+ In our country, we say _"To let 100M parameters model generate python script and not validate is like to let monkey fly a plane"_. So please be careful with the generated code.
26
+
27
+ ---
28
+
29
+ ## Base Model Information
30
+
31
+ The base model, smol_llama-101M-GQA, was pretrained on a relatively few (< ~20B) high-quality tokens. It is tiny in size (101M parameters) but relatively powerful in performance. The training for the base model included datasets such as:
32
+
33
+ - [JeanKaddour/minipile](https://huggingface.co/datasets/JeanKaddour/minipile)
34
+ - [pszemraj/simple_wikipedia_LM](https://huggingface.co/datasets/pszemraj/simple_wikipedia_LM)
35
+ - [BEE-spoke-data/wikipedia-20230901.en-deduped](https://huggingface.co/datasets/BEE-spoke-data/wikipedia-20230901.en-deduped)
36
+ - [mattymchen/refinedweb-3m](https://huggingface.co/datasets/mattymchen/refinedweb-3m)
37
+
38
+ You can find more information about the base model [here](https://huggingface.co/BEE-spoke-data/smol_llama-101M-GQA).
39
+
40
+ ---
41
+
42
+ ### Credits
43
+
44
+ This app is modified from a demo playground originally built for [StarCoder](https://huggingface.co/bigcode/starcoder) by [BigCode](https://huggingface.co/bigcode). You can find the original demo [here](https://huggingface.co/spaces/bigcode/bigcode-playground).
45
+
46
+ ---
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from gradio.themes.utils import sizes
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ import utils
7
+ from constants import END_OF_TEXT
8
+ from settings import DEFAULT_PORT
9
+
10
+ # Load the tokenizer and model
11
+ tokenizer = AutoTokenizer.from_pretrained(
12
+ "BEE-spoke-data/smol_llama-101M-GQA-python",
13
+ use_fast=False,
14
+ )
15
+ tokenizer.pad_token_id = tokenizer.eos_token_id
16
+ tokenizer.pad_token = END_OF_TEXT
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ "BEE-spoke-data/smol_llama-101M-GQA-python",
19
+ device_map="auto",
20
+ )
21
+ model = torch.compile(model, mode="reduce-overhead")
22
+
23
+ # UI things
24
+
25
+ _styles = utils.get_file_as_string("styles.css")
26
+
27
+ # Loads ./README.md file & splits it into sections
28
+ readme_file_content = utils.get_file_as_string("README.md", path="./")
29
+ (
30
+ manifest,
31
+ description,
32
+ disclaimer,
33
+ base_model_info,
34
+ formats,
35
+ ) = utils.get_sections(readme_file_content, "---", up_to=5)
36
+
37
+ theme = gr.themes.Soft(
38
+ primary_hue="yellow",
39
+ secondary_hue="orange",
40
+ neutral_hue="slate",
41
+ radius_size=sizes.radius_sm,
42
+ font=[
43
+ gr.themes.GoogleFont("IBM Plex Sans", [400, 600]),
44
+ "ui-sans-serif",
45
+ "system-ui",
46
+ "sans-serif",
47
+ ],
48
+ text_size=sizes.text_lg,
49
+ )
50
+
51
+
52
+ def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
+ outputs = model.generate(
55
+ **inputs,
56
+ max_new_tokens=max_new_tokens,
57
+ min_new_tokens=8,
58
+ renormalize_logits=True,
59
+ no_repeat_ngram_size=6,
60
+ repetition_penalty=repetition_penalty,
61
+ num_beams=3,
62
+ early_stopping=True,
63
+ do_sample=True,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ )
67
+ text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
68
+ return text
69
+
70
+
71
+ # Gradio interface wrapper for inference
72
+ def gradio_interface(
73
+ prompt: str,
74
+ temperature: float,
75
+ max_new_tokens: int,
76
+ top_p: float,
77
+ repetition_penalty: float,
78
+ ):
79
+ return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty)
80
+
81
+
82
+ import random
83
+
84
+ examples = [
85
+ ["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2],
86
+ [
87
+ "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):",
88
+ 0.2,
89
+ 192,
90
+ 0.9,
91
+ 1.2,
92
+ ],
93
+ [
94
+ "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda",
95
+ 0.2,
96
+ 192,
97
+ 0.9,
98
+ 1.2,
99
+ ],
100
+ [
101
+ "def factorial(n):\n if n == 0:\n return 1\n else:",
102
+ 0.2,
103
+ 192,
104
+ 0.9,
105
+ 1.2,
106
+ ],
107
+ [
108
+ 'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:',
109
+ 0.2,
110
+ 192,
111
+ 0.9,
112
+ 1.2,
113
+ ],
114
+ [
115
+ "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot",
116
+ 0.2,
117
+ 192,
118
+ 0.9,
119
+ 1.2,
120
+ ],
121
+ ["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2],
122
+ ["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2],
123
+ [
124
+ "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):",
125
+ 0.2,
126
+ 192,
127
+ 0.9,
128
+ 1.2,
129
+ ],
130
+ [
131
+ "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:",
132
+ 0.2,
133
+ 192,
134
+ 0.9,
135
+ 1.2,
136
+ ],
137
+ ]
138
+
139
+ # Define the Gradio Blocks interface
140
+ with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
141
+ with gr.Column():
142
+ gr.Markdown(description)
143
+ with gr.Row():
144
+ with gr.Column():
145
+ instruction = gr.Textbox(
146
+ value=random.choice([e[0] for e in examples]),
147
+ placeholder="Enter your code here",
148
+ label="Code",
149
+ elem_id="q-input",
150
+ )
151
+ submit = gr.Button("Generate", variant="primary")
152
+ output = gr.Code(elem_id="q-output", language="python", lines=10)
153
+ with gr.Row():
154
+ with gr.Column():
155
+ with gr.Accordion("Advanced settings", open=False):
156
+ with gr.Row():
157
+ column_1, column_2 = gr.Column(), gr.Column()
158
+ with column_1:
159
+ temperature = gr.Slider(
160
+ label="Temperature",
161
+ value=0.2,
162
+ minimum=0.0,
163
+ maximum=1.0,
164
+ step=0.05,
165
+ interactive=True,
166
+ info="Higher values produce more diverse outputs",
167
+ )
168
+ max_new_tokens = gr.Slider(
169
+ label="Max new tokens",
170
+ value=128,
171
+ minimum=0,
172
+ maximum=512,
173
+ step=64,
174
+ interactive=True,
175
+ info="Number of tokens to generate",
176
+ )
177
+ with column_2:
178
+ top_p = gr.Slider(
179
+ label="Top-p (nucleus sampling)",
180
+ value=0.90,
181
+ minimum=0.0,
182
+ maximum=1,
183
+ step=0.05,
184
+ interactive=True,
185
+ info="Higher values sample more low-probability tokens",
186
+ )
187
+ repetition_penalty = gr.Slider(
188
+ label="Repetition penalty",
189
+ value=1.1,
190
+ minimum=1.0,
191
+ maximum=2.0,
192
+ step=0.05,
193
+ interactive=True,
194
+ info="Penalize repeated tokens",
195
+ )
196
+ with gr.Column():
197
+ version = gr.Dropdown(
198
+ [
199
+ "smol_llama-101M-GQA-python",
200
+ ],
201
+ value="smol_llama-101M-GQA-python",
202
+ label="Version",
203
+ info="",
204
+ )
205
+ gr.Markdown(disclaimer)
206
+ gr.Examples(
207
+ examples=examples,
208
+ inputs=[
209
+ instruction,
210
+ temperature,
211
+ max_new_tokens,
212
+ top_p,
213
+ repetition_penalty,
214
+ version,
215
+ ],
216
+ cache_examples=False,
217
+ fn=gradio_interface,
218
+ outputs=[output],
219
+ )
220
+ gr.Markdown(base_model_info)
221
+ gr.Markdown(formats)
222
+
223
+ submit.click(
224
+ gradio_interface,
225
+ inputs=[
226
+ instruction,
227
+ temperature,
228
+ max_new_tokens,
229
+ top_p,
230
+ repetition_penalty,
231
+ ],
232
+ outputs=[output],
233
+ # preprocess=False,
234
+ max_batch_size=2,
235
+ show_progress=True,
236
+ )
237
+
238
+ demo.queue(max_size=10).launch(
239
+ debug=True,
240
+ server_port=DEFAULT_PORT,
241
+ max_threads=utils.get_workers(),
242
+ )
constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ END_OF_TEXT = "<|endoftext|>"
2
+
3
+ # Near zero temperature to avoid division by zero
4
+ MIN_TEMPERATURE = 1e-4
settings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # URLs for the StarCoder Models/APIs
2
+ DEFAULT_HUGGINGFACE_MODELS_API_BASE_URL = "https://api-inference.huggingface.co/models/"
3
+ DEFAULT_STARCODER_API_PATH = "bigcode/starcoder/"
4
+ DEFAULT_STARCODER_BASE_API_PATH = "bigcode/starcoderbase/"
5
+ FIM_INDICATOR = "<FILL_HERE>"
6
+ DEFAULT_PORT = 7860
7
+
8
+ STATIC_PATH = "static"
9
+
10
+ DEFAULT_SETTINGS = dict(
11
+ temperature=0.9,
12
+ max_new_tokens=256,
13
+ top_p=0.95,
14
+ repetition_penalty=1.0,
15
+ version="StarCoder",
16
+ )
static/loading-icon.svg ADDED
static/styles.css ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600;700&display=swap');
2
+
3
+ h1, h2 {
4
+ font-family: 'IBM Plex Mono', sans-serif;
5
+ }
6
+
7
+ .generating {
8
+ visibility: hidden
9
+ }
10
+
11
+ .gradio-container {
12
+ color: black
13
+ }
14
+
15
+ /* monospace_css */
16
+ #q-input textarea {
17
+ font-family: monospace, 'Consolas', Courier, monospace;
18
+ }
19
+
20
+ /* Share Button */
21
+
22
+ /* it was hidden directly inside the svg xml content */
23
+ #share-btn-loading-icon {
24
+ display: none;
25
+ }
26
+
27
+ a {
28
+ text-decoration-line: underline;
29
+ font-weight: 600;
30
+ }
31
+
32
+ .animate-spin {
33
+ animation: spin 1s linear infinite;
34
+ }
35
+
36
+ @keyframes spin {
37
+ from {
38
+ transform: rotate(0deg);
39
+ }
40
+ to {
41
+ transform: rotate(360deg);
42
+ }
43
+ }
44
+
45
+ #share-btn-container {
46
+ display: flex;
47
+ padding-left: 0.5rem !important;
48
+ padding-right: 0.5rem !important;
49
+ background-color: #000000;
50
+ justify-content: center;
51
+ align-items: center;
52
+ border-radius: 9999px !important;
53
+ width: 15rem;
54
+ }
55
+
56
+ #share-btn {
57
+ all: initial;
58
+ color: #ffffff;
59
+ font-weight: 600;
60
+ cursor: pointer;
61
+ font-family: 'IBM Plex Sans', sans-serif;
62
+ margin-left: 0.5rem !important;
63
+ padding-top: 0.25rem !important;
64
+ padding-bottom: 0.25rem !important;
65
+ }
66
+
67
+ #share-btn * {
68
+ all: unset;
69
+ }
70
+
71
+ #share-btn-container div:nth-child(-n+2) {
72
+ width: auto !important;
73
+ min-height: 0px !important;
74
+ }
75
+
76
+ #share-btn-container .wrap {
77
+ display: none !important;
78
+ }
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ from settings import STATIC_PATH
5
+
6
+
7
+ def get_file_as_string(file_name, path=STATIC_PATH) -> str:
8
+ """Loads the content of a file given its name
9
+ and returns all of its lines as a single string
10
+ if a file path is given, it will be used
11
+ instead of the default static path (from settings)
12
+
13
+ Args:
14
+ file_name (_type_): The name of the file to load.
15
+ path (str, optional): The path to the file. Defaults to the current directory.
16
+
17
+ Returns:
18
+ str: The content of the file as a single string
19
+ """
20
+ with open(os.path.join(path, file_name), mode="r", encoding="UTF-8") as f:
21
+ return f.read()
22
+
23
+
24
+ def get_sections(string: str, delimiter: str, up_to: int = None) -> List[str]:
25
+ """Splits a string into sections given a delimiter
26
+
27
+ Args:
28
+ string (str): The string to split
29
+ delimiter (str): The delimiter to use
30
+ up_to (int, optional): The maximum number of sections to return.
31
+ Defaults to None (which means all sections)
32
+
33
+ Returns:
34
+ List[str]: The list of sections (up to the given limit, if any provided)
35
+ """
36
+ return [
37
+ section.strip()
38
+ for section in string.split(delimiter)
39
+ if (section and not section.isspace())
40
+ ][:up_to]
41
+
42
+
43
+ def get_workers(safety: int = 4) -> int:
44
+ """Return the number of cores available on the current system, minus a safety margin."""
45
+ return max(1, os.cpu_count() - safety)