import gradio as gr
import torch
from transformers import pipeline
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
TOKEN = os.getenv("USER_TOKEN")
instruct_pipeline_3b = pipeline(model="tiiuae/falcon-7b-instruct", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
instruct_pipeline_7b = pipeline(model="HuggingFaceH4/llama-7b-ift", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", use_auth_token=TOKEN)
#instruct_pipeline_12b = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
def generate(query, temperature, top_p, top_k, max_new_tokens):
return [instruct_pipeline_3b(query, temperature, top_p, top_k, max_new_tokens), instruct_pipeline_7b(query, temperature, top_p, top_k, max_new_tokens)]
examples = [
"How many helicopters can a human eat in one sitting?",
"What is an alpaca? How is it different from a llama?",
"Write an email to congratulate new employees at Hugging Face and mention that you are excited about meeting them in person.",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren't birds real?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating?",
]
def process_example(args):
for x in generate(args):
pass
return x
css = ".generating {visibility: hidden}"
with gr.Blocks(theme=theme) as demo:
gr.Markdown(
"""
Falcon 7B vs. LLaMA 7B instruction tuned
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
instruction = gr.Textbox(placeholder="Enter your question here", label="Question", elem_id="q-input")
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample fewer low-probability tokens",
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=100,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens",
)
with gr.Column():
with gr.Row():
max_new_tokens = gr.Slider(
label="Maximum new tokens",
value=256,
minimum=0,
maximum=2048,
step=5,
interactive=True,
info="The maximum number of new tokens to generate",
)
with gr.Row():
submit = gr.Button("Generate Answers")
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("**Falcon 7B instruct**")
output_3b = gr.Markdown()
with gr.Column():
with gr.Box():
gr.Markdown("**LLaMA 7B instruct**")
output_7b = gr.Markdown()
# with gr.Column():
# with gr.Box():
# gr.Markdown("**Dolly 12B**")
# output_12b = gr.Markdown()
with gr.Row():
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output_3b, output_7b],
)
submit.click(generate, inputs=[instruction, temperature, top_p, top_k, max_new_tokens], outputs=[output_3b, output_7b ])
instruction.submit(generate, inputs=[instruction, temperature, top_p, top_k, max_new_tokens ], outputs=[output_3b, output_7b])
demo.queue(concurrency_count=16).launch(debug=True)