license: apache-2.0
base_model: google-bert/bert-base-uncased
tags:
- generated_from_trainer
metrics:
- accuracy
model-index:
- name: tool-bert
results: []
tool-bert
This model is a fine-tuned version of google-bert/bert-base-uncased.
It uses a custom made dataset of sample user instructions, which are classified to a number of possible local assistant function calling endpoints.
For example, given an input query, tool-bert returns a prediction as to what tool to use to augment a downstream LLM generated output with.
More information on these tools to follow, but example tools are "play music", "check the weather", "get the news", "take a photo", or use no tool.
Basically, this model is meant to be a means of allowing very small LLMs (i.e. 8B and below) to use function calling.
All limitations and biases are inherited from the parent model.
Example Usage
from transformers import AutoTokenizer from transformers import AutoModelForSequenceClassification
key_tools = ['take_picture', 'no_tool_needed', 'check_news', 'check_weather', 'play_spotify']
def get_id2tool_name(id, key_tools): return key_tools[id]
def remove_any_non_alphanumeric_characters(text): return ''.join(e for e in text if e.isalnum() or e.isspace())
def load_model(): tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") model = AutoModelForSequenceClassification.from_pretrained( "nkasmanoff/tool-bert")
model.eval()
return model, tokenizer
def predict_tool(question, model, tokenizer): question = remove_any_non_alphanumeric_characters(question) inputs = tokenizer(question, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
return get_id2tool_name(logits.argmax().item(), key_tools)
model, tokenizer = load_model()
question = "What's the weather outside?"
predict_tool(question, model, tokenizer)
check_weather
Framework versions
- Transformers 4.41.1
- Pytorch 2.3.0
- Datasets 2.19.1
- Tokenizers 0.19.1