Commit
·
a4f0c54
1
Parent(s):
2690bc7
Update README.md
Browse files
README.md
CHANGED
@@ -21,13 +21,16 @@ Sample generation script :
|
|
21 |
|
22 |
```python
|
23 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
24 |
tokenizer = transformers.AutoTokenizer.from_pretrained(r"path\to\checkpoint")
|
25 |
model = AutoModelForSeq2SeqLM.from_pretrained(r"path\to\checkpoint")
|
|
|
26 |
tokenizer.src_lang = "zh"
|
27 |
tokenizer.tgt_lang = "en"
|
28 |
test_string = "地阶上品遁术,施展后便可立于所持之剑上,以极快的速度自由飞行。"
|
29 |
|
30 |
-
inputs = tokenizer(test_string, return_tensors="pt")
|
31 |
translated_tokens = model.generate(**inputs, num_beams=10, do_sample=True)
|
32 |
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
33 |
|
|
|
21 |
|
22 |
```python
|
23 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
24 |
+
import torch
|
25 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
26 |
tokenizer = transformers.AutoTokenizer.from_pretrained(r"path\to\checkpoint")
|
27 |
model = AutoModelForSeq2SeqLM.from_pretrained(r"path\to\checkpoint")
|
28 |
+
model.to(device)
|
29 |
tokenizer.src_lang = "zh"
|
30 |
tokenizer.tgt_lang = "en"
|
31 |
test_string = "地阶上品遁术,施展后便可立于所持之剑上,以极快的速度自由飞行。"
|
32 |
|
33 |
+
inputs = tokenizer(test_string, return_tensors="pt").to(device)
|
34 |
translated_tokens = model.generate(**inputs, num_beams=10, do_sample=True)
|
35 |
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
36 |
|