Den4ikAI commited on
Commit
c32fe02
1 Parent(s): 7571ecb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -18
README.md CHANGED
@@ -18,41 +18,35 @@ use_cuda = torch.cuda.is_available()
18
  device = torch.device("cuda" if use_cuda else "cpu")
19
 
20
 
21
- t5_tokenizer = transformers.GPT2Tokenizer.from_pretrained("Den4ikAI/FRED-T5-XL-chitchat")
22
- t5_model = transformers.T5ForConditionalGeneration.from_pretrained("Den4ikAI/FRED-T5-XL-chitchat")
23
- t5_model.to(device)
24
- t5_model.eval()
25
 
26
  while True:
27
  print('-'*80)
28
  dialog = []
29
  while True:
30
- msg = input('Вопрос :> ').strip()
31
  if len(msg) == 0:
32
  break
33
 
34
- msg = msg[0].upper() + msg[1:]
35
- dialog.append('человек: ' + msg)
36
- prompt = '<SC1>' + '\n'.join(dialog) + '\nчатбот: <extra_id_0>'
37
-
38
- input_ids = t5_tokenizer(prompt, return_tensors='pt').input_ids
39
- out_ids = t5_model.generate(input_ids=input_ids.to(device),
40
  max_length=200,
41
  eos_token_id=t5_tokenizer.eos_token_id,
42
  early_stopping=True,
43
  do_sample=True,
44
  temperature=1.0,
45
  top_k=0,
46
- top_p=0.95)
47
-
48
- t5_output = t5_tokenizer.decode(out_ids[0][1:])
49
  if '</s>' in t5_output:
50
  t5_output = t5_output[:t5_output.find('</s>')].strip()
51
 
52
- t5_output = t5_output.replace('<extra_id_0>', '').strip()
53
-
54
- print('Ответ :> {}'.format(t5_output))
55
- dialog.append('чатбот: ' + t5_output)
56
  ```
57
  # Citation
58
  ```
 
18
  device = torch.device("cuda" if use_cuda else "cpu")
19
 
20
 
21
+ t5_tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_dir)
22
+ t5_model = transformers.T5ForConditionalGeneration.from_pretrained(model_dir)
 
 
23
 
24
  while True:
25
  print('-'*80)
26
  dialog = []
27
  while True:
28
+ msg = input('H:> ').strip()
29
  if len(msg) == 0:
30
  break
31
 
32
+ dialog.append('- ' + msg)
33
+ dialog.append('- <extra_id_0>')
34
+ input_ids = t5_tokenizer('<SC1>'+'\n'.join(dialog), return_tensors='pt').input_ids
35
+ out_ids = t5_model.generate(input_ids=input_ids,
 
 
36
  max_length=200,
37
  eos_token_id=t5_tokenizer.eos_token_id,
38
  early_stopping=True,
39
  do_sample=True,
40
  temperature=1.0,
41
  top_k=0,
42
+ top_p=0.85)
43
+ dialog.pop(-1)
44
+ t5_output = t5_tokenizer.decode(out_ids[0][1:]).replace('<extra_id_0>','')
45
  if '</s>' in t5_output:
46
  t5_output = t5_output[:t5_output.find('</s>')].strip()
47
 
48
+ print('B:> {}'.format(t5_output))
49
+ dialog.append('- '+t5_output)
 
 
50
  ```
51
  # Citation
52
  ```