Wally-codeFix / app.py
SnehaPriyaaMP's picture
Update app.py
aee496c verified
raw
history blame contribute delete
No virus
3.25 kB
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import gradio as gr
import sentencepiece
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:120'
model_id = "thesven/Llama3-8B-SFT-code_bagel-bnb-4bit"
tokenizer_path = "./"
DESCRIPTION = """
# thesven/Llama3-8B-SFT-code_bagel-bnb-4bit
"""
# Check if CUDA is available and set device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, trust_remote_code=True)
def format_prompt(user_message, system_message="You are an expert developer in all programming languages. Help me with my code. Answer any questions I have with code examples."):
prompt = f"assistant\n{system_message}\n\nuser\n{user_message}\nassistant\n"
return prompt
@spaces.GPU
def predict(message, system_message, max_new_tokens=600, temperature=3.5, top_p=0.9, top_k=40, do_sample=False):
formatted_prompt = format_prompt(message, system_message)
input_ids = tokenizer.encode(formatted_prompt, return_tensors='pt')
input_ids = input_ids.to(device)
response_ids = model.generate(
input_ids,
max_length=max_new_tokens + input_ids.shape[1],
temperature=temperature,
top_p=top_p,
top_k=top_k,
no_repeat_ngram_size=9,
pad_token_id=tokenizer.eos_token_id,
do_sample=do_sample
)
response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
truncate_str = ""
if truncate_str and truncate_str in response:
response = response.split(truncate_str)[0]
return [("bot", response)]
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
system_prompt = gr.Textbox(placeholder='Provide a System Prompt In The First Person', label='System Prompt', lines=2, value="You are an expert developer in all programming languages. Help me with my code. Answer any questions I have with code examples.")
with gr.Group():
chatbot = gr.Chatbot(label='thesven/Llama3-8B-SFT-code_bagel-bnb-4bit')
with gr.Group():
textbox = gr.Textbox(placeholder='Your Message Here', label='Your Message', lines=2)
submit_button = gr.Button('Submit', variant='primary')
with gr.Accordion(label='Advanced options', open=False):
max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=55000, step=1, value=512)
temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=40)
do_sample_checkbox = gr.Checkbox(label='Disable for faster inference', value=True)
submit_button.click(
fn=predict,
inputs=[textbox, system_prompt, max_new_tokens, temperature, top_p, top_k, do_sample_checkbox],
outputs=chatbot
)
demo.launch()