statnlp commited on
Commit
126e1f7
·
1 Parent(s): 83b83d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -15,7 +15,6 @@ model = AutoModelForCausalLM.from_pretrained(
15
  "PY007/LiteChat-Preview",
16
  trust_remote_code=True,
17
  device_map="auto",
18
- torch_dtype=torch.float16
19
  )
20
  model.eval()
21
 
@@ -58,14 +57,14 @@ def bot(history):
58
  # Take only the most recent context up to the max context length and prepend the
59
  # system prompt with the messages
60
  max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
61
- inputs = BatchEncoding({
62
- k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
63
- for k in msg_tokens
64
- }).to('cuda')
65
  # inputs = BatchEncoding({
66
  # k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
67
  # for k in msg_tokens
68
- # })
 
 
 
 
69
  # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
70
  if inputs.get("token_type_ids", None) is not None:
71
  inputs.pop("token_type_ids")
 
15
  "PY007/LiteChat-Preview",
16
  trust_remote_code=True,
17
  device_map="auto",
 
18
  )
19
  model.eval()
20
 
 
57
  # Take only the most recent context up to the max context length and prepend the
58
  # system prompt with the messages
59
  max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
 
 
 
 
60
  # inputs = BatchEncoding({
61
  # k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
62
  # for k in msg_tokens
63
+ # }).to('cuda')
64
+ inputs = BatchEncoding({
65
+ k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
66
+ for k in msg_tokens
67
+ })
68
  # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
69
  if inputs.get("token_type_ids", None) is not None:
70
  inputs.pop("token_type_ids")