Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from gradio.themes.utils import sizes | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import utils | |
| from constants import END_OF_TEXT, MIN_TEMPERATURE | |
| # Load the tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "BEE-spoke-data/smol_llama-101M-GQA-python", | |
| use_fast=False, | |
| ) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.pad_token = END_OF_TEXT | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "BEE-spoke-data/smol_llama-101M-GQA-python", | |
| device_map="auto", | |
| ) | |
| model = torch.compile(model, mode="reduce-overhead") | |
| # UI things | |
| _styles = utils.get_file_as_string("styles.css") | |
| # Loads ./README.md file & splits it into sections | |
| readme_file_content = utils.get_file_as_string("README.md", path="./") | |
| ( | |
| manifest, | |
| description, | |
| disclaimer, | |
| base_model_info, | |
| formats, | |
| ) = utils.get_sections(readme_file_content, "---", up_to=5) | |
| theme = gr.themes.Soft( | |
| primary_hue="yellow", | |
| secondary_hue="orange", | |
| neutral_hue="slate", | |
| radius_size=sizes.radius_sm, | |
| font=[ | |
| gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ], | |
| text_size=sizes.text_lg, | |
| ) | |
| def run_inference( | |
| prompt, temperature, max_new_tokens, top_p, repetition_penalty | |
| ) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| do_sample=True, | |
| epsilon_cutoff=1e-3, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=2, | |
| no_repeat_ngram_size=6, | |
| renormalize_logits=True, | |
| repetition_penalty=repetition_penalty, | |
| temperature=max(temperature, MIN_TEMPERATURE), | |
| top_p=top_p, | |
| ) | |
| text = tokenizer.batch_decode( | |
| outputs, | |
| skip_special_tokens=True, | |
| )[0] | |
| return text | |
| examples = [ | |
| [ | |
| 'def greet(name: str) -> None:\n """\n Greets the user\n """\n print(f"Hello,', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| [ | |
| 'for i in range(5):\n """\n Loop through 0 to 4\n """\n print(i,', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| ['x = 10\n"""Check if x is greater than 5"""\nif x > 5:', 0.2, 64, 0.9, 1.2], | |
| ["def square(x: int) -> int:\n return", 0.2, 64, 0.9, 1.2], | |
| ['import math\n"""Math operations"""\nmath.', 0.2, 64, 0.9, 1.2], | |
| [ | |
| 'def is_even(n) -> bool:\n """\n Check if a number is even\n """\n if n % 2 == 0:', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| [ | |
| 'while True:\n """Infinite loop example"""\n print("Infinite loop,', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| [ | |
| "def sum_list(lst: list[int]) -> int:\n total = 0\n for item in lst:", | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| [ | |
| 'try:\n """\n Exception handling\n """\n x = int(input("Enter a number: "))\nexcept ValueError:', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| [ | |
| 'def divide(a: float, b: float) -> float:\n """\n Divide a by b\n """\n if b != 0:', | |
| 0.2, | |
| 64, | |
| 0.9, | |
| 1.2, | |
| ], | |
| ] | |
| # Define the Gradio Blocks interface | |
| with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: | |
| with gr.Column(): | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| instruction = gr.Textbox( | |
| value=examples[0][0], | |
| placeholder="Enter your code here", | |
| label="Code", | |
| elem_id="q-input", | |
| ) | |
| submit = gr.Button("Generate", variant="primary") | |
| output = gr.Code(elem_id="q-output", language="python", lines=10) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Advanced settings", open=False): | |
| with gr.Row(): | |
| column_1, column_2 = gr.Column(), gr.Column() | |
| with column_1: | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.2, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| value=64, | |
| minimum=32, | |
| maximum=512, | |
| step=32, | |
| interactive=True, | |
| info="Number of tokens to generate", | |
| ) | |
| with column_2: | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.90, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| with gr.Column(): | |
| version = gr.Dropdown( | |
| [ | |
| "smol_llama-101M-GQA-python", | |
| ], | |
| value="smol_llama-101M-GQA-python", | |
| label="Version", | |
| info="", | |
| ) | |
| gr.Markdown(disclaimer) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| instruction, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| repetition_penalty, | |
| version, | |
| ], | |
| cache_examples=False, | |
| fn=run_inference, | |
| outputs=[output], | |
| ) | |
| gr.Markdown(base_model_info) | |
| gr.Markdown(formats) | |
| submit.click( | |
| run_inference, | |
| inputs=[ | |
| instruction, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| repetition_penalty, | |
| ], | |
| outputs=[output], | |
| # preprocess=False, | |
| # batch=False, | |
| show_progress=True, | |
| ) | |
| # .queue(max_size=10, api_open=False) | |
| demo.launch( | |
| debug=True, | |
| show_api=False, | |
| share=utils.is_google_colab(), | |
| ) | |