File size: 2,897 Bytes
0ee6042 210d258 0ee6042 210d258 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# Transformation spoken text to written text
![Model](https://raw.githubusercontent.com/nguyenvulebinh/spoken-norm/main/spoken_norm_model.svg)
```python
import torch
import model_handling
from data_handling import DataCollatorForNormSeq2Seq
from model_handling import EncoderDecoderSpokenNorm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
```
# Init tokenizer and model
```python
tokenizer = model_handling.init_tokenizer()
model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm', cache_dir=model_handling.cache_dir)
data_collator = DataCollatorForNormSeq2Seq(tokenizer)
```
# Infer sample
```python
bias_list = ['scotland', 'covid', 'delta', 'beta']
input_str = 'ngày hai tám tháng tư cô vít bùng phát ở sờ cốt lờn chiếm tám mươi phần trăm là biến chủng đen ta và bê ta'
```
```python
inputs = tokenizer([input_str])
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
if len(bias_list) > 0:
bias = data_collator.encode_list_string(bias_list)
bias_input_ids = bias['input_ids']
bias_attention_mask = bias['attention_mask']
else:
bias_input_ids = None
bias_attention_mask = None
inputs = {
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"bias_input_ids": bias_input_ids,
"bias_attention_mask": bias_attention_mask,
}
```
## Format input text **with** bias phrases
```python
outputs = model.generate(**inputs, output_attentions=True, num_beams=1, num_return_sequences=1)
for output in outputs.cpu().detach().numpy().tolist():
# print('\n', tokenizer.decode(output, skip_special_tokens=True).split(), '\n')
print(tokenizer.sp_model.DecodePieces(tokenizer.decode(output, skip_special_tokens=True).split()))
```
28/4 covid bùng phát ở scotland chiếm 80 % là biến chủng delta và beta
## Format input text **without** bias phrases
```python
outputs = model.generate(**{
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"bias_input_ids": None,
"bias_attention_mask": None,
}, output_attentions=True, num_beams=1, num_return_sequences=1)
for output in outputs.cpu().detach().numpy().tolist():
# print('\n', tokenizer.decode(output, skip_special_tokens=True).split(), '\n')
print(tokenizer.sp_model.DecodePieces(tokenizer.decode(output, skip_special_tokens=True).split()))
```
28/4 cô vít bùng phát ở sờ cốt lờn chiếm 80 % là biến chủng đen ta và bê ta
## About
*Built by Binh Nguyen*
[![Follow](https://img.shields.io/twitter/follow/nguyenvulebinh?style=social)](https://twitter.com/intent/follow?screen_name=nguyenvulebinh)
For more details, visit the project repository.
[![GitHub stars](https://img.shields.io/github/stars/nguyenvulebinh/spoken-norm?style=social)](https://github.com/nguyenvulebinh/spoken-norm) |