qianmuuq commited on
Commit
183839c
1 Parent(s): b50eeab

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -115,9 +115,9 @@ def getChat(text,model,tokenizer):
115
  text_ids = tokenizer.convert_tokens_to_ids(text)
116
  # print(text_ids)
117
 
118
- input_ids = torch.tensor(text_ids).long().to(device)
119
  input_ids = input_ids.unsqueeze(0)
120
- mask_input = torch.ones_like(input_ids).long().to(device)
121
  # print(input_ids.size())
122
  response = [] # 根据context,生成的response
123
  # 最多生成max_len个token
 
115
  text_ids = tokenizer.convert_tokens_to_ids(text)
116
  # print(text_ids)
117
 
118
+ input_ids = torch.tensor(text_ids).long()
119
  input_ids = input_ids.unsqueeze(0)
120
+ mask_input = torch.ones_like(input_ids).long()
121
  # print(input_ids.size())
122
  response = [] # 根据context,生成的response
123
  # 最多生成max_len个token