pvduy commited on
Commit
0a1707e
1 Parent(s): 54fe16b

change device

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -18,7 +18,7 @@ def parse_args():
18
  return parser.parse_args()
19
 
20
  def predict(message, history, system_prompt, temperature, max_tokens):
21
- global model, tokenizer
22
  instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
23
  for human, assistant in history:
24
  instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
@@ -33,8 +33,8 @@ def predict(message, history, system_prompt, temperature, max_tokens):
33
  if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
34
  input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
35
 
36
- input_ids = input_ids.cuda()
37
- attention_mask = attention_mask.cuda()
38
  generate_kwargs = dict(
39
  {"input_ids": input_ids, "attention_mask": attention_mask},
40
  streamer=streamer,
@@ -59,7 +59,8 @@ if __name__ == "__main__":
59
  args = parse_args()
60
  tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
61
  model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
62
- model = model.cuda()
 
63
  gr.ChatInterface(
64
  predict,
65
  title="Stable Code Instruct Chat - Demo",
 
18
  return parser.parse_args()
19
 
20
  def predict(message, history, system_prompt, temperature, max_tokens):
21
+ global model, tokenizer, device
22
  instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
23
  for human, assistant in history:
24
  instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
 
33
  if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
34
  input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
35
 
36
+ input_ids = input_ids.to(device)
37
+ attention_mask = attention_mask.to(device)
38
  generate_kwargs = dict(
39
  {"input_ids": input_ids, "attention_mask": attention_mask},
40
  streamer=streamer,
 
59
  args = parse_args()
60
  tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
61
  model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+ model = model.to(device)
64
  gr.ChatInterface(
65
  predict,
66
  title="Stable Code Instruct Chat - Demo",