rpand002 commited on
Commit
73eb3b4
1 Parent(s): 57fc33b

update readme

Browse files
Files changed (1) hide show
  1. README.md +88 -3
README.md CHANGED
@@ -1,3 +1,88 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ ### Granite-20B-FunctionCalling
5
+ #### Model Summary
6
+ Granite-20B-FunctionCalling is a finetuned model based on IBM's granite-20b-code instruct to introduce function calling abilities into Granite model family. The model is trained using a multi-task training approach on seven fundamental tasks encompassed in function calling, those being Nested Function Calling, Function Chaining, Parallel Functions, Function Name Detection, Parameter-Value Pair Detection, Next-Best Function, and Response Generation.
7
+
8
+ - **Developers**: IBM Research
9
+ - **Paper**: [Granite-Function Calling Model: Introducing Function Calling Abilities via Multi-task Learning of Granular Tasks](https://arxiv.org/pdf/2407.00121v1)
10
+ - **Release Date**: July 9th, 2024
11
+ - **License**: [Apache 2.0.](https://www.apache.org/licenses/LICENSE-2.0)
12
+
13
+ ### Usage
14
+ ### Intended use
15
+ The model is designed to respond to function calling related instructions.
16
+
17
+ ### Generation
18
+ This is a simple example of how to use Granite-20B-Code-Instruct model.
19
+ ```python
20
+ import json
21
+ import torch
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+
24
+ device = "cuda" # or "cpu"
25
+ model_path = "ibm-granite/granite-20b-functioncalling"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
27
+ # drop device_map if running on CPU
28
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
29
+ model.eval()
30
+
31
+ # define the user query and list of available functions
32
+ query = "What's the current weather in New York?"
33
+ functions = [
34
+ {
35
+ "name": "get_current_weather",
36
+ "description": "Get the current weather",
37
+ "parameters": {
38
+ "type": "object",
39
+ "properties": {
40
+ "location": {
41
+ "type": "string",
42
+ "description": "The city and state, e.g. San Francisco, CA"
43
+ }
44
+ },
45
+ "required": ["location"]
46
+ }
47
+ },
48
+ {
49
+ "name": "get_stock_price",
50
+ "description": "Retrieves the current stock price for a given ticker symbol. The ticker symbol must be a valid symbol for a publicly traded company on a major US stock exchange like NYSE or NASDAQ. The tool will return the latest trade price in USD. It should be used when the user asks about the current or most recent price of a specific stock. It will not provide any other information about the stock or company.",
51
+ "parameters": {
52
+ "type": "object",
53
+ "properties": {
54
+ "ticker": {
55
+ "type": "string",
56
+ "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
57
+ }
58
+ },
59
+ "required": ["ticker"]
60
+ }
61
+ }
62
+ ]
63
+
64
+
65
+ # serialize functions and define a payload to generate the input template
66
+ payload = {
67
+ "functions_str": [json.dumps(x) for x in functions],
68
+ "query": query,
69
+ }
70
+
71
+ instruction = tokenizer.apply_chat_template(payload, tokenize=False, add_generation_prompt=True)
72
+
73
+ # tokenize the text
74
+ input_tokens = tokenizer(instruction, return_tensors="pt").to(device)
75
+
76
+ # generate output tokens
77
+ outputs = model.generate(**input_tokens, max_new_tokens=100)
78
+
79
+ # decode output tokens into text
80
+ outputs = tokenizer.batch_decode(outputs)
81
+
82
+ # loop over the batch to print, in this example the batch size is 1
83
+ for output in outputs:
84
+ # Each function call in the output will be preceded by the token "<function_call>" followed by a
85
+ # json serialized function call of the format {"name": $function_name$, "arguments" {$arg_name$: $arg_val$}}
86
+ # In this specific case, the output will be: <function_call> {"name": "get_current_weather", "arguments": {"location": "New York"}}
87
+ print(output)
88
+ ```