tool-bert / README.md
nkasmanoff's picture
Update README.md
6ad2c3c verified
|
raw
history blame
2.29 kB
metadata
license: apache-2.0
base_model: google-bert/bert-base-uncased
tags:
  - generated_from_trainer
metrics:
  - accuracy
model-index:
  - name: tool-bert
    results: []

tool-bert

This model is a fine-tuned version of google-bert/bert-base-uncased.

It uses a custom made dataset of sample user instructions, which are classified to a number of possible local assistant function calling endpoints.

For example, given an input query, tool-bert returns a prediction as to what tool to use to augment a downstream LLM generated output with.

More information on these tools to follow, but example tools are "play music", "check the weather", "get the news", "take a photo", or use no tool.

Basically, this model is meant to be a means of allowing very small LLMs (i.e. 8B and below) to use function calling.

All limitations and biases are inherited from the parent model.

Example Usage

from transformers import AutoTokenizer from transformers import AutoModelForSequenceClassification

key_tools = ['take_picture', 'no_tool_needed', 'check_news', 'check_weather', 'play_spotify']

def get_id2tool_name(id, key_tools): return key_tools[id]

def remove_any_non_alphanumeric_characters(text): return ''.join(e for e in text if e.isalnum() or e.isspace())

def load_model(): tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") model = AutoModelForSequenceClassification.from_pretrained( "nkasmanoff/tool-bert")

model.eval()
return model, tokenizer

def predict_tool(question, model, tokenizer): question = remove_any_non_alphanumeric_characters(question) inputs = tokenizer(question, return_tensors="pt")

outputs = model(**inputs)

logits = outputs.logits
return get_id2tool_name(logits.argmax().item(), key_tools)

model, tokenizer = load_model()

question = "What's the weather outside?"

predict_tool(question, model, tokenizer)

check_weather

Framework versions

  • Transformers 4.41.1
  • Pytorch 2.3.0
  • Datasets 2.19.1
  • Tokenizers 0.19.1