pomn commited on
Commit
ba13540
1 Parent(s): 12cc82e
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,18 +1,17 @@
1
- from transformers import AutoModel, AutoTokenizer
2
  import torch
3
-
4
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
5
- #model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
6
- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
7
-
8
  def converse(user_input, chat_history=[]):
9
- user_input_ids = tokenizer(user_input + tokenizer.eos_token, return_tensors='pt').input_ids
10
  # keep history in the tensor
11
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
12
  # get response
13
- chat_history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
14
  print (chat_history)
15
- response = tokenizer.decode(chat_history[0]).split("<|endoftext|>")
16
  print("starting to print response")
17
  print(response)
18
  # html for display
 
1
+ from transformers import AutoTokenizer,AutoModel
2
  import torch
3
+ chat_tkn = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
4
+ mdl = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
5
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
6
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
 
7
  def converse(user_input, chat_history=[]):
8
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
9
  # keep history in the tensor
10
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
11
  # get response
12
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
13
  print (chat_history)
14
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
15
  print("starting to print response")
16
  print(response)
17
  # html for display