CadenzaBaron commited on
Commit
a4f0c54
·
1 Parent(s): 2690bc7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -1
README.md CHANGED
@@ -21,13 +21,16 @@ Sample generation script :
21
 
22
  ```python
23
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
24
  tokenizer = transformers.AutoTokenizer.from_pretrained(r"path\to\checkpoint")
25
  model = AutoModelForSeq2SeqLM.from_pretrained(r"path\to\checkpoint")
 
26
  tokenizer.src_lang = "zh"
27
  tokenizer.tgt_lang = "en"
28
  test_string = "地阶上品遁术,施展后便可立于所持之剑上,以极快的速度自由飞行。"
29
 
30
- inputs = tokenizer(test_string, return_tensors="pt")
31
  translated_tokens = model.generate(**inputs, num_beams=10, do_sample=True)
32
  translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
33
 
 
21
 
22
  ```python
23
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
24
+ import torch
25
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26
  tokenizer = transformers.AutoTokenizer.from_pretrained(r"path\to\checkpoint")
27
  model = AutoModelForSeq2SeqLM.from_pretrained(r"path\to\checkpoint")
28
+ model.to(device)
29
  tokenizer.src_lang = "zh"
30
  tokenizer.tgt_lang = "en"
31
  test_string = "地阶上品遁术,施展后便可立于所持之剑上,以极快的速度自由飞行。"
32
 
33
+ inputs = tokenizer(test_string, return_tensors="pt").to(device)
34
  translated_tokens = model.generate(**inputs, num_beams=10, do_sample=True)
35
  translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
36