voice_clone_v3 / transformers /docs /source /ko /tasks /document_question_answering.md
ahassoun's picture
Upload 3018 files
ee6e328
|
raw
history blame
No virus
26.3 kB

๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต(Document Question Answering) [[document_question_answering]]

[[open-in-colab]]

๋ฌธ์„œ ์‹œ๊ฐ์  ์งˆ์˜ ์‘๋‹ต(Document Visual Question Answering)์ด๋ผ๊ณ ๋„ ํ•˜๋Š” ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต(Document Question Answering)์€ ๋ฌธ์„œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์งˆ๋ฌธ์— ๋‹ต๋ณ€์„ ์ฃผ๋Š” ํƒœ์Šคํฌ์ž…๋‹ˆ๋‹ค. ์ด ํƒœ์Šคํฌ๋ฅผ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์˜ ์กฐํ•ฉ์ด๊ณ , ์ถœ๋ ฅ์€ ์ž์—ฐ์–ด๋กœ ๋œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ํ…์ŠคํŠธ, ๋‹จ์–ด์˜ ์œ„์น˜(๋ฐ”์šด๋”ฉ ๋ฐ•์Šค), ์ด๋ฏธ์ง€ ๋“ฑ ๋‹ค์–‘ํ•œ ๋ชจ๋‹ฌ๋ฆฌํ‹ฐ๋ฅผ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ๋Š” ๋‹ค์Œ ๋‚ด์šฉ์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค:

  • DocVQA dataset์„ ์‚ฌ์šฉํ•ด LayoutLMv2 ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ
  • ์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ธฐ

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์„ค๋ช…ํ•˜๋Š” ํƒœ์Šคํฌ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์—์„œ ์ง€์›๋ฉ๋‹ˆ๋‹ค:

LayoutLM, LayoutLMv2, LayoutLMv3

LayoutLMv2๋Š” ํ† ํฐ์˜ ๋งˆ์ง€๋ง‰ ์€๋‹‰์ธต ์œ„์— ์งˆ์˜ ์‘๋‹ต ํ—ค๋“œ๋ฅผ ์ถ”๊ฐ€ํ•ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘ ํ† ํฐ๊ณผ ๋ ํ† ํฐ์˜ ์œ„์น˜๋ฅผ ์˜ˆ์ธกํ•จ์œผ๋กœ์จ ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต ํƒœ์Šคํฌ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๋ฌธ๋งฅ์ด ์ฃผ์–ด์กŒ์„ ๋•Œ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜๋Š” ์ถ”์ถœํ˜• ์งˆ์˜ ์‘๋‹ต(Extractive question answering)์œผ๋กœ ๋ฌธ์ œ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๋ฌธ๋งฅ์€ OCR ์—”์ง„์˜ ์ถœ๋ ฅ์—์„œ ๊ฐ€์ ธ์˜ค๋ฉฐ, ์—ฌ๊ธฐ์„œ๋Š” Google์˜ Tesseract๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”. LayoutLMv2๋Š” detectron2, torchvision ๋ฐ ํ…Œ์„œ๋ž™ํŠธ๋ฅผ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค.

pip install -q transformers datasets
pip install 'git+https://github.com/facebookresearch/detectron2.git'
pip install torchvision
sudo apt install tesseract-ocr
pip install -q pytesseract

ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ ๋ชจ๋‘ ์„ค์น˜ํ•œ ํ›„ ๋Ÿฐํƒ€์ž„์„ ๋‹ค์‹œ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.

์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋‹น์‹ ์˜ ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•ด์„œ ๋ชจ๋ธ์„ ๐Ÿค— Hub์— ์—…๋กœ๋“œํ•˜์„ธ์š”. ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์‹คํ–‰๋˜๋ฉด, ๋กœ๊ทธ์ธ์„ ์œ„ํ•ด ํ† ํฐ์„ ์ž…๋ ฅํ•˜์„ธ์š”:

>>> from huggingface_hub import notebook_login

>>> notebook_login()

๋ช‡ ๊ฐ€์ง€ ์ „์—ญ ๋ณ€์ˆ˜๋ฅผ ์ •์˜ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased"
>>> batch_size = 4

๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ [[load-the-data]]

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๐Ÿค— Hub์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ์ „์ฒ˜๋ฆฌ๋œ DocVQA์˜ ์ž‘์€ ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. DocVQA์˜ ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, DocVQA homepage์— ๊ฐ€์ž… ํ›„ ๋‹ค์šด๋กœ๋“œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋‹ค์šด๋กœ๋“œ ํ–ˆ๋‹ค๋ฉด, ์ด ๊ฐ€์ด๋“œ๋ฅผ ๊ณ„์† ์ง„ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๐Ÿค— dataset์— ํŒŒ์ผ์„ ๊ฐ€์ ธ์˜ค๋Š” ๋ฐฉ๋ฒ•์„ ํ™•์ธํ•˜์„ธ์š”.

>>> from datasets import load_dataset

>>> dataset = load_dataset("nielsr/docvqa_1200_examples")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
        num_rows: 200
    })
})

๋ณด์‹œ๋‹ค์‹œํ”ผ, ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ์ด๋ฏธ ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌด์ž‘์œ„๋กœ ์˜ˆ์ œ๋ฅผ ์‚ดํŽด๋ณด๋ฉด์„œ ํŠน์„ฑ์„ ํ™•์ธํ•ด๋ณด์„ธ์š”.

>>> dataset["train"].features

๊ฐ ํ•„๋“œ๊ฐ€ ๋‚˜ํƒ€๋‚ด๋Š” ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • id: ์˜ˆ์ œ์˜ id
  • image: ๋ฌธ์„œ ์ด๋ฏธ์ง€๋ฅผ ํฌํ•จํ•˜๋Š” PIL.Image.Image ๊ฐ์ฒด
  • query: ์งˆ๋ฌธ ๋ฌธ์ž์—ด - ์—ฌ๋Ÿฌ ์–ธ์–ด์˜ ์ž์—ฐ์–ด๋กœ ๋œ ์งˆ๋ฌธ
  • answers: ์‚ฌ๋žŒ์ด ์ฃผ์„์„ ๋‹จ ์ •๋‹ต ๋ฆฌ์ŠคํŠธ
  • words and bounding_boxes: OCR์˜ ๊ฒฐ๊ณผ๊ฐ’๋“ค์ด๋ฉฐ ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์˜ˆ์ •
  • answer: ๋‹ค๋ฅธ ๋ชจ๋ธ๊ณผ ์ผ์น˜ํ•˜๋Š” ๋‹ต๋ณ€์ด๋ฉฐ ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์˜ˆ์ •

์˜์–ด๋กœ ๋œ ์งˆ๋ฌธ๋งŒ ๋‚จ๊ธฐ๊ณ  ๋‹ค๋ฅธ ๋ชจ๋ธ์— ๋Œ€ํ•œ ์˜ˆ์ธก์„ ํฌํ•จํ•˜๋Š” answer ํŠน์„ฑ์„ ์‚ญ์ œํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ฃผ์„ ์ž‘์„ฑ์ž๊ฐ€ ์ œ๊ณตํ•œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์ฒซ ๋ฒˆ์งธ ๋‹ต๋ณ€์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. ๋˜๋Š” ๋ฌด์ž‘์œ„๋กœ ์ƒ˜ํ”Œ์„ ์ถ”์ถœํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"])
>>> updated_dataset = updated_dataset.map(
...     lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"]
... )

์ด ๊ฐ€์ด๋“œ์—์„œ ์‚ฌ์šฉํ•˜๋Š” LayoutLMv2 ์ฒดํฌํฌ์ธํŠธ๋Š” max_position_embeddings = 512๋กœ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค(์ด ์ •๋ณด๋Š” ์ฒดํฌํฌ์ธํŠธ์˜ config.json ํŒŒ์ผ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค). ๋ฐ”๋กœ ์˜ˆ์ œ๋ฅผ ์ž˜๋ผ๋‚ผ ์ˆ˜๋„ ์žˆ์ง€๋งŒ, ๊ธด ๋ฌธ์„œ์˜ ๋์— ๋‹ต๋ณ€์ด ์žˆ์–ด ์ž˜๋ฆฌ๋Š” ์ƒํ™ฉ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ์—ฌ๊ธฐ์„œ๋Š” ์ž„๋ฒ ๋”ฉ์ด 512๋ณด๋‹ค ๊ธธ์–ด์งˆ ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ์˜ˆ์ œ๋ฅผ ์ œ๊ฑฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์žˆ๋Š” ๋Œ€๋ถ€๋ถ„์˜ ๋ฌธ์„œ๊ฐ€ ๊ธด ๊ฒฝ์šฐ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค - ์ž์„ธํ•œ ๋‚ด์šฉ์„ ํ™•์ธํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์ด ๋…ธํŠธ๋ถ์„ ํ™•์ธํ•˜์„ธ์š”.

>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)

์ด ์‹œ์ ์—์„œ ์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ OCR ํŠน์„ฑ๋„ ์ œ๊ฑฐํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. OCR ํŠน์„ฑ์€ ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์œผ๋กœ, ์ด ๊ฐ€์ด๋“œ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์˜ ์ž…๋ ฅ ์š”๊ตฌ ์‚ฌํ•ญ๊ณผ ์ผ์น˜ํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์ด ํŠน์„ฑ์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ผ๋ถ€ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋Œ€์‹ , ์›๋ณธ ๋ฐ์ดํ„ฐ์— [LayoutLMv2Processor]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ OCR ๋ฐ ํ† ํฐํ™”๋ฅผ ๋ชจ๋‘ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ์ž…๋ ฅ์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๋ฅผ ์ˆ˜๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด, LayoutLMv2 model documentation์—์„œ ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ์ž…๋ ฅ ํฌ๋งท์„ ํ™•์ธํ•ด๋ณด์„ธ์š”.

>>> updated_dataset = updated_dataset.remove_columns("words")
>>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")

๋งˆ์ง€๋ง‰์œผ๋กœ, ๋ฐ์ดํ„ฐ ํƒ์ƒ‰์„ ์™„๋ฃŒํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ์˜ˆ์‹œ๋ฅผ ์‚ดํŽด๋ด…์‹œ๋‹ค.

>>> updated_dataset["train"][11]["image"]
DocVQA Image Example

๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocess-the-data]]

๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต ํƒœ์Šคํฌ๋Š” ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ํƒœ์Šคํฌ์ด๋ฉฐ, ๊ฐ ๋ชจ๋‹ฌ๋ฆฌํ‹ฐ์˜ ์ž…๋ ฅ์ด ๋ชจ๋ธ์˜ ์š”๊ตฌ์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌ ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์™€ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ๋Š” ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๊ฒฐํ•ฉํ•œ [LayoutLMv2Processor]๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

>>> from transformers import AutoProcessor

>>> processor = AutoProcessor.from_pretrained(model_checkpoint)

๋ฌธ์„œ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ [[preprocessing-document-images]]

๋จผ์ €, ํ”„๋กœ์„ธ์„œ์˜ image_processor๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์— ๋Œ€ํ•œ ๋ฌธ์„œ ์ด๋ฏธ์ง€๋ฅผ ์ค€๋น„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ, ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋Š” ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 224x224๋กœ ์กฐ์ •ํ•˜๊ณ  ์ƒ‰์ƒ ์ฑ„๋„์˜ ์ˆœ์„œ๊ฐ€ ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•œ ํ›„ ๋‹จ์–ด์™€ ์ •๊ทœํ™”๋œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด ํ…Œ์„œ๋ž™ํŠธ๋ฅผ ์‚ฌ์šฉํ•ด OCR๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์šฐ๋ฆฌ๊ฐ€ ํ•„์š”ํ•œ ๊ฒƒ๊ณผ ๊ธฐ๋ณธ๊ฐ’์€ ์™„์ „ํžˆ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ฐฐ์น˜์— ๊ธฐ๋ณธ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ๋ฅผ ์ ์šฉํ•˜๊ณ  OCR์˜ ๊ฒฐ๊ณผ๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.

>>> image_processor = processor.image_processor


>>> def get_ocr_words_and_boxes(examples):
...     images = [image.convert("RGB") for image in examples["image"]]
...     encoded_inputs = image_processor(images)

...     examples["image"] = encoded_inputs.pixel_values
...     examples["words"] = encoded_inputs.words
...     examples["boxes"] = encoded_inputs.boxes

...     return examples

์ด ์ „์ฒ˜๋ฆฌ๋ฅผ ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ „์ฒด์— ๋น ๋ฅด๊ฒŒ ์ ์šฉํ•˜๋ ค๋ฉด [~datasets.Dataset.map]๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.

>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)

ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ [[preprocessing-text-data]]

์ด๋ฏธ์ง€์— OCR์„ ์ ์šฉํ–ˆ์œผ๋ฉด ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํ…์ŠคํŠธ ๋ถ€๋ถ„์„ ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ธ์ฝ”๋”ฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ธ์ฝ”๋”ฉ์—๋Š” ์ด์ „ ๋‹จ๊ณ„์—์„œ ๊ฐ€์ ธ์˜จ ๋‹จ์–ด์™€ ๋ฐ•์Šค๋ฅผ ํ† ํฐ ์ˆ˜์ค€์˜ input_ids, attention_mask, token_type_ids ๋ฐ bbox๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ž‘์—…์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค. ํ…์ŠคํŠธ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ํ”„๋กœ์„ธ์„œ์˜ tokenizer๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

>>> tokenizer = processor.tokenizer

์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ „์ฒ˜๋ฆฌ ์™ธ์—๋„ ๋ชจ๋ธ์„ ์œ„ํ•ด ๋ ˆ์ด๋ธ”์„ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๐Ÿค— Transformers์˜ xxxForQuestionAnswering ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, ๋ ˆ์ด๋ธ”์€ start_positions์™€ end_positions๋กœ ๊ตฌ์„ฑ๋˜๋ฉฐ ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘๊ณผ ๋์— ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

๋ ˆ์ด๋ธ” ์ถ”๊ฐ€๋ฅผ ์œ„ํ•ด์„œ, ๋จผ์ € ๋” ํฐ ๋ฆฌ์ŠคํŠธ(๋‹จ์–ด ๋ฆฌ์ŠคํŠธ)์—์„œ ํ•˜์œ„ ๋ฆฌ์ŠคํŠธ(๋‹จ์–ด๋กœ ๋ถ„ํ• ๋œ ๋‹ต๋ณ€)์„ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ํ—ฌํผ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

์ด ํ•จ์ˆ˜๋Š” words_list์™€ answer_list, ์ด๋ ‡๊ฒŒ ๋‘ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ words_list๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ words_list์˜ ํ˜„์žฌ ๋‹จ์–ด(words_list[i])๊ฐ€ answer_list์˜ ์ฒซ ๋ฒˆ์งธ ๋‹จ์–ด(answer_list[0])์™€ ๊ฐ™์€์ง€, ํ˜„์žฌ ๋‹จ์–ด์—์„œ ์‹œ์ž‘ํ•ด answer_list์™€ ๊ฐ™์€ ๊ธธ์ด๋งŒํผ์˜ words_list์˜ ํ•˜์œ„ ๋ฆฌ์ŠคํŠธ๊ฐ€ answer_list์™€ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. ์ด ์กฐ๊ฑด์ด ์ฐธ์ด๋ผ๋ฉด ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์„ ๋ฐœ๊ฒฌํ–ˆ์Œ์„ ์˜๋ฏธํ•˜๋ฉฐ, ํ•จ์ˆ˜๋Š” ์ผ์น˜ ํ•ญ๋ชฉ, ์‹œ์ž‘ ์ธ๋ฑ์Šค(idx) ๋ฐ ์ข…๋ฃŒ ์ธ๋ฑ์Šค(idx + len(answer_list) - 1)๋ฅผ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์ด ๋‘ ๊ฐœ ์ด์ƒ ๋ฐœ๊ฒฌ๋˜๋ฉด ํ•จ์ˆ˜๋Š” ์ฒซ ๋ฒˆ์งธ ํ•ญ๋ชฉ๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์ด ์—†๋‹ค๋ฉด ํ•จ์ˆ˜๋Š” (None, 0, 0)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

>>> def subfinder(words_list, answer_list):
...     matches = []
...     start_indices = []
...     end_indices = []
...     for idx, i in enumerate(range(len(words_list))):
...         if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:
...             matches.append(answer_list)
...             start_indices.append(idx)
...             end_indices.append(idx + len(answer_list) - 1)
...     if matches:
...         return matches[0], start_indices[0], end_indices[0]
...     else:
...         return None, 0, 0

์ด ํ•จ์ˆ˜๊ฐ€ ์–ด๋–ป๊ฒŒ ์ •๋‹ต์˜ ์œ„์น˜๋ฅผ ์ฐพ๋Š”์ง€ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์˜ˆ์ œ์—์„œ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> example = dataset_with_ocr["train"][1]
>>> words = [word.lower() for word in example["words"]]
>>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split())
>>> print("Question: ", example["question"])
>>> print("Words:", words)
>>> print("Answer: ", example["answer"])
>>> print("start_index", word_idx_start)
>>> print("end_index", word_idx_end)
Question:  Who is in  cc in this letter?
Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', 'ยซshort', 'cigarette,', 'tobacco', 'section', '30', 'mm.', 'ยซextremely', 'fast', 'buming', 'cigarette.', 'ยซnovel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', 'ยซmore', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']
Answer:  T.F. Riehl
start_index 17
end_index 18

ํ•œํŽธ, ์œ„ ์˜ˆ์ œ๊ฐ€ ์ธ์ฝ”๋”ฉ๋˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค:

>>> encoding = tokenizer(example["question"], example["words"], example["boxes"])
>>> tokenizer.decode(encoding["input_ids"])
[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...

์ด์ œ ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์—์„œ ์ •๋‹ต์˜ ์œ„์น˜๋ฅผ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.

  • token_type_ids๋Š” ์–ด๋–ค ํ† ํฐ์ด ์งˆ๋ฌธ์— ์†ํ•˜๋Š”์ง€, ๊ทธ๋ฆฌ๊ณ  ์–ด๋–ค ํ† ํฐ์ด ๋ฌธ์„œ์˜ ๋‹จ์–ด์— ํฌํ•จ๋˜๋Š”์ง€๋ฅผ ์•Œ๋ ค์ค๋‹ˆ๋‹ค.
  • tokenizer.cls_token_id ์ž…๋ ฅ์˜ ์‹œ์ž‘ ๋ถ€๋ถ„์— ์žˆ๋Š” ํŠน์ˆ˜ ํ† ํฐ์„ ์ฐพ๋Š” ๋ฐ ๋„์›€์„ ์ค๋‹ˆ๋‹ค.
  • word_ids๋Š” ์›๋ณธ words์—์„œ ์ฐพ์€ ๋‹ต๋ณ€์„ ์ „์ฒด ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์˜ ๋™์ผํ•œ ๋‹ต๊ณผ ์ผ์น˜์‹œํ‚ค๊ณ  ์ธ์ฝ”๋”ฉ๋œ ์ž…๋ ฅ์—์„œ ๋‹ต๋ณ€์˜ ์‹œ์ž‘/๋ ์œ„์น˜๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.

์œ„ ๋‚ด์šฉ๋“ค์„ ์—ผ๋‘์— ๋‘๊ณ  ๋ฐ์ดํ„ฐ ์„ธํŠธ ์˜ˆ์ œ์˜ ๋ฐฐ์น˜๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> def encode_dataset(examples, max_length=512):
...     questions = examples["question"]
...     words = examples["words"]
...     boxes = examples["boxes"]
...     answers = examples["answer"]

...     # ์˜ˆ์ œ ๋ฐฐ์น˜๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๊ณ  start_positions์™€ end_positions๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค
...     encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
...     start_positions = []
...     end_positions = []

...     # ๋ฐฐ์น˜์˜ ์˜ˆ์ œ๋ฅผ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค
...     for i in range(len(questions)):
...         cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)

...         # ์˜ˆ์ œ์˜ words์—์„œ ๋‹ต๋ณ€์˜ ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
...         words_example = [word.lower() for word in words[i]]
...         answer = answers[i]
...         match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())

...         if match:
...             # ์ผ์น˜ํ•˜๋Š” ํ•ญ๋ชฉ์„ ๋ฐœ๊ฒฌํ•˜๋ฉด, `token_type_ids`๋ฅผ ์‚ฌ์šฉํ•ด ์ธ์ฝ”๋”ฉ์—์„œ ๋‹จ์–ด๊ฐ€ ์‹œ์ž‘ํ•˜๋Š” ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
...             token_type_ids = encoding["token_type_ids"][i]
...             token_start_index = 0
...             while token_type_ids[token_start_index] != 1:
...                 token_start_index += 1

...             token_end_index = len(encoding["input_ids"][i]) - 1
...             while token_type_ids[token_end_index] != 1:
...                 token_end_index -= 1

...             word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
...             start_position = cls_index
...             end_position = cls_index

...             # words์˜ ๋‹ต๋ณ€ ์œ„์น˜์™€ ์ผ์น˜ํ•  ๋•Œ๊นŒ์ง€ word_ids๋ฅผ ๋ฐ˜๋ณตํ•˜๊ณ  `token_start_index`๋ฅผ ๋Š˜๋ฆฝ๋‹ˆ๋‹ค
...             # ์ผ์น˜ํ•˜๋ฉด `token_start_index`๋ฅผ ์ธ์ฝ”๋”ฉ์—์„œ ๋‹ต๋ณ€์˜ `start_position`์œผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
...             for id in word_ids:
...                 if id == word_idx_start:
...                     start_position = token_start_index
...                 else:
...                     token_start_index += 1

...             # ๋น„์Šทํ•˜๊ฒŒ, ๋์—์„œ ์‹œ์ž‘ํ•ด `word_ids`๋ฅผ ๋ฐ˜๋ณตํ•˜๋ฉฐ ๋‹ต๋ณ€์˜ `end_position`์„ ์ฐพ์Šต๋‹ˆ๋‹ค
...             for id in word_ids[::-1]:
...                 if id == word_idx_end:
...                     end_position = token_end_index
...                 else:
...                     token_end_index -= 1

...             start_positions.append(start_position)
...             end_positions.append(end_position)

...         else:
...             start_positions.append(cls_index)
...             end_positions.append(cls_index)

...     encoding["image"] = examples["image"]
...     encoding["start_positions"] = start_positions
...     encoding["end_positions"] = end_positions

...     return encoding

์ด์ œ ์ด ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๊ฐ€ ์žˆ์œผ๋‹ˆ ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> encoded_train_dataset = dataset_with_ocr["train"].map(
...     encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
... )
>>> encoded_test_dataset = dataset_with_ocr["test"].map(
...     encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
... )

์ธ์ฝ”๋”ฉ๋œ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ํŠน์„ฑ์ด ์–ด๋–ป๊ฒŒ ์ƒ๊ฒผ๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'start_positions': Value(dtype='int64', id=None),
 'end_positions': Value(dtype='int64', id=None)}

ํ‰๊ฐ€ [[evaluation]]

๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต์„ ํ‰๊ฐ€ํ•˜๋ ค๋ฉด ์ƒ๋‹นํ•œ ์–‘์˜ ํ›„์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ๋„ˆ๋ฌด ๋งŽ์ด ๊ฑธ๋ฆฌ์ง€ ์•Š๋„๋ก ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ํ‰๊ฐ€ ๋‹จ๊ณ„๋ฅผ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค. [Trainer]๊ฐ€ ํ›ˆ๋ จ ๊ณผ์ •์—์„œ ํ‰๊ฐ€ ์†์‹ค(evaluation loss)์„ ๊ณ„์† ๊ณ„์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๋Œ€๋žต์ ์œผ๋กœ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ถ”์ถœ์ (Extractive) ์งˆ์˜ ์‘๋‹ต์€ ๋ณดํ†ต F1/exact match ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด ํ‰๊ฐ€๋ฉ๋‹ˆ๋‹ค. ์ง์ ‘ ๊ตฌํ˜„ํ•ด๋ณด๊ณ  ์‹ถ์œผ์‹œ๋‹ค๋ฉด, Hugging Face course์˜ Question Answering chapter์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

ํ›ˆ๋ จ [[train]]

์ถ•ํ•˜ํ•ฉ๋‹ˆ๋‹ค! ์ด ๊ฐ€์ด๋“œ์˜ ๊ฐ€์žฅ ์–ด๋ ค์šด ๋ถ€๋ถ„์„ ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌํ–ˆ์œผ๋‹ˆ ์ด์ œ ๋‚˜๋งŒ์˜ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ›ˆ๋ จ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค:

  • ์ „์ฒ˜๋ฆฌ์—์„œ์˜ ๋™์ผํ•œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด [AutoModelForDocumentQuestionAnswering]์œผ๋กœ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • [TrainingArguments]๋กœ ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ์˜ˆ์ œ๋ฅผ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” [DefaultDataCollator]๊ฐ€ ์ ๋‹นํ•ฉ๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ ์„ธํŠธ, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ(Data collator)์™€ ํ•จ๊ป˜ [Trainer]์— ํ›ˆ๋ จ ์ธ์ˆ˜๋“ค์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • [~Trainer.train]์„ ํ˜ธ์ถœํ•ด์„œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
>>> from transformers import AutoModelForDocumentQuestionAnswering

>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)

[TrainingArguments]์—์„œ output_dir์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ €์žฅํ•  ์œ„์น˜๋ฅผ ์ง€์ •ํ•˜๊ณ , ์ ์ ˆํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ์ปค๋ฎค๋‹ˆํ‹ฐ์™€ ๊ณต์œ ํ•˜๋ ค๋ฉด push_to_hub๋ฅผ True๋กœ ์„ค์ •ํ•˜์„ธ์š” (๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๋ ค๋ฉด Hugging Face์— ๋กœ๊ทธ์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค). ์ด ๊ฒฝ์šฐ output_dir์€ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ํ‘ธ์‹œํ•  ๋ ˆํฌ์ง€ํ† ๋ฆฌ์˜ ์ด๋ฆ„์ด ๋ฉ๋‹ˆ๋‹ค.

>>> from transformers import TrainingArguments

>>> # ๋ณธ์ธ์˜ ๋ ˆํฌ์ง€ํ† ๋ฆฌ ID๋กœ ๋ฐ”๊พธ์„ธ์š”
>>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa"

>>> training_args = TrainingArguments(
...     output_dir=repo_id,
...     per_device_train_batch_size=4,
...     num_train_epochs=20,
...     save_steps=200,
...     logging_steps=50,
...     evaluation_strategy="steps",
...     learning_rate=5e-5,
...     save_total_limit=2,
...     remove_unused_columns=False,
...     push_to_hub=True,
... )

๊ฐ„๋‹จํ•œ ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ๋ฅผ ์ •์˜ํ•˜์—ฌ ์˜ˆ์ œ๋ฅผ ํ•จ๊ป˜ ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator()

๋งˆ์ง€๋ง‰์œผ๋กœ, ๋ชจ๋“  ๊ฒƒ์„ ํ•œ ๊ณณ์— ๋ชจ์•„ [~Trainer.train]์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค:

>>> from transformers import Trainer

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     data_collator=data_collator,
...     train_dataset=encoded_train_dataset,
...     eval_dataset=encoded_test_dataset,
...     tokenizer=processor,
... )

>>> trainer.train()

์ตœ์ข… ๋ชจ๋ธ์„ ๐Ÿค— Hub์— ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด, ๋ชจ๋ธ ์นด๋“œ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  push_to_hub๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค:

>>> trainer.create_model_card()
>>> trainer.push_to_hub()

์ถ”๋ก  [[inference]]

์ด์ œ LayoutLMv2 ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ  ๐Ÿค— Hub์— ์—…๋กœ๋“œํ–ˆ์œผ๋‹ˆ ์ถ”๋ก ์—๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [Pipeline]์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> example = dataset["test"][2]
>>> question = example["query"]["en"]
>>> image = example["image"]
>>> print(question)
>>> print(example["answers"])
'Who is โ€˜presidingโ€™ TRRF GENERAL SESSION (PART 1)?'
['TRRF Vice President', 'lee a. waller']

๊ทธ ๋‹ค์Œ, ๋ชจ๋ธ๋กœ ๋ฌธ์„œ ์งˆ์˜ ์‘๋‹ต์„ ํ•˜๊ธฐ ์œ„ํ•ด ํŒŒ์ดํ”„๋ผ์ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์ด๋ฏธ์ง€ + ์งˆ๋ฌธ ์กฐํ•ฉ์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers import pipeline

>>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> qa_pipeline(image, question)
[{'score': 0.9949808120727539,
  'answer': 'Lee A. Waller',
  'start': 55,
  'end': 57}]

์›ํ•œ๋‹ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ณต์ œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

  1. ์ด๋ฏธ์ง€์™€ ์งˆ๋ฌธ์„ ๊ฐ€์ ธ์™€ ๋ชจ๋ธ์˜ ํ”„๋กœ์„ธ์„œ๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ชจ๋ธ์„ ํ†ตํ•ด ๊ฒฐ๊ณผ ๋˜๋Š” ์ „์ฒ˜๋ฆฌ๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  3. ๋ชจ๋ธ์€ ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์˜ ์‹œ์ž‘์— ์žˆ๋Š”์ง€, ์–ด๋–ค ํ† ํฐ์ด ๋‹ต๋ณ€์ด ๋์— ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” start_logits์™€ end_logits๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‘˜ ๋‹ค (batch_size, sequence_length) ํ˜•ํƒœ๋ฅผ ๊ฐ–์Šต๋‹ˆ๋‹ค.
  4. start_logits์™€ end_logits์˜ ๋งˆ์ง€๋ง‰ ์ฐจ์›์„ ์ตœ๋Œ€๋กœ ๋งŒ๋“œ๋Š” ๊ฐ’์„ ์ฐพ์•„ ์˜ˆ์ƒ start_idx์™€ end_idx๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
  5. ํ† ํฌ๋‚˜์ด์ €๋กœ ๋‹ต๋ณ€์„ ๋””์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.
>>> import torch
>>> from transformers import AutoProcessor
>>> from transformers import AutoModelForDocumentQuestionAnswering

>>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")

>>> with torch.no_grad():
...     encoding = processor(image.convert("RGB"), question, return_tensors="pt")
...     outputs = model(**encoding)
...     start_logits = outputs.start_logits
...     end_logits = outputs.end_logits
...     predicted_start_idx = start_logits.argmax(-1).item()
...     predicted_end_idx = end_logits.argmax(-1).item()

>>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
'lee a. waller'