pgarbacki commited on
Commit
5a86ab7
1 Parent(s): 199b267

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -37,7 +37,7 @@ import json
37
 
38
  device = "cuda" # the device to load the model onto
39
 
40
- model = AutoModelForCausalLM.from_pretrained("fireworks-ai/firefunction-v1")
41
  tokenizer = AutoTokenizer.from_pretrained("fireworks-ai/firefunction-v1")
42
 
43
  function_spec = [
@@ -87,7 +87,7 @@ messages = [
87
  {'role': 'user', 'content': 'Hi, can you tell me the current stock price of AAPL?'}
88
  ]
89
 
90
- model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
91
 
92
  generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
93
  decoded = tokenizer.batch_decode(generated_ids)
 
37
 
38
  device = "cuda" # the device to load the model onto
39
 
40
+ model = AutoModelForCausalLM.from_pretrained("fireworks-ai/firefunction-v1", device_map="auto")
41
  tokenizer = AutoTokenizer.from_pretrained("fireworks-ai/firefunction-v1")
42
 
43
  function_spec = [
 
87
  {'role': 'user', 'content': 'Hi, can you tell me the current stock price of AAPL?'}
88
  ]
89
 
90
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
91
 
92
  generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
93
  decoded = tokenizer.batch_decode(generated_ids)