shishirpatil commited on
Commit
11df80f
1 Parent(s): 77d6b1c

Local model inference

Browse files
Files changed (1) hide show
  1. README.md +70 -1
README.md CHANGED
@@ -15,7 +15,7 @@ executable APIs call given natural language instructions and API context.
15
  |gorilla-openfunctions-v0 | Given a function, and user intent, returns properly formatted json with the right arguments|
16
  |gorilla-openfunctions-v1 | + Parallel functions, and can choose between functions|
17
 
18
- ## Example Usage
19
 
20
  1. OpenFunctions is compatible with OpenAI Functions
21
 
@@ -63,6 +63,75 @@ get_gorilla_response(query, functions=functions)
63
  ```bash
64
  uber.ride(loc="berkeley", type="plus", time=10)
65
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  ## Contributing
68
 
 
15
  |gorilla-openfunctions-v0 | Given a function, and user intent, returns properly formatted json with the right arguments|
16
  |gorilla-openfunctions-v1 | + Parallel functions, and can choose between functions|
17
 
18
+ ## Example Usage (Hosted)
19
 
20
  1. OpenFunctions is compatible with OpenAI Functions
21
 
 
63
  ```bash
64
  uber.ride(loc="berkeley", type="plus", time=10)
65
  ```
66
+
67
+ ## Example Usage (Run Locally)
68
+
69
+ ```python
70
+ import json
71
+ import torch
72
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
73
+
74
+ def get_prompt(user_query: str, functions: list = []) -> str:
75
+ """
76
+ Generates a conversation prompt based on the user's query and a list of functions.
77
+
78
+ Parameters:
79
+ - user_query (str): The user's query.
80
+ - functions (list): A list of functions to include in the prompt.
81
+
82
+ Returns:
83
+ - str: The formatted conversation prompt.
84
+ """
85
+ if len(functions) == 0:
86
+ return f"USER: <<question>> {user_query}\nASSISTANT: "
87
+ functions_string = json.dumps(functions)
88
+ return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
89
+
90
+ # Device setup
91
+ device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
92
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
93
+
94
+ # Model and tokenizer setup
95
+ model_id : str = "gorilla-llm/gorilla-openfunctions-v0"
96
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
97
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
98
+
99
+ # Move model to device
100
+ model.to(device)
101
+
102
+ # Pipeline setup
103
+ pipe = pipeline(
104
+ "text-generation",
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ max_new_tokens=128,
108
+ batch_size=16,
109
+ torch_dtype=torch_dtype,
110
+ device=device,
111
+ )
112
+
113
+ # Example usage
114
+ query: str = "Call me an Uber ride type \"Plus\" in Berkeley at zipcode 94704 in 10 minutes"
115
+ functions = [
116
+ {
117
+ "name": "Uber Carpool",
118
+ "api_name": "uber.ride",
119
+ "description": "Find suitable ride for customers given the location, type of ride, and the amount of time the customer is willing to wait as parameters",
120
+ "parameters": [
121
+ {"name": "loc", "description": "Location of the starting place of the Uber ride"},
122
+ {"name": "type", "enum": ["plus", "comfort", "black"], "description": "Types of Uber ride user is ordering"},
123
+ {"name": "time", "description": "The amount of time in minutes the customer is willing to wait"}
124
+ ]
125
+ }
126
+ ]
127
+
128
+ # Generate prompt and obtain model output
129
+ prompt = get_prompt(query, functions=functions)
130
+ output = pipe(prompt)
131
+
132
+ print(output)
133
+ ```
134
+
135
 
136
  ## Contributing
137