Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
model = AutoModelForCausalLM.from_pretrained("BatsResearch/bonito-v1") | |
tokenizer = AutoTokenizer.from_pretrained("BatsResearch/bonito-v1") | |
model.to("cuda") | |
def respond( | |
message, | |
task_type, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
task_type = task_type.lower() | |
input_text = "<|tasktype|>\n" + task_type.strip() | |
input_text += "\n<|context|>\n" + message.strip() + "\n<|task|>\n" | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") | |
output = model.generate( | |
input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
) | |
pred_start = int(input_ids.shape[-1]) | |
response = tokenizer.decode(output[0][pred_start:], skip_special_tokens=True) | |
# check if <|pipe|> is in the response | |
if "<|pipe|>" in response: | |
pair = response.split("<|pipe|>") | |
instruction = pair[0].strip().replace("{{context}}", message) | |
response = pair[1].strip() | |
else: | |
# fallback | |
instruction = pair[0].strip().replace("{{context}}", message) | |
response = "Unable to generate response. Please regenerate." | |
return instruction, response | |
task_types = [ | |
"extractive question answering", | |
"multiple-choice question answering", | |
"question generation", | |
"question answering without choices", | |
"yes-no question answering", | |
"coreference resolution", | |
"paraphrase generation", | |
"paraphrase identification", | |
"sentence completion", | |
"sentiment", | |
"summarization", | |
"text generation", | |
"topic classification", | |
"word sense disambiguation", | |
"textual entailment", | |
"natural language inference", | |
] | |
# capitalize for better readability | |
task_types = [task_type.title() for task_type in task_types] | |
description = """ | |
This is a demo for Bonito, an open-source model for conditional task generation: the task of converting unannotated text into task-specific synthetic instruction tuning data. | |
### More details on Bonito | |
- Model: https://huggingface.co/BatsResearch/bonito-v | |
- Paper: https://arxiv.org/abs/2402.18334 | |
- GitHub: https://github.com/BatsResearch/bonito | |
### Instructions | |
Try out the model by entering a context and selecting a task type from the dropdown. The model will generate a task instruction based on the context and task type you provide. | |
""" | |
examples = [ | |
( | |
"""Providence was one of the first cities in the country to industrialize and became noted for its textile manufacturing and subsequent machine tool, jewelry, and silverware industries. Today, the city of Providence is home to eight hospitals and eight institutions of higher learning which have shifted the city's economy into service industries, though it still retains some manufacturing activity.""", | |
"Natural language inference", | |
), | |
( | |
"""John Wick (Keanu Reeves) uncovers a path to defeating The High Table. But before he can earn his freedom, Wick must face off against a new enemy with powerful alliances across the globe and forces that turn old friends into foes.""", | |
"Yes-no question answering", | |
), | |
( | |
"""In 2013, American singer-songwriter Taylor Swift purchased High Watch for US$17.75 million. From 2013 to 2016, she received widespread press coverage for hosting annual American Independence Day parties on the estate, featuring numerous celebrity guests and lavish decorations often depicted on Instagram.""", | |
"Extractive question answering", | |
), | |
] | |
examples_with_additional = [[x[0], x[1]] for x in examples] | |
demo = gr.Interface( | |
fn=respond, | |
inputs=[ | |
gr.Textbox(label="Passage", lines=5, placeholder="Enter context here.."), | |
gr.Dropdown( | |
task_types, | |
value="Natural language inference", | |
label="Task Type", | |
), | |
], | |
outputs=[ | |
gr.Textbox( | |
label="Instruction", | |
lines=5, | |
), | |
gr.Textbox(label="Response"), | |
], | |
additional_inputs=[ | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
title="Bonito", | |
description=description, | |
examples=examples_with_additional, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |