sanagnos commited on
Commit
ed47e74
1 Parent(s): c0c2c2c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -4
README.md CHANGED
@@ -25,11 +25,11 @@ model.lm_head = CastOutputToFloat(model.lm_head)
25
 
26
  tokenizer = transformers.AutoTokenizer.from_pretrained(base_path)
27
 
28
- batch = "<question>What are the symptoms of Alzheimer's disease?<answer>"
29
 
30
  with torch.cuda.amp.autocast():
31
  out = model.generate(
32
- input_ids=batch['input_ids'],
33
  max_length=300,
34
  do_sample=True,
35
  top_k=40,
@@ -38,6 +38,5 @@ with torch.cuda.amp.autocast():
38
  eos_token_id=tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('<question>')]
39
  )
40
 
41
- message = tokenizer.decode(out[0, :-1]).replace('<question>', "User:\n").replace('<answer>', 'Assistant:\n')
42
-
43
  ```
 
25
 
26
  tokenizer = transformers.AutoTokenizer.from_pretrained(base_path)
27
 
28
+ batch = tokenizer.encode("<question>What are the symptoms of Alzheimer's disease?<answer>", return_tensors="pt")
29
 
30
  with torch.cuda.amp.autocast():
31
  out = model.generate(
32
+ input_ids=batch.to(model.device),
33
  max_length=300,
34
  do_sample=True,
35
  top_k=40,
 
38
  eos_token_id=tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('<question>')]
39
  )
40
 
41
+ print(tokenizer.decode(out[0, :-1]).replace('<question>', "User:\n").replace('<answer>', '\nAssistant:\n'))
 
42
  ```