File size: 3,708 Bytes
a7110bd
 
 
 
 
 
 
 
 
 
8fa75ff
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
b8e0ab4
371adfb
a7110bd
 
 
 
 
 
 
 
 
 
 
 
b8e0ab4
8fa75ff
a7110bd
 
b8e0ab4
a7110bd
 
 
 
 
 
 
 
 
 
 
baed986
 
 
 
 
a7110bd
 
 
 
d249c1c
a7110bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8e0ab4
a7110bd
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
---
language: 
  - ko
tags:
- trocr
- image-to-text
license: mit
metrics:
- wer
- cer
widget:
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/random_2.jpg
  example_title: 랜덤 문장 1
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/random_6.jpg
  example_title: 랜덤 문장 2
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/chatbot_3.jpg
  example_title: 챗봇 1
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/chatbot_5.jpg
  example_title: 챗봇 2
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_1.jpg
  example_title: 뉴스 1
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_3.jpg
  example_title: 뉴스 2
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/nsmc_1.jpg
  example_title: 영화 리뷰 1
- src: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/nsmc_2.jpg
  example_title: 영화 리뷰 2
---

# TrOCR for Korean Language (PoC)

## Overview

TrOCR has not yet released a multilingual model including Korean, so we trained a Korean model for PoC purpose. Based on this model, it is recommended to collect more data to additionally train the 1st stage or perform fine-tuning as the 2nd stage.

## Collecting data

### Text data
We created training data by processing three types of datasets.

- News summarization dataset: https://huggingface.co/datasets/daekeun-ml/naver-news-summarization-ko
- Naver Movie Sentiment Classification: https://github.com/e9t/nsmc
- Chatbot dataset: https://github.com/songys/Chatbot_data
  
For efficient data collection, each sentence was separated by a sentence separator library (Kiwi Python wrapper; https://github.com/bab2min/kiwipiepy), and as a result, 637,401 samples were collected.

### Image Data

Image data was generated with TextRecognitionDataGenerator (https://github.com/Belval/TextRecognitionDataGenerator) introduced in the TrOCR paper.
Below is a code snippet for generating images.
```shell
python3 ./trdg/run.py -i ocr_dataset_poc.txt -w 5 -t {num_cores} -f 64 -l ko -c {num_samples} -na 2 --output_dir {dataset_dir}
```

## Training

### Base model
The encoder model used `facebook/deit-base-distilled-patch16-384` and the decoder model used `klue/roberta-base`. It is easier than training by starting weights from `microsoft/trocr-base-stage1`.

### Parameters
We used heuristic parameters without separate hyperparameter tuning.
- learning_rate = 4e-5
- epochs = 25
- fp16 = True
- max_length = 64

## Usage

### inference.py

```python
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer
import requests 
from io import BytesIO
from PIL import Image

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 
model = VisionEncoderDecoderModel.from_pretrained("daekeun-ml/ko-trocr-base-nsmc-news-chatbot")
tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/ko-trocr-base-nsmc-news-chatbot")

url = "https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_1.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content))

pixel_values = processor(img, return_tensors="pt").pixel_values 
generated_ids = model.generate(pixel_values, max_length=64)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 
print(generated_text)
```

All the code required for data collection and model training has been published on the author's Github.
- https://github.com/daekeun-ml/sm-kornlp-usecases/tree/main/trocr