Yokohide031 commited on
Commit
1ae49c0
1 Parent(s): 5b9c9fd

Update READ me

Browse files
Files changed (1) hide show
  1. README.md +125 -2
README.md CHANGED
@@ -4,14 +4,137 @@ license: cc-by-sa-4.0
4
  datasets:
5
  - wikipedia
6
  widget:
7
- - text: 東北大学で[MASK]の研究をしています。
8
  ---
9
 
10
  # What is this model?
11
  - 東北大学のBERT large JapaneseをRustで使える様に変換
12
  - [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese)
13
 
14
- ## Licenses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/).
17
 
 
4
  datasets:
5
  - wikipedia
6
  widget:
7
+ - text: Rustで[MASK]を使うことができます。。
8
  ---
9
 
10
  # What is this model?
11
  - 東北大学のBERT large JapaneseをRustで使える様に変換
12
  - [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese)
13
 
14
+ # How to Try
15
+
16
+ ### 1. Clone
17
+
18
+ ```
19
+ git clone https://huggingface.co/Yokohide031/rust_cl-tohoku_bert-large-japanese
20
+
21
+ ```
22
+
23
+ ### 2. Create Project
24
+
25
+ ```
26
+ cargo new <projectName>
27
+ ```
28
+
29
+ ### 3. Edit main.rs
30
+
31
+ ```
32
+ extern crate anyhow;
33
+
34
+ use rust_bert::bert::{BertConfig, BertForMaskedLM};
35
+ use rust_bert::Config;
36
+ use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
37
+ use rust_tokenizers::vocab::Vocab;
38
+ use tch::{nn, no_grad, Device, Tensor};
39
+
40
+ use std::path::PathBuf;
41
+
42
+ fn get_path(item: String) -> PathBuf {
43
+ let mut resource_dir = PathBuf::from("path/to/rust_cl-tohoku_bert-large-japanese/");
44
+ resource_dir.push(&item);
45
+ println!("{:?}", resource_dir);
46
+ return resource_dir;
47
+ }
48
+
49
+ fn input(display: String) -> String {
50
+ let mut text = String::new();
51
+ println!("{}", display);
52
+ std::io::stdin().read_line(&mut text).unwrap();
53
+ return text.trim().to_string();
54
+ }
55
+
56
+ fn main() -> anyhow::Result<()> {
57
+ // Resources paths
58
+
59
+ let model_path: PathBuf = get_path(String::from("rust_model.ot"));
60
+ let vocab_path: PathBuf = get_path(String::from("vocab.txt"));
61
+ let config_path: PathBuf = get_path(String::from("config.json"));
62
 
63
+
64
+ // Set-up masked LM model
65
+ let device = Device::Cpu;
66
+ let mut vs = nn::VarStore::new(device);
67
+ let config = BertConfig::from_file(config_path);
68
+ let bert_model = BertForMaskedLM::new(&vs.root(), &config);
69
+ vs.load(model_path)?;
70
+
71
+ // Define input
72
+ let inp = input(String::from("Input: "));
73
+ let inp = inp.replace("*", "[MASK]");
74
+ let input = [inp];
75
+
76
+ let tokenizer: BertTokenizer =
77
+ BertTokenizer::from_file(vocab_path.to_str().unwrap(), false, false).unwrap();
78
+
79
+ let owakatied = &tokenizer.tokenize_list(&input);
80
+
81
+ let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
82
+
83
+ let mut mask_index: usize = 0;
84
+ for (i, m) in owakatied[0].iter().enumerate() {
85
+ if m == "[MASK]" {
86
+ mask_index = i+1;
87
+ break;
88
+ }
89
+ }
90
+
91
+ let max_len = tokenized_input
92
+ .iter()
93
+ .map(|input| input.token_ids.len())
94
+ .max()
95
+ .unwrap();
96
+ let tokenized_input = tokenized_input
97
+ .iter()
98
+ .map(|input| input.token_ids.clone())
99
+ .map(|mut input| {
100
+ input.extend(vec![0; max_len - input.len()]);
101
+ input
102
+ })
103
+ .map(|input| Tensor::of_slice(&(input)))
104
+ .collect::<Vec<_>>();
105
+ let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
106
+
107
+ // Forward pass
108
+ let model_output = no_grad(|| {
109
+ bert_model.forward_t(
110
+ Some(&input_tensor),
111
+ None,
112
+ None,
113
+ None,
114
+ None,
115
+ None,
116
+ None,
117
+ false,
118
+ )
119
+ });
120
+ println!("MASK: {}", mask_index);
121
+ // Print masked tokens
122
+ let index_1 = model_output
123
+ .prediction_scores
124
+ .get(0)
125
+ .get(mask_index as i64)
126
+ .argmax(0, false);
127
+
128
+ let word = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
129
+ println!("{}", word);
130
+
131
+ Ok(())
132
+ }
133
+
134
+ ```
135
+
136
+ ※ 上のコードでは、[MASK]の代わりに "*" を使うことになってます。
137
+
138
+ ## Licenses
139
  The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/).
140