File size: 3,687 Bytes
5b9c9fd 1ae49c0 5b9c9fd 1ae49c0 5b9c9fd 1ae49c0 5b9c9fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
---
language: ja
license: cc-by-sa-4.0
datasets:
- wikipedia
widget:
- text: Rustで[MASK]を使うことができます。。
---
# What is this model?
- 東北大学のBERT large JapaneseをRustで使える様に変換
- [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese)
# How to Try
### 1. Clone
```
git clone https://huggingface.co/Yokohide031/rust_cl-tohoku_bert-large-japanese
```
### 2. Create Project
```
cargo new <projectName>
```
### 3. Edit main.rs
```
extern crate anyhow;
use rust_bert::bert::{BertConfig, BertForMaskedLM};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
use std::path::PathBuf;
fn get_path(item: String) -> PathBuf {
let mut resource_dir = PathBuf::from("path/to/rust_cl-tohoku_bert-large-japanese/");
resource_dir.push(&item);
println!("{:?}", resource_dir);
return resource_dir;
}
fn input(display: String) -> String {
let mut text = String::new();
println!("{}", display);
std::io::stdin().read_line(&mut text).unwrap();
return text.trim().to_string();
}
fn main() -> anyhow::Result<()> {
// Resources paths
let model_path: PathBuf = get_path(String::from("rust_model.ot"));
let vocab_path: PathBuf = get_path(String::from("vocab.txt"));
let config_path: PathBuf = get_path(String::from("config.json"));
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let config = BertConfig::from_file(config_path);
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(model_path)?;
// Define input
let inp = input(String::from("Input: "));
let inp = inp.replace("*", "[MASK]");
let input = [inp];
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), false, false).unwrap();
let owakatied = &tokenizer.tokenize_list(&input);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let mut mask_index: usize = 0;
for (i, m) in owakatied[0].iter().enumerate() {
if m == "[MASK]" {
mask_index = i+1;
break;
}
}
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
false,
)
});
println!("MASK: {}", mask_index);
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(mask_index as i64)
.argmax(0, false);
let word = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
println!("{}", word);
Ok(())
}
```
※ 上のコードでは、[MASK]の代わりに "*" を使うことになってます。
## Licenses
The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/).
|