pszemraj's picture
remove batching
0f6a513
import gradio as gr
import torch
from gradio.themes.utils import sizes
from transformers import AutoModelForCausalLM, AutoTokenizer
import utils
from constants import END_OF_TEXT
from settings import DEFAULT_PORT
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
"BEE-spoke-data/smol_llama-101M-GQA-python",
use_fast=False,
)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = END_OF_TEXT
model = AutoModelForCausalLM.from_pretrained(
"BEE-spoke-data/smol_llama-101M-GQA-python",
device_map="auto",
)
model = torch.compile(model, mode="reduce-overhead")
# UI things
_styles = utils.get_file_as_string("styles.css")
# Loads ./README.md file & splits it into sections
readme_file_content = utils.get_file_as_string("README.md", path="./")
(
manifest,
description,
disclaimer,
base_model_info,
formats,
) = utils.get_sections(readme_file_content, "---", up_to=5)
theme = gr.themes.Soft(
primary_hue="yellow",
secondary_hue="orange",
neutral_hue="slate",
radius_size=sizes.radius_sm,
font=[
gr.themes.GoogleFont("IBM Plex Sans", [400, 600]),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
text_size=sizes.text_lg,
)
def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=8,
renormalize_logits=True,
no_repeat_ngram_size=6,
repetition_penalty=repetition_penalty,
num_beams=3,
early_stopping=True,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return text
# Gradio interface wrapper for inference
def gradio_interface(
prompt: str,
temperature: float,
max_new_tokens: int,
top_p: float,
repetition_penalty: float,
):
return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty)
import random
examples = [
["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2],
[
"class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):",
0.2,
192,
0.9,
1.2,
],
[
"import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda",
0.2,
192,
0.9,
1.2,
],
[
"def factorial(n):\n if n == 0:\n return 1\n else:",
0.2,
192,
0.9,
1.2,
],
[
'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:',
0.2,
192,
0.9,
1.2,
],
[
"import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot",
0.2,
192,
0.9,
1.2,
],
["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2],
["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2],
[
"def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):",
0.2,
192,
0.9,
1.2,
],
[
"def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:",
0.2,
192,
0.9,
1.2,
],
]
# Define the Gradio Blocks interface
with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(
value=random.choice([e[0] for e in examples]),
placeholder="Enter your code here",
label="Code",
elem_id="q-input",
)
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output", language="python", lines=10)
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=128,
minimum=0,
maximum=512,
step=64,
interactive=True,
info="Number of tokens to generate",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
with gr.Column():
version = gr.Dropdown(
[
"smol_llama-101M-GQA-python",
],
value="smol_llama-101M-GQA-python",
label="Version",
info="",
)
gr.Markdown(disclaimer)
gr.Examples(
examples=examples,
inputs=[
instruction,
temperature,
max_new_tokens,
top_p,
repetition_penalty,
version,
],
cache_examples=False,
fn=gradio_interface,
outputs=[output],
)
gr.Markdown(base_model_info)
gr.Markdown(formats)
submit.click(
gradio_interface,
inputs=[
instruction,
temperature,
max_new_tokens,
top_p,
repetition_penalty,
],
outputs=[output],
# preprocess=False,
batch=False,
show_progress=True,
)
demo.queue(max_size=10).launch(
debug=True,
server_port=DEFAULT_PORT,
)