File size: 3,571 Bytes
5ffc566
 
 
 
 
dacdaca
 
 
 
 
 
 
5ffc566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dacdaca
 
5ffc566
 
 
 
 
 
 
 
 
 
dacdaca
 
5ffc566
 
dacdaca
5ffc566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "1. Download the repo from Github https://github.com/clovaai/donut using git command or through direct download.\n",
        "2. (The base model config for document classification / document parsing / document Q&A tasks is stored under /config.\n",
        "3. Copy a copy of any YAML file, rename arbitarily and set your parameters.\n",
        "3. Prepare your dataset (train, validation, test) along with JSONL files on the /dataset folder. You can use program to generate JSONL files from csv files. Be remind of the format. One line per one data. One JSONL file in each folder (train/valdidation/test)\n",
        "4. Refer to donut_training.ipynb to train your model. Use A-100/V-100 GPU to avoid troublesome settings / slow training time. The trained model is stored under /result folder.\n",
        "5. Run the trained model using this ipynb file.\n",
        "6. Don't change the version of transformers and timm. It is a nightmare if you don't understand what you do."
      ],
      "metadata": {
        "id": "L5U1ACZZBxfh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Enable Google Drive and Go to the donut folder\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "%cd /content/drive/MyDrive/donut"
      ],
      "metadata": {
        "id": "-BZ2HFB9OtWP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SJpD4AAj7qeZ"
      },
      "outputs": [],
      "source": [
        "#Install all necessary modules. Don't change the version number!\n",
        "!pip install transformers==4.25.1\n",
        "!pip install timm==0.5.4\n",
        "!pip install donut-python"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# import necessary modules\n",
        "from donut import DonutModel\n",
        "from PIL import Image\n",
        "import torch\n",
        "import argparse"
      ],
      "metadata": {
        "id": "gSatjcDn5S89"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Input the default arguments\n",
        "parser = argparse.ArgumentParser()"
      ],
      "metadata": {
        "id": "RZSmy3Riz7ia"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model = DonutModel.from_pretrained(\"./result/train_Booking/donut-booking-extract\")\n",
        "if torch.cuda.is_available():\n",
        "    model.half()\n",
        "    device = torch.device(\"cuda\")\n",
        "    model.to(device)\n",
        "else:\n",
        "    model.encoder.to(torch.bfloat16)\n",
        "\n",
        "model.eval()\n",
        "\n",
        "image = Image.open(\"/content/drive/MyDrive/donut/test/4.jpg\").convert(\"RGB\")\n",
        "\n",
        "with torch.no_grad():\n",
        "  output = model.inference(image=image, prompt=\"<s_Booking>\")\n",
        "output"
      ],
      "metadata": {
        "id": "dFfm72T93Z8G"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "V100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}