Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from gradio.themes.utils import sizes | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import utils | |
from constants import END_OF_TEXT | |
from settings import DEFAULT_PORT | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained( | |
"BEE-spoke-data/smol_llama-101M-GQA-python", | |
use_fast=False, | |
) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
tokenizer.pad_token = END_OF_TEXT | |
model = AutoModelForCausalLM.from_pretrained( | |
"BEE-spoke-data/smol_llama-101M-GQA-python", | |
device_map="auto", | |
) | |
model = torch.compile(model, mode="reduce-overhead") | |
# UI things | |
_styles = utils.get_file_as_string("styles.css") | |
# Loads ./README.md file & splits it into sections | |
readme_file_content = utils.get_file_as_string("README.md", path="./") | |
( | |
manifest, | |
description, | |
disclaimer, | |
base_model_info, | |
formats, | |
) = utils.get_sections(readme_file_content, "---", up_to=5) | |
theme = gr.themes.Soft( | |
primary_hue="yellow", | |
secondary_hue="orange", | |
neutral_hue="slate", | |
radius_size=sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
text_size=sizes.text_lg, | |
) | |
def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
min_new_tokens=8, | |
renormalize_logits=True, | |
no_repeat_ngram_size=6, | |
repetition_penalty=repetition_penalty, | |
num_beams=3, | |
early_stopping=True, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return text | |
# Gradio interface wrapper for inference | |
def gradio_interface( | |
prompt: str, | |
temperature: float, | |
max_new_tokens: int, | |
top_p: float, | |
repetition_penalty: float, | |
): | |
return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty) | |
import random | |
examples = [ | |
["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2], | |
[ | |
"class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
[ | |
"import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
[ | |
"def factorial(n):\n if n == 0:\n return 1\n else:", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
[ | |
'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:', | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
[ | |
"import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2], | |
["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2], | |
[ | |
"def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
[ | |
"def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:", | |
0.2, | |
192, | |
0.9, | |
1.2, | |
], | |
] | |
# Define the Gradio Blocks interface | |
with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: | |
with gr.Column(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
value=random.choice([e[0] for e in examples]), | |
placeholder="Enter your code here", | |
label="Code", | |
elem_id="q-input", | |
) | |
submit = gr.Button("Generate", variant="primary") | |
output = gr.Code(elem_id="q-output", language="python", lines=10) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
column_1, column_2 = gr.Column(), gr.Column() | |
with column_1: | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=128, | |
minimum=0, | |
maximum=512, | |
step=64, | |
interactive=True, | |
info="Number of tokens to generate", | |
) | |
with column_2: | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.1, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
with gr.Column(): | |
version = gr.Dropdown( | |
[ | |
"smol_llama-101M-GQA-python", | |
], | |
value="smol_llama-101M-GQA-python", | |
label="Version", | |
info="", | |
) | |
gr.Markdown(disclaimer) | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
instruction, | |
temperature, | |
max_new_tokens, | |
top_p, | |
repetition_penalty, | |
version, | |
], | |
cache_examples=False, | |
fn=gradio_interface, | |
outputs=[output], | |
) | |
gr.Markdown(base_model_info) | |
gr.Markdown(formats) | |
submit.click( | |
gradio_interface, | |
inputs=[ | |
instruction, | |
temperature, | |
max_new_tokens, | |
top_p, | |
repetition_penalty, | |
], | |
outputs=[output], | |
# preprocess=False, | |
# batch=False, | |
show_progress=True, | |
) | |
# .queue(max_size=10, api_open=False) | |
demo.launch( | |
debug=True, | |
server_port=DEFAULT_PORT, | |
show_api=False, | |
share=utils.is_google_colab(), | |
) | |