Training in progress, step 1000
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +171 -0
- .gitignore~ +170 -0
- README.md +7 -0
- adapter_config.json +35 -0
- adapter_model.safetensors +3 -0
- config.json +177 -0
- data.py +152 -0
- finetune_phi3_vision.py +263 -0
- finetuner_usloath.py +174 -0
- idefics2/adapter_config.json +26 -0
- idefics2/adapter_model.safetensors +3 -0
- idefics2/checkpoint-10000/adapter_config.json +26 -0
- idefics2/checkpoint-10000/adapter_model.safetensors +3 -0
- idefics2/checkpoint-10000/generation_config.json +7 -0
- idefics2/checkpoint-10000/optimizer.pt +3 -0
- idefics2/checkpoint-10000/rng_state.pth +3 -0
- idefics2/checkpoint-10000/scheduler.pt +3 -0
- idefics2/checkpoint-10000/trainer_state.json +0 -0
- idefics2/checkpoint-10000/training_args.bin +3 -0
- idefics2/checkpoint-8000/adapter_config.json +26 -0
- idefics2/checkpoint-8000/adapter_model.safetensors +3 -0
- idefics2/checkpoint-8000/generation_config.json +18 -0
- idefics2/checkpoint-8000/optimizer.pt +3 -0
- idefics2/checkpoint-8000/rng_state.pth +3 -0
- idefics2/checkpoint-8000/scheduler.pt +3 -0
- idefics2/checkpoint-8000/trainer_state.json +0 -0
- idefics2/checkpoint-8000/training_args.bin +3 -0
- idefics2/checkpoint-9000/adapter_config.json +26 -0
- idefics2/checkpoint-9000/adapter_model.safetensors +3 -0
- idefics2/checkpoint-9000/generation_config.json +18 -0
- idefics2/checkpoint-9000/optimizer.pt +3 -0
- idefics2/checkpoint-9000/rng_state.pth +3 -0
- idefics2/checkpoint-9000/scheduler.pt +3 -0
- idefics2/checkpoint-9000/trainer_state.json +0 -0
- idefics2/checkpoint-9000/training_args.bin +3 -0
- idefics2/training_args.bin +3 -0
- inference.py +98 -0
- inference_idefics2.py +97 -0
- model.py +204 -0
- model.safetensors +3 -0
- model_sft.py +217 -0
- phi3/checkpoint-25/adapter_config.json +26 -0
- phi3/checkpoint-25/adapter_model.safetensors +3 -0
- phi3/checkpoint-25/generation_config.json +18 -0
- phi3/checkpoint-25/optimizer.pt +3 -0
- phi3/checkpoint-25/rng_state.pth +3 -0
- phi3/checkpoint-25/scheduler.pt +3 -0
- phi3/checkpoint-25/trainer_state.json +84 -0
- phi3/checkpoint-25/training_args.bin +3 -0
- phi3_ocr.py +176 -0
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
.git/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
|
163 |
+
|
164 |
+
# nohup output
|
165 |
+
nohup.out
|
166 |
+
|
167 |
+
# wandb and output
|
168 |
+
wandb/
|
169 |
+
output/
|
170 |
+
trl/
|
171 |
+
|
.gitignore~
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
.git/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
|
163 |
+
|
164 |
+
# nohup output
|
165 |
+
nohup.out
|
166 |
+
|
167 |
+
# wandb and output
|
168 |
+
wandb/
|
169 |
+
output/
|
170 |
+
|
README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Alphapen
|
2 |
+
|
3 |
+
This project aims to develop an OCR model for instantaneous text extraction from handwritten documents. The ultimate goal is to seamlessly integrate such a model into computers or mobile phones, allowing for the direct digitalization of handwritten documents using a proprietary pen manufactured by a startup company named [Alphapen](https://alphapen.fr/views/index.html).
|
4 |
+
|
5 |
+
# Fine-tuning the TrOCR model
|
6 |
+
|
7 |
+
python model.py --log_with wandb --push_to_hub True --hub_model_id hadrakey/alphapen_trocr
|
adapter_config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": {
|
4 |
+
"base_model_class": "VisionEncoderDecoderModel",
|
5 |
+
"parent_library": "transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder"
|
6 |
+
},
|
7 |
+
"base_model_name_or_path": "microsoft/trocr-large-handwritten",
|
8 |
+
"bias": "none",
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 8,
|
17 |
+
"lora_dropout": 0.1,
|
18 |
+
"megatron_config": null,
|
19 |
+
"megatron_core": "megatron.core",
|
20 |
+
"modules_to_save": null,
|
21 |
+
"peft_type": "LORA",
|
22 |
+
"r": 1,
|
23 |
+
"rank_pattern": {},
|
24 |
+
"revision": null,
|
25 |
+
"target_modules": [
|
26 |
+
"intermediate.dense",
|
27 |
+
"key",
|
28 |
+
"output.dense",
|
29 |
+
"value",
|
30 |
+
"query"
|
31 |
+
],
|
32 |
+
"task_type": null,
|
33 |
+
"use_dora": false,
|
34 |
+
"use_rslora": false
|
35 |
+
}
|
adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f832825236eb8acbcbcbb821ede7f8dcdd64857560ac68e6a4431adbf3f4bc95
|
3 |
+
size 1811016
|
config.json
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/trocr-base-handwritten",
|
3 |
+
"architectures": [
|
4 |
+
"VisionEncoderDecoderModel"
|
5 |
+
],
|
6 |
+
"decoder": {
|
7 |
+
"_name_or_path": "",
|
8 |
+
"activation_dropout": 0.0,
|
9 |
+
"activation_function": "gelu",
|
10 |
+
"add_cross_attention": true,
|
11 |
+
"architectures": null,
|
12 |
+
"attention_dropout": 0.0,
|
13 |
+
"bad_words_ids": null,
|
14 |
+
"begin_suppress_tokens": null,
|
15 |
+
"bos_token_id": 0,
|
16 |
+
"chunk_size_feed_forward": 0,
|
17 |
+
"classifier_dropout": 0.0,
|
18 |
+
"cross_attention_hidden_size": 768,
|
19 |
+
"d_model": 1024,
|
20 |
+
"decoder_attention_heads": 16,
|
21 |
+
"decoder_ffn_dim": 4096,
|
22 |
+
"decoder_layerdrop": 0.0,
|
23 |
+
"decoder_layers": 12,
|
24 |
+
"decoder_start_token_id": 2,
|
25 |
+
"diversity_penalty": 0.0,
|
26 |
+
"do_sample": false,
|
27 |
+
"dropout": 0.1,
|
28 |
+
"early_stopping": false,
|
29 |
+
"encoder_no_repeat_ngram_size": 0,
|
30 |
+
"eos_token_id": 2,
|
31 |
+
"exponential_decay_length_penalty": null,
|
32 |
+
"finetuning_task": null,
|
33 |
+
"forced_bos_token_id": null,
|
34 |
+
"forced_eos_token_id": null,
|
35 |
+
"id2label": {
|
36 |
+
"0": "LABEL_0",
|
37 |
+
"1": "LABEL_1"
|
38 |
+
},
|
39 |
+
"init_std": 0.02,
|
40 |
+
"is_decoder": true,
|
41 |
+
"is_encoder_decoder": false,
|
42 |
+
"label2id": {
|
43 |
+
"LABEL_0": 0,
|
44 |
+
"LABEL_1": 1
|
45 |
+
},
|
46 |
+
"layernorm_embedding": true,
|
47 |
+
"length_penalty": 1.0,
|
48 |
+
"max_length": 20,
|
49 |
+
"max_position_embeddings": 512,
|
50 |
+
"min_length": 0,
|
51 |
+
"model_type": "trocr",
|
52 |
+
"no_repeat_ngram_size": 0,
|
53 |
+
"num_beam_groups": 1,
|
54 |
+
"num_beams": 1,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"output_attentions": false,
|
57 |
+
"output_hidden_states": false,
|
58 |
+
"output_scores": false,
|
59 |
+
"pad_token_id": 1,
|
60 |
+
"prefix": null,
|
61 |
+
"problem_type": null,
|
62 |
+
"pruned_heads": {},
|
63 |
+
"remove_invalid_values": false,
|
64 |
+
"repetition_penalty": 1.0,
|
65 |
+
"return_dict": true,
|
66 |
+
"return_dict_in_generate": false,
|
67 |
+
"scale_embedding": false,
|
68 |
+
"sep_token_id": null,
|
69 |
+
"suppress_tokens": null,
|
70 |
+
"task_specific_params": null,
|
71 |
+
"temperature": 1.0,
|
72 |
+
"tf_legacy_loss": false,
|
73 |
+
"tie_encoder_decoder": false,
|
74 |
+
"tie_word_embeddings": true,
|
75 |
+
"tokenizer_class": null,
|
76 |
+
"top_k": 50,
|
77 |
+
"top_p": 1.0,
|
78 |
+
"torch_dtype": null,
|
79 |
+
"torchscript": false,
|
80 |
+
"typical_p": 1.0,
|
81 |
+
"use_bfloat16": false,
|
82 |
+
"use_cache": false,
|
83 |
+
"use_learned_position_embeddings": true,
|
84 |
+
"vocab_size": 50265
|
85 |
+
},
|
86 |
+
"decoder_start_token_id": 0,
|
87 |
+
"early_stopping": true,
|
88 |
+
"encoder": {
|
89 |
+
"_name_or_path": "",
|
90 |
+
"add_cross_attention": false,
|
91 |
+
"architectures": null,
|
92 |
+
"attention_probs_dropout_prob": 0.0,
|
93 |
+
"bad_words_ids": null,
|
94 |
+
"begin_suppress_tokens": null,
|
95 |
+
"bos_token_id": null,
|
96 |
+
"chunk_size_feed_forward": 0,
|
97 |
+
"cross_attention_hidden_size": null,
|
98 |
+
"decoder_start_token_id": null,
|
99 |
+
"diversity_penalty": 0.0,
|
100 |
+
"do_sample": false,
|
101 |
+
"early_stopping": false,
|
102 |
+
"encoder_no_repeat_ngram_size": 0,
|
103 |
+
"encoder_stride": 16,
|
104 |
+
"eos_token_id": null,
|
105 |
+
"exponential_decay_length_penalty": null,
|
106 |
+
"finetuning_task": null,
|
107 |
+
"forced_bos_token_id": null,
|
108 |
+
"forced_eos_token_id": null,
|
109 |
+
"hidden_act": "gelu",
|
110 |
+
"hidden_dropout_prob": 0.0,
|
111 |
+
"hidden_size": 768,
|
112 |
+
"id2label": {
|
113 |
+
"0": "LABEL_0",
|
114 |
+
"1": "LABEL_1"
|
115 |
+
},
|
116 |
+
"image_size": 384,
|
117 |
+
"initializer_range": 0.02,
|
118 |
+
"intermediate_size": 3072,
|
119 |
+
"is_decoder": false,
|
120 |
+
"is_encoder_decoder": false,
|
121 |
+
"label2id": {
|
122 |
+
"LABEL_0": 0,
|
123 |
+
"LABEL_1": 1
|
124 |
+
},
|
125 |
+
"layer_norm_eps": 1e-12,
|
126 |
+
"length_penalty": 1.0,
|
127 |
+
"max_length": 20,
|
128 |
+
"min_length": 0,
|
129 |
+
"model_type": "vit",
|
130 |
+
"no_repeat_ngram_size": 0,
|
131 |
+
"num_attention_heads": 12,
|
132 |
+
"num_beam_groups": 1,
|
133 |
+
"num_beams": 1,
|
134 |
+
"num_channels": 3,
|
135 |
+
"num_hidden_layers": 12,
|
136 |
+
"num_return_sequences": 1,
|
137 |
+
"output_attentions": false,
|
138 |
+
"output_hidden_states": false,
|
139 |
+
"output_scores": false,
|
140 |
+
"pad_token_id": null,
|
141 |
+
"patch_size": 16,
|
142 |
+
"prefix": null,
|
143 |
+
"problem_type": null,
|
144 |
+
"pruned_heads": {},
|
145 |
+
"qkv_bias": false,
|
146 |
+
"remove_invalid_values": false,
|
147 |
+
"repetition_penalty": 1.0,
|
148 |
+
"return_dict": true,
|
149 |
+
"return_dict_in_generate": false,
|
150 |
+
"sep_token_id": null,
|
151 |
+
"suppress_tokens": null,
|
152 |
+
"task_specific_params": null,
|
153 |
+
"temperature": 1.0,
|
154 |
+
"tf_legacy_loss": false,
|
155 |
+
"tie_encoder_decoder": false,
|
156 |
+
"tie_word_embeddings": true,
|
157 |
+
"tokenizer_class": null,
|
158 |
+
"top_k": 50,
|
159 |
+
"top_p": 1.0,
|
160 |
+
"torch_dtype": null,
|
161 |
+
"torchscript": false,
|
162 |
+
"typical_p": 1.0,
|
163 |
+
"use_bfloat16": false
|
164 |
+
},
|
165 |
+
"eos_token_id": 2,
|
166 |
+
"is_encoder_decoder": true,
|
167 |
+
"length_penalty": 2.0,
|
168 |
+
"max_length": 64,
|
169 |
+
"model_type": "vision-encoder-decoder",
|
170 |
+
"no_repeat_ngram_size": 3,
|
171 |
+
"num_beams": 4,
|
172 |
+
"pad_token_id": 1,
|
173 |
+
"processor_class": "TrOCRProcessor",
|
174 |
+
"tie_word_embeddings": false,
|
175 |
+
"torch_dtype": "float32",
|
176 |
+
"transformers_version": "4.44.2"
|
177 |
+
}
|
data.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from PIL import Image
|
4 |
+
import json
|
5 |
+
from transformers import TrOCRProcessor
|
6 |
+
import pandas as pd
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
import glob
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
def prepare_data_frame(root_dir):
|
13 |
+
with open(root_dir) as f:
|
14 |
+
d = json.load(f)
|
15 |
+
filename = [d[i]["word_id"]+ ".png" for i in range(len(d))]
|
16 |
+
text = [d[i]["text"] for i in range(len(d))]
|
17 |
+
data = {'filename': filename, 'text': text}
|
18 |
+
df = pd.DataFrame(data=data)
|
19 |
+
return df
|
20 |
+
|
21 |
+
|
22 |
+
class AphaPenDataset(Dataset):
|
23 |
+
def __init__(self, root_dir, df, processor, transform=None, max_target_length=128):
|
24 |
+
self.root_dir = root_dir
|
25 |
+
self.df= df
|
26 |
+
# self.filename, self.text = self.prepare_data()
|
27 |
+
self.processor = processor
|
28 |
+
self.max_target_length = max_target_length
|
29 |
+
self.transform = transform
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.df)
|
33 |
+
|
34 |
+
def __getitem__(self, idx):
|
35 |
+
# get file name + text
|
36 |
+
file_name = self.df.filename[idx]
|
37 |
+
text = self.df.text[idx]
|
38 |
+
# prepare image (i.e. resize + normalize)
|
39 |
+
image = Image.open(self.root_dir + file_name).convert("RGB")
|
40 |
+
if self.transform is not None:
|
41 |
+
image = self.transform(image)
|
42 |
+
img=transforms.ToPILImage()(image)
|
43 |
+
img.save("/mnt/data1/Datasets/AlphaPen/transformed_images/" + file_name)
|
44 |
+
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
45 |
+
# add labels (input_ids) by encoding the text
|
46 |
+
labels = self.processor.tokenizer(text,
|
47 |
+
padding="max_length",
|
48 |
+
max_length=self.max_target_length).input_ids
|
49 |
+
# important: make sure that PAD tokens are ignored by the loss function
|
50 |
+
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
|
51 |
+
|
52 |
+
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
53 |
+
return encoding
|
54 |
+
|
55 |
+
def prepare_data(self):
|
56 |
+
with open(self.path_json) as f:
|
57 |
+
d = json.load(f)
|
58 |
+
filename = [d[i]["image_id"]+ ".png" for i in range(len(d))]
|
59 |
+
text = [d[i]["text"] for i in range(len(d))]
|
60 |
+
return filename, text
|
61 |
+
|
62 |
+
|
63 |
+
class AlphaPenPhi3Dataset(Dataset):
|
64 |
+
def __init__(self, root_dir, dataframe, tokenizer, max_length, image_size):
|
65 |
+
self.dataframe = dataframe
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
self.tokenizer.padding_side = 'left'
|
68 |
+
self.max_length = max_length
|
69 |
+
self.root_dir = root_dir
|
70 |
+
self.transform = transforms.Compose([
|
71 |
+
transforms.Resize((image_size, image_size)),
|
72 |
+
transforms.ToTensor()
|
73 |
+
])
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.dataframe)
|
77 |
+
|
78 |
+
|
79 |
+
def __getitem__(self, idx):
|
80 |
+
row = self.dataframe.iloc[idx]
|
81 |
+
text = f"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\n {row['text']} <|end|>"
|
82 |
+
image_path = self.root_dir + row['filename']
|
83 |
+
|
84 |
+
# Tokenize text
|
85 |
+
encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length)
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Load and transform image
|
89 |
+
image = Image.open(image_path).convert("RGB")
|
90 |
+
image = self.image_transform_function(image)
|
91 |
+
except (FileNotFoundError, IOError):
|
92 |
+
# Skip the sample if the image is not found
|
93 |
+
return None
|
94 |
+
|
95 |
+
labels = self.tokenizer(row['text'],
|
96 |
+
padding="max_length",
|
97 |
+
max_length=self.max_length).input_ids
|
98 |
+
# important: make sure that PAD tokens are ignored by the loss function
|
99 |
+
labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]
|
100 |
+
encodings['pixel_values'] = image
|
101 |
+
encodings['labels'] = labels
|
102 |
+
|
103 |
+
return {key: torch.tensor(val) for key, val in encodings.items()}
|
104 |
+
|
105 |
+
|
106 |
+
def image_transform_function(self, image):
|
107 |
+
image = self.transform(image)
|
108 |
+
return image
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
json_path = "/mnt/data1/Datasets/OCR/Alphapen/label_check/"
|
115 |
+
json_path_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/label_check/"
|
116 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
117 |
+
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
118 |
+
json_files = glob.glob(json_path + "*.json")
|
119 |
+
json_files_b2 = glob.glob(json_path_b2 + "*.json")
|
120 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
121 |
+
df_list_b1 = [prepare_data_frame(file) for file in json_files]
|
122 |
+
df_list_b2 = [prepare_data_frame(file) for file in json_files_b2]
|
123 |
+
# df_list = df_list_b1 + df_list_b2
|
124 |
+
df_b1 = pd.concat(df_list_b1)
|
125 |
+
df_b2 = pd.concat(df_list_b2)
|
126 |
+
|
127 |
+
df_b1.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b1.csv")
|
128 |
+
df_b2.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b2.csv")
|
129 |
+
# train_df, test_df = train_test_split(df, test_size=0.15)
|
130 |
+
# # we reset the indices to start from zero
|
131 |
+
# train_df.reset_index(drop=True, inplace=True)
|
132 |
+
# test_df.reset_index(drop=True, inplace=True)
|
133 |
+
# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
134 |
+
# train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
|
135 |
+
# eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
|
136 |
+
# print("Number of training examples:", len(train_dataset))
|
137 |
+
# print("Number of validation examples:", len(eval_dataset))
|
138 |
+
|
139 |
+
# encoding = train_dataset[0]
|
140 |
+
# for k,v in encoding.items():
|
141 |
+
# print(k, v.shape)
|
142 |
+
|
143 |
+
# image = Image.open(train_dataset.root_dir + df.filename[0]).convert("RGB")
|
144 |
+
# print('Label: '+df.text[0])
|
145 |
+
# print(image)
|
146 |
+
|
147 |
+
# labels = encoding['labels']
|
148 |
+
# print(labels)
|
149 |
+
|
150 |
+
# labels[labels == -100] = processor.tokenizer.pad_token_id
|
151 |
+
# label_str = processor.decode(labels, skip_special_tokens=True)
|
152 |
+
# print('Decoded Label:', label_str)
|
finetune_phi3_vision.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset, DatasetDict, Image
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from peft import LoraConfig
|
7 |
+
from transformers import AutoProcessor, BitsAndBytesConfig
|
8 |
+
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq
|
9 |
+
from datetime import datetime
|
10 |
+
import evaluate
|
11 |
+
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
12 |
+
from sklearn.model_selection import train_test_split
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
class MyDataCollator:
|
17 |
+
def __init__(self, processor):
|
18 |
+
self.processor = processor
|
19 |
+
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
20 |
+
processor.tokenizer.additional_special_tokens.index("<image>")
|
21 |
+
]
|
22 |
+
|
23 |
+
def __call__(self, examples):
|
24 |
+
texts = []
|
25 |
+
images = []
|
26 |
+
for example in examples:
|
27 |
+
image = example["image"]
|
28 |
+
# print(example["query"])
|
29 |
+
question = example["query"]
|
30 |
+
answer = example["answers"]
|
31 |
+
messages = [
|
32 |
+
{
|
33 |
+
"role": "user",
|
34 |
+
"content": [
|
35 |
+
{"type": "text", "text": "OCR the text in the image."},
|
36 |
+
{"type": "image"},
|
37 |
+
{"type": "text", "text": question}
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"role": "assistant",
|
42 |
+
"content": [
|
43 |
+
{"type": "text", "text": answer}
|
44 |
+
]
|
45 |
+
}
|
46 |
+
]
|
47 |
+
text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
48 |
+
texts.append(text.strip())
|
49 |
+
images.append([image])
|
50 |
+
|
51 |
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
52 |
+
|
53 |
+
labels = batch["input_ids"].clone()
|
54 |
+
# labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
|
55 |
+
batch["labels"] = labels
|
56 |
+
|
57 |
+
return batch
|
58 |
+
|
59 |
+
# Define train and test size.
|
60 |
+
TRAIN_SAMPLES = 1000
|
61 |
+
TEST_SAMPLES = 200
|
62 |
+
TEST_SIZE = 0.166 #
|
63 |
+
samp_list = [1, 15000, 30000, 45000, 60000, 70000]
|
64 |
+
|
65 |
+
# Define the directory containing the images.
|
66 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
67 |
+
df = pd.read_csv(df_path)
|
68 |
+
df.dropna(inplace=True)
|
69 |
+
df["id"] = range(df.shape[0])
|
70 |
+
df["query"] = "What is shown in this image?"
|
71 |
+
train_df, test_df = train_test_split(df, test_size=0.02, random_state=0)
|
72 |
+
|
73 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
74 |
+
image_paths_train = [root_dir + img for img in train_df.filename]
|
75 |
+
image_paths_test = [root_dir + img for img in test_df.filename]
|
76 |
+
|
77 |
+
# New batch
|
78 |
+
df_path_2 = "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv"
|
79 |
+
df_2 = pd.read_csv(df_path_2)
|
80 |
+
df_2.dropna(inplace=True)
|
81 |
+
df_2["id"] = range(df_2.shape[0])
|
82 |
+
df_2["query"] = "What is shown in this image?"
|
83 |
+
train_df_b2, test_df_b2 = train_test_split(df_2, test_size=0.01, random_state=0)
|
84 |
+
|
85 |
+
root_dir_2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_"
|
86 |
+
image_paths_2_train = [root_dir_2 + img for img in train_df_b2.filename]
|
87 |
+
image_paths_2_test = [root_dir_2 + img for img in test_df_b2.filename]
|
88 |
+
|
89 |
+
|
90 |
+
ids_test = range(test_df.shape[0] + test_df_b2.shape[0])
|
91 |
+
queries_test = test_df['query'].tolist() + test_df_b2['query'].tolist()
|
92 |
+
answers_test = test_df['text'].tolist() + test_df_b2['text'].tolist()
|
93 |
+
|
94 |
+
# Create the dataset dictionary.
|
95 |
+
|
96 |
+
|
97 |
+
eval_dataset_dict = {
|
98 |
+
'id': ids_test,
|
99 |
+
'image': image_paths_test + image_paths_2_test,
|
100 |
+
'query': queries_test,
|
101 |
+
'answers': answers_test
|
102 |
+
}
|
103 |
+
|
104 |
+
# Create the dataset.
|
105 |
+
|
106 |
+
eval_dataset = Dataset.from_dict(eval_dataset_dict)
|
107 |
+
|
108 |
+
# Cast the 'image' column to Image type.
|
109 |
+
|
110 |
+
eval_dataset = eval_dataset.cast_column("image", Image())
|
111 |
+
|
112 |
+
# Split the dataset into train and test.
|
113 |
+
# split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False)
|
114 |
+
|
115 |
+
# train_dataset = split_dataset["train"]
|
116 |
+
# eval_dataset = split_dataset["test"]
|
117 |
+
print(len(eval_dataset))
|
118 |
+
# Push the dataset on Hugging Face Hub.
|
119 |
+
# split_dataset.push_to_hub("NSTiwari/DocumentIDEFICS_QA")
|
120 |
+
|
121 |
+
|
122 |
+
# Define model ID
|
123 |
+
# model_id = "microsoft/Phi-3-vision-128k-instruct"
|
124 |
+
model_id = "HuggingFaceM4/idefics2-8b"
|
125 |
+
|
126 |
+
DEVICE = "cuda:0"
|
127 |
+
USE_LORA = False
|
128 |
+
USE_QLORA = True
|
129 |
+
|
130 |
+
processor = AutoProcessor.from_pretrained(
|
131 |
+
model_id,
|
132 |
+
do_image_splitting=False
|
133 |
+
)
|
134 |
+
|
135 |
+
# print(processor.tokenizer.additional_special_tokens.index("<image>"))
|
136 |
+
if USE_QLORA or USE_LORA:
|
137 |
+
lora_config = LoraConfig(
|
138 |
+
r=64,
|
139 |
+
lora_alpha=16,
|
140 |
+
lora_dropout=0.1,
|
141 |
+
# target_modules= [
|
142 |
+
# "q_proj",
|
143 |
+
# "k_proj",
|
144 |
+
# "v_proj",
|
145 |
+
# "o_proj",
|
146 |
+
# "gate_proj",
|
147 |
+
# "up_proj",
|
148 |
+
# # "down_proj",
|
149 |
+
# ],
|
150 |
+
target_modules = '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
|
151 |
+
use_dora=False if USE_QLORA else True,
|
152 |
+
init_lora_weights="gaussian"
|
153 |
+
)
|
154 |
+
if USE_QLORA:
|
155 |
+
bnb_config = BitsAndBytesConfig(
|
156 |
+
load_in_4bit=True,
|
157 |
+
bnb_4bit_quant_type="nf4",
|
158 |
+
bnb_4bit_compute_dtype=torch.float16
|
159 |
+
)
|
160 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
161 |
+
model_id,
|
162 |
+
torch_dtype=torch.float16,
|
163 |
+
quantization_config=bnb_config if USE_QLORA else None,
|
164 |
+
trust_remote_code=True
|
165 |
+
)
|
166 |
+
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
167 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
168 |
+
model.config.max_length= 128
|
169 |
+
model.add_adapter(lora_config)
|
170 |
+
model.enable_adapters()
|
171 |
+
else:
|
172 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
173 |
+
model_id,
|
174 |
+
torch_dtype=torch.float16,
|
175 |
+
_attn_implementation="flash_attention_2", # Need GPUs like A100 or H100.
|
176 |
+
trust_remote_code=True
|
177 |
+
).to(DEVICE)
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
data_collator = MyDataCollator(processor)
|
184 |
+
|
185 |
+
|
186 |
+
for samp in samp_list:
|
187 |
+
os.environ["WANDB_PROJECT"]="Alphapen"
|
188 |
+
# Create a list of other columns such as id, query, and answer.
|
189 |
+
ids_train = range(train_df.shape[0] + train_df_b2.shape[0])
|
190 |
+
queries_train = train_df['query'].tolist() + train_df_b2['query'].tolist()
|
191 |
+
answers_train = train_df['text'].tolist() + train_df_b2['text'].tolist()
|
192 |
+
|
193 |
+
train_dataset_dict = {
|
194 |
+
'id': ids_train,
|
195 |
+
'image': image_paths_train + image_paths_2_train,
|
196 |
+
'query': queries_train,
|
197 |
+
'answers': answers_train
|
198 |
+
}
|
199 |
+
|
200 |
+
train_dataset = Dataset.from_dict(train_dataset_dict)
|
201 |
+
train_dataset = train_dataset.cast_column("image", Image())
|
202 |
+
|
203 |
+
training_args = Seq2SeqTrainingArguments(
|
204 |
+
predict_with_generate=True,
|
205 |
+
output_dir = "idefics2",
|
206 |
+
learning_rate = 2e-4,
|
207 |
+
fp16 = True,
|
208 |
+
per_device_train_batch_size = 8,
|
209 |
+
per_device_eval_batch_size = 8,
|
210 |
+
gradient_accumulation_steps = 2,
|
211 |
+
dataloader_pin_memory = False,
|
212 |
+
save_total_limit = 3,
|
213 |
+
eval_strategy ="steps",
|
214 |
+
save_strategy = "steps",
|
215 |
+
eval_steps = 500,
|
216 |
+
save_steps = 1000,
|
217 |
+
max_steps = 5000,
|
218 |
+
logging_steps = 10,
|
219 |
+
remove_unused_columns = False,
|
220 |
+
push_to_hub=True,
|
221 |
+
label_names = ["labels"],
|
222 |
+
load_best_model_at_end = False,
|
223 |
+
report_to = "wandb",
|
224 |
+
optim = "paged_adamw_8bit",
|
225 |
+
# run_name=f"idefics2-vision-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",
|
226 |
+
run_name="idefics2-vision-LoRA-" + str(samp),
|
227 |
+
hub_model_id="hadrakey/alphapen_idefics2_" + str(samp),
|
228 |
+
)
|
229 |
+
|
230 |
+
def compute_metrics(pred):
|
231 |
+
# accuracy_metric = evaluate.load("precision")
|
232 |
+
cer_metric = evaluate.load("cer")
|
233 |
+
|
234 |
+
labels_ids = pred.label_ids
|
235 |
+
pred_ids = pred.predictions
|
236 |
+
# print(pred_ids)
|
237 |
+
# print(labels_ids)
|
238 |
+
# max_length = max(pred_ids.shape[1], labels_ids.shape[1])
|
239 |
+
# generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
|
240 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
241 |
+
pred_str = [word.lower() for word in pred_str]
|
242 |
+
# print(pred_str)
|
243 |
+
# pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
244 |
+
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
245 |
+
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
|
246 |
+
label_str = [word.lower() for word in label_str]
|
247 |
+
# print(label_str)
|
248 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
249 |
+
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
|
250 |
+
|
251 |
+
return {"cer": cer}
|
252 |
+
|
253 |
+
|
254 |
+
trainer = Seq2SeqTrainer(
|
255 |
+
model = model,
|
256 |
+
args = training_args,
|
257 |
+
data_collator = data_collator,
|
258 |
+
train_dataset = train_dataset,
|
259 |
+
eval_dataset = eval_dataset,
|
260 |
+
compute_metrics=compute_metrics,
|
261 |
+
)
|
262 |
+
|
263 |
+
trainer.train()
|
finetuner_usloath.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Example inspired from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct
|
2 |
+
|
3 |
+
# Import necessary libraries
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
from transformers import AutoModelForCausalLM
|
7 |
+
from transformers import AutoProcessor
|
8 |
+
from transformers import BitsAndBytesConfig
|
9 |
+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
10 |
+
import torch
|
11 |
+
import pandas as pd
|
12 |
+
from torchmetrics.text import CharErrorRate
|
13 |
+
from peft import LoraConfig, get_peft_model
|
14 |
+
from data import AlphaPenPhi3Dataset
|
15 |
+
from sklearn.model_selection import train_test_split
|
16 |
+
from datetime import datetime
|
17 |
+
import os
|
18 |
+
import evaluate
|
19 |
+
# tqdm.pandas()
|
20 |
+
os.environ["WANDB_PROJECT"]="Alphapen"
|
21 |
+
|
22 |
+
# Define model ID
|
23 |
+
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
24 |
+
# Load data
|
25 |
+
|
26 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
27 |
+
df = pd.read_csv(df_path)
|
28 |
+
df.dropna(inplace=True)
|
29 |
+
train_df, test_df = train_test_split(df, test_size=0.15, random_state=0)
|
30 |
+
# we reset the indices to start from zero
|
31 |
+
train_df.reset_index(drop=True, inplace=True)
|
32 |
+
test_df.reset_index(drop=True, inplace=True)
|
33 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
34 |
+
|
35 |
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
36 |
+
tokenizer = processor.tokenizer
|
37 |
+
|
38 |
+
train_dataset = AlphaPenPhi3Dataset(root_dir=root_dir, dataframe=train_df, tokenizer=tokenizer, max_length=128, image_size=128)
|
39 |
+
eval_dataset = AlphaPenPhi3Dataset(root_dir=root_dir, dataframe=test_df.iloc[:10,], tokenizer=tokenizer, max_length=128, image_size=128)
|
40 |
+
|
41 |
+
print(train_dataset[0])
|
42 |
+
nf4_config = BitsAndBytesConfig(
|
43 |
+
load_in_4bit=True,
|
44 |
+
bnb_4bit_quant_type="nf4",
|
45 |
+
bnb_4bit_use_double_quant=True,
|
46 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
47 |
+
)
|
48 |
+
|
49 |
+
# Load model with 4-bit quantization and map to CUDA
|
50 |
+
model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
model_id,
|
52 |
+
device_map="auto",
|
53 |
+
trust_remote_code=True,
|
54 |
+
torch_dtype="auto",
|
55 |
+
quantization_config=nf4_config,
|
56 |
+
)
|
57 |
+
|
58 |
+
# set special tokens used for creating the decoder_input_ids from the labels
|
59 |
+
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
60 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
61 |
+
# make sure vocab size is set correctly
|
62 |
+
# model.config.vocab_size = model.config.decoder.vocab_size
|
63 |
+
# for peft
|
64 |
+
# model.vocab_size = model.config.decoder.vocab_size
|
65 |
+
|
66 |
+
# set beam search parameters
|
67 |
+
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
68 |
+
model.config.max_new_tokens= 128
|
69 |
+
model.config.early_stopping = True
|
70 |
+
model.config.no_repeat_ngram_size = 3
|
71 |
+
model.config.length_penalty = 2.0
|
72 |
+
model.config.num_beams = 4
|
73 |
+
|
74 |
+
|
75 |
+
# LoRa
|
76 |
+
lora_config = LoraConfig(
|
77 |
+
r=64,
|
78 |
+
lora_alpha=16,
|
79 |
+
lora_dropout=0.1,
|
80 |
+
# target_modules = 'all-linear'
|
81 |
+
target_modules=[
|
82 |
+
"q_proj",
|
83 |
+
"k_proj",
|
84 |
+
"v_proj",
|
85 |
+
"o_proj",
|
86 |
+
# "gate_proj",
|
87 |
+
# "up_proj",
|
88 |
+
# "down_proj",
|
89 |
+
],
|
90 |
+
)
|
91 |
+
# print(model)
|
92 |
+
# import torch
|
93 |
+
# from transformers import Conv1D
|
94 |
+
|
95 |
+
# def get_specific_layer_names(model):
|
96 |
+
# # Create a list to store the layer names
|
97 |
+
# layer_names = []
|
98 |
+
|
99 |
+
# # Recursively visit all modules and submodules
|
100 |
+
# for name, module in model.named_modules():
|
101 |
+
# # Check if the module is an instance of the specified layers
|
102 |
+
# if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, Conv1D)):
|
103 |
+
# # model name parsing
|
104 |
+
|
105 |
+
# layer_names.append('.'.join(name.split('.')[4:]).split('.')[0])
|
106 |
+
|
107 |
+
# return layer_names
|
108 |
+
|
109 |
+
# print(list(set(get_specific_layer_names(model))))
|
110 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
111 |
+
# model.to(device)
|
112 |
+
|
113 |
+
model = get_peft_model(model, lora_config)
|
114 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
115 |
+
model = model.to(device)
|
116 |
+
# print(model.vocab_size)
|
117 |
+
# run_name=f"Mistral-7B-SQL-QLoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}"
|
118 |
+
|
119 |
+
# # Step 3: Define the training arguments
|
120 |
+
training_args = Seq2SeqTrainingArguments(
|
121 |
+
predict_with_generate=True,
|
122 |
+
evaluation_strategy="steps",
|
123 |
+
per_device_train_batch_size=8,
|
124 |
+
per_device_eval_batch_size=8,
|
125 |
+
bf16=True,
|
126 |
+
bf16_full_eval=True,
|
127 |
+
output_dir="./",
|
128 |
+
logging_steps=100,
|
129 |
+
save_steps=1000,
|
130 |
+
eval_steps=100,
|
131 |
+
report_to="wandb",
|
132 |
+
run_name=f"phi3-vision-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",
|
133 |
+
optim="adamw_torch_fused",
|
134 |
+
lr_scheduler_type="cosine",
|
135 |
+
gradient_accumulation_steps=2,
|
136 |
+
learning_rate=1.0e-4,
|
137 |
+
max_steps=10000,
|
138 |
+
push_to_hub=True,
|
139 |
+
hub_model_id="hadrakey/alphapen_phi3",
|
140 |
+
)
|
141 |
+
|
142 |
+
def compute_metrics(pred):
|
143 |
+
# accuracy_metric = evaluate.load("precision")
|
144 |
+
cer_metric = evaluate.load("cer")
|
145 |
+
|
146 |
+
labels_ids = pred.label_ids
|
147 |
+
pred_ids = pred.predictions
|
148 |
+
print(labels_ids.shape, pred_ids.shape)
|
149 |
+
max_length = max(pred_ids.shape[1], labels_ids.shape[1])
|
150 |
+
|
151 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
152 |
+
print(pred_str)
|
153 |
+
# pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
154 |
+
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
|
155 |
+
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
|
156 |
+
print(label_str)
|
157 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
158 |
+
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
|
159 |
+
|
160 |
+
return {"cer": cer}
|
161 |
+
|
162 |
+
|
163 |
+
# # Step 5: Define the Trainer
|
164 |
+
trainer = Seq2SeqTrainer(
|
165 |
+
model=model,
|
166 |
+
tokenizer=tokenizer,
|
167 |
+
args=training_args,
|
168 |
+
compute_metrics=compute_metrics,
|
169 |
+
train_dataset=train_dataset,
|
170 |
+
eval_dataset=eval_dataset,
|
171 |
+
data_collator=default_data_collator
|
172 |
+
)
|
173 |
+
|
174 |
+
trainer.train()
|
idefics2/adapter_config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "HuggingFaceM4/idefics2-8b",
|
5 |
+
"bias": "none",
|
6 |
+
"fan_in_fan_out": false,
|
7 |
+
"inference_mode": false,
|
8 |
+
"init_lora_weights": "gaussian",
|
9 |
+
"layer_replication": null,
|
10 |
+
"layers_pattern": null,
|
11 |
+
"layers_to_transform": null,
|
12 |
+
"loftq_config": {},
|
13 |
+
"lora_alpha": 16,
|
14 |
+
"lora_dropout": 0.1,
|
15 |
+
"megatron_config": null,
|
16 |
+
"megatron_core": "megatron.core",
|
17 |
+
"modules_to_save": null,
|
18 |
+
"peft_type": "LORA",
|
19 |
+
"r": 64,
|
20 |
+
"rank_pattern": {},
|
21 |
+
"revision": null,
|
22 |
+
"target_modules": ".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
|
23 |
+
"task_type": null,
|
24 |
+
"use_dora": false,
|
25 |
+
"use_rslora": false
|
26 |
+
}
|
idefics2/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e38855b7b26c79a86d6bc42985348143f602714b85923b6fcf6793830f400de
|
3 |
+
size 746528304
|
idefics2/checkpoint-10000/adapter_config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "HuggingFaceM4/idefics2-8b",
|
5 |
+
"bias": "none",
|
6 |
+
"fan_in_fan_out": false,
|
7 |
+
"inference_mode": false,
|
8 |
+
"init_lora_weights": "gaussian",
|
9 |
+
"layer_replication": null,
|
10 |
+
"layers_pattern": null,
|
11 |
+
"layers_to_transform": null,
|
12 |
+
"loftq_config": {},
|
13 |
+
"lora_alpha": 16,
|
14 |
+
"lora_dropout": 0.1,
|
15 |
+
"megatron_config": null,
|
16 |
+
"megatron_core": "megatron.core",
|
17 |
+
"modules_to_save": null,
|
18 |
+
"peft_type": "LORA",
|
19 |
+
"r": 64,
|
20 |
+
"rank_pattern": {},
|
21 |
+
"revision": null,
|
22 |
+
"target_modules": ".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
|
23 |
+
"task_type": null,
|
24 |
+
"use_dora": false,
|
25 |
+
"use_rslora": false
|
26 |
+
}
|
idefics2/checkpoint-10000/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a529961f33fd865061f2d504f10e9dbda5d36ac583ca54c807b178a3eef0a02
|
3 |
+
size 746528304
|
idefics2/checkpoint-10000/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 1,
|
3 |
+
"eos_token_id": 2,
|
4 |
+
"max_length": 128,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.42.3"
|
7 |
+
}
|
idefics2/checkpoint-10000/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4f18fec102e19e47ad7a4dc5a698a67204bd5a3f9a5e592c8b3c510be2357ad
|
3 |
+
size 374548180
|
idefics2/checkpoint-10000/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4cb5638cdcc03d44751491c48ef9325c702b5d355c5ef610bd485b897821f63
|
3 |
+
size 14244
|
idefics2/checkpoint-10000/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5351a4d2054ce412bb25cc143ed6232593a8af839dadb43b7080a505b08f3f6
|
3 |
+
size 1064
|
idefics2/checkpoint-10000/trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
idefics2/checkpoint-10000/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42f9baac86ca26a549fd61c40400a3efd2e95f6a3486ca7a7482e10ccfbb4ac6
|
3 |
+
size 5368
|
idefics2/checkpoint-8000/adapter_config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "HuggingFaceM4/idefics2-8b",
|
5 |
+
"bias": "none",
|
6 |
+
"fan_in_fan_out": false,
|
7 |
+
"inference_mode": false,
|
8 |
+
"init_lora_weights": "gaussian",
|
9 |
+
"layer_replication": null,
|
10 |
+
"layers_pattern": null,
|
11 |
+
"layers_to_transform": null,
|
12 |
+
"loftq_config": {},
|
13 |
+
"lora_alpha": 16,
|
14 |
+
"lora_dropout": 0.1,
|
15 |
+
"megatron_config": null,
|
16 |
+
"megatron_core": "megatron.core",
|
17 |
+
"modules_to_save": null,
|
18 |
+
"peft_type": "LORA",
|
19 |
+
"r": 64,
|
20 |
+
"rank_pattern": {},
|
21 |
+
"revision": null,
|
22 |
+
"target_modules": ".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
|
23 |
+
"task_type": null,
|
24 |
+
"use_dora": false,
|
25 |
+
"use_rslora": false
|
26 |
+
}
|
idefics2/checkpoint-8000/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9beb18aad93c70c527415b17f7e7cfc2d3142ad6521c26cf5b27642f6cfd1d68
|
3 |
+
size 746528304
|
idefics2/checkpoint-8000/generation_config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bad_words_ids": [
|
4 |
+
[
|
5 |
+
32000
|
6 |
+
],
|
7 |
+
[
|
8 |
+
32001
|
9 |
+
]
|
10 |
+
],
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": [
|
13 |
+
2,
|
14 |
+
32002
|
15 |
+
],
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"transformers_version": "4.42.3"
|
18 |
+
}
|
idefics2/checkpoint-8000/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b52cc7ac0241e0f119dba0c6b72dc5f4d5a429b38bd94eb75edb5a358b4b644
|
3 |
+
size 374548180
|
idefics2/checkpoint-8000/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca311e524f1174ed34780c49c50e3110c40901c89b57036e41617328bbe51608
|
3 |
+
size 14244
|
idefics2/checkpoint-8000/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0b4f2b37c59e9aaf35b1999b2d2c84957e88f5548679884092460d7d6d53d20
|
3 |
+
size 1064
|
idefics2/checkpoint-8000/trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
idefics2/checkpoint-8000/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81357016d48c35c8aa3b0681c2713c513fb1ee31efc707674d85a48ef9eee341
|
3 |
+
size 5368
|
idefics2/checkpoint-9000/adapter_config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "HuggingFaceM4/idefics2-8b",
|
5 |
+
"bias": "none",
|
6 |
+
"fan_in_fan_out": false,
|
7 |
+
"inference_mode": false,
|
8 |
+
"init_lora_weights": "gaussian",
|
9 |
+
"layer_replication": null,
|
10 |
+
"layers_pattern": null,
|
11 |
+
"layers_to_transform": null,
|
12 |
+
"loftq_config": {},
|
13 |
+
"lora_alpha": 16,
|
14 |
+
"lora_dropout": 0.1,
|
15 |
+
"megatron_config": null,
|
16 |
+
"megatron_core": "megatron.core",
|
17 |
+
"modules_to_save": null,
|
18 |
+
"peft_type": "LORA",
|
19 |
+
"r": 64,
|
20 |
+
"rank_pattern": {},
|
21 |
+
"revision": null,
|
22 |
+
"target_modules": ".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
|
23 |
+
"task_type": null,
|
24 |
+
"use_dora": false,
|
25 |
+
"use_rslora": false
|
26 |
+
}
|
idefics2/checkpoint-9000/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:073760f03d80782f132b9c3b74c826df46d8654df042a472117a01290cb7e44f
|
3 |
+
size 746528304
|
idefics2/checkpoint-9000/generation_config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bad_words_ids": [
|
4 |
+
[
|
5 |
+
32000
|
6 |
+
],
|
7 |
+
[
|
8 |
+
32001
|
9 |
+
]
|
10 |
+
],
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": [
|
13 |
+
2,
|
14 |
+
32002
|
15 |
+
],
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"transformers_version": "4.42.3"
|
18 |
+
}
|
idefics2/checkpoint-9000/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdd19fe724219197bfb17e4e705cf8801986dcb1e617f05fabedaf8ec38279ee
|
3 |
+
size 374548180
|
idefics2/checkpoint-9000/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a6e481b24fb2e50ba449b0f66b256e83542f9ceba4a6efa543ab9acb0848a1b
|
3 |
+
size 14244
|
idefics2/checkpoint-9000/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23554f3e4062472629006034b4ce40b01e49bd325d1fb48661fcf4c6868ee807
|
3 |
+
size 1064
|
idefics2/checkpoint-9000/trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
idefics2/checkpoint-9000/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81357016d48c35c8aa3b0681c2713c513fb1ee31efc707674d85a48ef9eee341
|
3 |
+
size 5368
|
idefics2/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c886ec66a448f0680d0a46cd28b697b6899ecc0627e105de6d1eac26f3c78140
|
3 |
+
size 5368
|
inference.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
2 |
+
import pandas as pd
|
3 |
+
from PIL import Image
|
4 |
+
from torchmetrics.text import CharErrorRate
|
5 |
+
|
6 |
+
# Finetuned model
|
7 |
+
model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1")
|
8 |
+
model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000")
|
9 |
+
model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000")
|
10 |
+
model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000")
|
11 |
+
model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000")
|
12 |
+
model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000")
|
13 |
+
|
14 |
+
#Baseline
|
15 |
+
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
16 |
+
|
17 |
+
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
18 |
+
|
19 |
+
# Checked label
|
20 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
|
21 |
+
data = pd.read_csv(df_path)
|
22 |
+
data.dropna(inplace=True)
|
23 |
+
data.reset_index(inplace=True)
|
24 |
+
sample = data.iloc[:50,:]
|
25 |
+
|
26 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
|
27 |
+
|
28 |
+
inf_baseline = []
|
29 |
+
inf_finetune_1 = []
|
30 |
+
inf_finetune_2 = []
|
31 |
+
inf_finetune_3 = []
|
32 |
+
inf_finetune_4 = []
|
33 |
+
inf_finetune_5 = []
|
34 |
+
inf_finetune_6 = []
|
35 |
+
|
36 |
+
cer_fine_1 = []
|
37 |
+
cer_fine_2 = []
|
38 |
+
cer_fine_3 = []
|
39 |
+
cer_fine_4 = []
|
40 |
+
cer_fine_5 = []
|
41 |
+
cer_fine_6 = []
|
42 |
+
cer_base = []
|
43 |
+
|
44 |
+
cer_metric = CharErrorRate()
|
45 |
+
|
46 |
+
for idx in range(len(sample)):
|
47 |
+
image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB")
|
48 |
+
|
49 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
50 |
+
generated_ids_base = model_base.generate(pixel_values)
|
51 |
+
generated_ids_fine_1 = model_finetune_1.generate(pixel_values)
|
52 |
+
generated_ids_fine_2= model_finetune_2.generate(pixel_values)
|
53 |
+
generated_ids_fine_3 = model_finetune_3.generate(pixel_values)
|
54 |
+
generated_ids_fine_4 = model_finetune_4.generate(pixel_values)
|
55 |
+
generated_ids_fine_5 = model_finetune_5.generate(pixel_values)
|
56 |
+
generated_ids_fine_6 = model_finetune_6.generate(pixel_values)
|
57 |
+
|
58 |
+
generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
|
59 |
+
generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0]
|
60 |
+
generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0]
|
61 |
+
generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0]
|
62 |
+
generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0]
|
63 |
+
generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0]
|
64 |
+
generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0]
|
65 |
+
|
66 |
+
cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy())
|
67 |
+
cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy())
|
68 |
+
cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy())
|
69 |
+
cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy())
|
70 |
+
cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy())
|
71 |
+
cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy())
|
72 |
+
cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy())
|
73 |
+
|
74 |
+
inf_baseline.append(generated_text_base)
|
75 |
+
inf_finetune_1.append(generated_text_fine_1)
|
76 |
+
inf_finetune_2.append(generated_text_fine_2)
|
77 |
+
inf_finetune_3.append(generated_text_fine_3)
|
78 |
+
inf_finetune_4.append(generated_text_fine_4)
|
79 |
+
inf_finetune_5.append(generated_text_fine_5)
|
80 |
+
inf_finetune_6.append(generated_text_fine_6)
|
81 |
+
|
82 |
+
sample["Baseline"]=inf_baseline
|
83 |
+
sample["Finetune_1"]=inf_finetune_1
|
84 |
+
sample["Finetune_2"]=inf_finetune_2
|
85 |
+
sample["Finetune_3"]=inf_finetune_3
|
86 |
+
sample["Finetune_4"]=inf_finetune_4
|
87 |
+
sample["Finetune_5"]=inf_finetune_5
|
88 |
+
sample["Finetune_6"]=inf_finetune_6
|
89 |
+
|
90 |
+
sample["cer_1"]=cer_fine_1
|
91 |
+
sample["cer_2"]=cer_fine_2
|
92 |
+
sample["cer_3"]=cer_fine_3
|
93 |
+
sample["cer_4"]=cer_fine_4
|
94 |
+
sample["cer_5"]=cer_fine_5
|
95 |
+
sample["cer_6"]=cer_fine_6
|
96 |
+
sample["cer_base"]=cer_base
|
97 |
+
|
98 |
+
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv")
|
inference_idefics2.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import requests
|
3 |
+
from transformers import AutoModelForCausalLM
|
4 |
+
from transformers import AutoProcessor
|
5 |
+
from transformers import BitsAndBytesConfig
|
6 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForVision2Seq
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
+
from torchmetrics.text import CharErrorRate
|
10 |
+
from peft import PeftModel, PeftConfig
|
11 |
+
from torchmetrics.text import CharErrorRate
|
12 |
+
from datasets import Dataset, DatasetDict, Image
|
13 |
+
# Define train and test size.
|
14 |
+
TRAIN_SAMPLES = 1000
|
15 |
+
TEST_SAMPLES = 200
|
16 |
+
TEST_SIZE = 0.166 #
|
17 |
+
DEVICE = "cuda:0"
|
18 |
+
peft_model_id = "hadrakey/alphapen_idefics2_finetune_v1"
|
19 |
+
|
20 |
+
config = PeftConfig.from_pretrained(peft_model_id)
|
21 |
+
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)
|
22 |
+
base_model = AutoModelForVision2Seq.from_pretrained(config.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
|
23 |
+
model = PeftModel.from_pretrained(base_model, peft_model_id)
|
24 |
+
model = model.to(DEVICE)
|
25 |
+
|
26 |
+
# Define the directory containing the images.
|
27 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
|
28 |
+
df = pd.read_csv(df_path)
|
29 |
+
df.dropna(inplace=True)
|
30 |
+
sample = df.iloc[:5000,:]
|
31 |
+
sample.reset_index(inplace=True)
|
32 |
+
sample["id"] = range(sample.shape[0])
|
33 |
+
sample["query"] = "What is shown in this image?"
|
34 |
+
|
35 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
36 |
+
image_paths = [root_dir + img for img in sample.filename]
|
37 |
+
# Create a list of other columns such as id, query, and answer.
|
38 |
+
ids = sample['id'].tolist()
|
39 |
+
queries = sample['query'].tolist()
|
40 |
+
answers = sample['text'].tolist()
|
41 |
+
|
42 |
+
# Create the dataset dictionary.
|
43 |
+
dataset_dict = {
|
44 |
+
'id': ids,
|
45 |
+
'image': image_paths,
|
46 |
+
'query': queries,
|
47 |
+
'answers': answers
|
48 |
+
}
|
49 |
+
|
50 |
+
# Create the dataset.
|
51 |
+
dataset = Dataset.from_dict(dataset_dict)
|
52 |
+
|
53 |
+
# Cast the 'image' column to Image type.
|
54 |
+
dataset = dataset.cast_column("image", Image())
|
55 |
+
|
56 |
+
# Split the dataset into train and test.
|
57 |
+
# split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False)
|
58 |
+
|
59 |
+
# train_dataset = split_dataset["train"]
|
60 |
+
# eval_dataset = split_dataset["test"]
|
61 |
+
|
62 |
+
cer_metric = CharErrorRate()
|
63 |
+
cer_idefics = []
|
64 |
+
idefics_output = []
|
65 |
+
|
66 |
+
for idx in range(len(dataset)):
|
67 |
+
|
68 |
+
test_example = dataset[idx]
|
69 |
+
|
70 |
+
image = test_example["image"]
|
71 |
+
query = test_example["query"]
|
72 |
+
|
73 |
+
|
74 |
+
messages = [
|
75 |
+
{
|
76 |
+
"role": "user",
|
77 |
+
"content": [
|
78 |
+
{"type": "text", "text": "Answer briefly."},
|
79 |
+
{"type": "image"},
|
80 |
+
{"type": "text", "text": query}
|
81 |
+
]
|
82 |
+
}
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
text = processor.apply_chat_template(messages, add_generation_prompt=True)
|
87 |
+
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
|
88 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
89 |
+
generated_ids = model.generate(**inputs, max_new_tokens=64)
|
90 |
+
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
|
91 |
+
idefics_output.append(generated_texts[0])
|
92 |
+
cer_idefics.append(cer_metric(generated_texts[0].lower(), test_example["answers"].lower()).detach().numpy())
|
93 |
+
# print(generated_texts, test_example["answers"], cer_idefics)
|
94 |
+
|
95 |
+
sample["idefics"] = idefics_output
|
96 |
+
sample["cer"] = cer_idefics
|
97 |
+
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "sample_idefics_v1.csv")
|
model.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Optional
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from datasets import load_dataset, Dataset, load_metric
|
9 |
+
from peft import LoraConfig
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback
|
12 |
+
|
13 |
+
|
14 |
+
from trl import SFTTrainer, is_xpu_available
|
15 |
+
from data import AphaPenDataset
|
16 |
+
import evaluate
|
17 |
+
from sklearn.model_selection import train_test_split
|
18 |
+
|
19 |
+
import torchvision.transforms as transforms
|
20 |
+
# from utils import compute_metrics
|
21 |
+
from src.calibrator import EncoderDecoderCalibrator
|
22 |
+
from src.loss import MarginLoss, KLRegularization
|
23 |
+
from src.similarity import CERSimilarity
|
24 |
+
import os
|
25 |
+
tqdm.pandas()
|
26 |
+
|
27 |
+
os.environ["WANDB_PROJECT"]="Alphapen"
|
28 |
+
# Define and parse arguments.
|
29 |
+
@dataclass
|
30 |
+
class ScriptArguments:
|
31 |
+
"""
|
32 |
+
The name of the OCR model we wish to fine with Seq2SeqTrainer
|
33 |
+
"""
|
34 |
+
|
35 |
+
model_name: Optional[str] = field(default="microsoft/trocr-base-handwritten", metadata={"help": "the model name"})
|
36 |
+
dataset_name: Optional[str] = field(
|
37 |
+
default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}
|
38 |
+
)
|
39 |
+
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
|
40 |
+
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
41 |
+
batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"})
|
42 |
+
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
|
43 |
+
gradient_accumulation_steps: Optional[int] = field(
|
44 |
+
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
45 |
+
)
|
46 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
|
47 |
+
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
|
48 |
+
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
49 |
+
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
50 |
+
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
|
51 |
+
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
52 |
+
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
53 |
+
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"})
|
54 |
+
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
|
55 |
+
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
56 |
+
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
57 |
+
max_length: Optional[int] = field(default=10, metadata={"help": "the maximum length"})
|
58 |
+
no_repeat_ngram_size: Optional[int] = field(default=3, metadata={"help": "the number of repeat"})
|
59 |
+
length_penalty: Optional[float] = field(default=2.0, metadata={"help": "the length of penalty"})
|
60 |
+
num_beams: Optional[int] = field(default=3, metadata={"help": "the number of beam search"})
|
61 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "Early stopping"})
|
62 |
+
save_steps: Optional[int] = field(
|
63 |
+
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"}
|
64 |
+
)
|
65 |
+
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
66 |
+
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
67 |
+
gradient_checkpointing: Optional[bool] = field(
|
68 |
+
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
|
69 |
+
)
|
70 |
+
gradient_checkpointing_kwargs: Optional[dict] = field(
|
71 |
+
default=None,
|
72 |
+
metadata={
|
73 |
+
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
|
74 |
+
},
|
75 |
+
)
|
76 |
+
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
|
77 |
+
|
78 |
+
parser = HfArgumentParser(ScriptArguments)
|
79 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
80 |
+
|
81 |
+
# # Step 1: Load the dataset
|
82 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
83 |
+
df = pd.read_csv(df_path)
|
84 |
+
df.dropna(inplace=True)
|
85 |
+
train_df, test_df = train_test_split(df, test_size=0.15, random_state=0)
|
86 |
+
# we reset the indices to start from zero
|
87 |
+
train_df.reset_index(drop=True, inplace=True)
|
88 |
+
test_df.reset_index(drop=True, inplace=True)
|
89 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
90 |
+
processor = TrOCRProcessor.from_pretrained(script_args.model_name)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
|
95 |
+
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
|
96 |
+
|
97 |
+
# Step 2: Load the model
|
98 |
+
if script_args.load_in_8bit and script_args.load_in_4bit:
|
99 |
+
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
100 |
+
elif script_args.load_in_8bit or script_args.load_in_4bit:
|
101 |
+
quantization_config = BitsAndBytesConfig(
|
102 |
+
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
|
103 |
+
)
|
104 |
+
# Copy the model to each device
|
105 |
+
device_map = (
|
106 |
+
{"": f"xpu:{Accelerator().local_process_index}"}
|
107 |
+
if is_xpu_available()
|
108 |
+
else {"": Accelerator().local_process_index}
|
109 |
+
)
|
110 |
+
torch_dtype = torch.bfloat16
|
111 |
+
else:
|
112 |
+
device_map = None
|
113 |
+
quantization_config = None
|
114 |
+
torch_dtype = None
|
115 |
+
|
116 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
117 |
+
script_args.model_name,
|
118 |
+
quantization_config=quantization_config,
|
119 |
+
device_map=device_map,
|
120 |
+
trust_remote_code=script_args.trust_remote_code,
|
121 |
+
torch_dtype=torch_dtype,
|
122 |
+
token=script_args.use_auth_token,
|
123 |
+
)
|
124 |
+
|
125 |
+
# set special tokens used for creating the decoder_input_ids from the labels
|
126 |
+
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
127 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
128 |
+
# make sure vocab size is set correctly
|
129 |
+
model.config.vocab_size = model.config.decoder.vocab_size
|
130 |
+
|
131 |
+
# set beam search parameters
|
132 |
+
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
133 |
+
model.config.max_length = script_args.max_length
|
134 |
+
model.config.early_stopping = script_args.early_stopping
|
135 |
+
model.config.no_repeat_ngram_size = script_args.no_repeat_ngram_size
|
136 |
+
model.config.length_penalty = script_args.length_penalty
|
137 |
+
model.config.num_beams = script_args.num_beams
|
138 |
+
|
139 |
+
tokenizer = processor.tokenizer
|
140 |
+
sim = CERSimilarity(tokenizer)
|
141 |
+
loss = MarginLoss(sim, beta=0.1, num_samples=60)
|
142 |
+
reg = KLRegularization(model)
|
143 |
+
calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15)
|
144 |
+
|
145 |
+
|
146 |
+
# # Step 3: Define the training arguments
|
147 |
+
training_args = Seq2SeqTrainingArguments(
|
148 |
+
predict_with_generate=True,
|
149 |
+
evaluation_strategy="steps",
|
150 |
+
per_device_train_batch_size=script_args.batch_size,
|
151 |
+
per_device_eval_batch_size=script_args.batch_size,
|
152 |
+
fp16=True,
|
153 |
+
output_dir=script_args.output_dir,
|
154 |
+
logging_steps=script_args.logging_steps,
|
155 |
+
save_steps=script_args.save_steps,
|
156 |
+
eval_steps=100,
|
157 |
+
save_total_limit=script_args.save_total_limit,
|
158 |
+
# load_best_model_at_end = True,
|
159 |
+
report_to=script_args.log_with,
|
160 |
+
num_train_epochs=script_args.num_train_epochs,
|
161 |
+
push_to_hub=script_args.push_to_hub,
|
162 |
+
hub_model_id=script_args.hub_model_id,
|
163 |
+
gradient_checkpointing=script_args.gradient_checkpointing,
|
164 |
+
# metric_for_best_model="eval/cer"
|
165 |
+
# TODO: uncomment that on the next release
|
166 |
+
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
# Step 4: Define a metric
|
171 |
+
|
172 |
+
def compute_metrics(pred):
|
173 |
+
# accuracy_metric = evaluate.load("precision")
|
174 |
+
cer_metric = evaluate.load("cer")
|
175 |
+
|
176 |
+
labels_ids = pred.label_ids
|
177 |
+
pred_ids = pred.predictions
|
178 |
+
|
179 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
180 |
+
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
181 |
+
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
|
182 |
+
|
183 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
184 |
+
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
|
185 |
+
|
186 |
+
return {"cer": cer}
|
187 |
+
|
188 |
+
early_stop = EarlyStoppingCallback(10, .001)
|
189 |
+
# # Step 5: Define the Trainer
|
190 |
+
trainer = Seq2SeqTrainer(
|
191 |
+
model=model,
|
192 |
+
tokenizer=processor.feature_extractor,
|
193 |
+
args=training_args,
|
194 |
+
compute_metrics=compute_metrics,
|
195 |
+
train_dataset=train_dataset,
|
196 |
+
eval_dataset=eval_dataset,
|
197 |
+
data_collator=default_data_collator,
|
198 |
+
# callbacks = [early_stop]
|
199 |
+
)
|
200 |
+
|
201 |
+
trainer.train()
|
202 |
+
|
203 |
+
# # Step 6: Save the model
|
204 |
+
# trainer.save_model(script_args.output_dir)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b14472ca382e9d96ea7efd3c778cbf0b73a412e31bc41cfec8d97e8988e6063d
|
3 |
+
size 1335747032
|
model_sft.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Optional
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from datasets import load_dataset, Dataset, load_metric
|
9 |
+
from peft import LoraConfig
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback
|
12 |
+
|
13 |
+
|
14 |
+
from trl import SFTTrainer, is_xpu_available
|
15 |
+
from data import AphaPenDataset
|
16 |
+
import evaluate
|
17 |
+
from sklearn.model_selection import train_test_split
|
18 |
+
|
19 |
+
import torchvision.transforms as transforms
|
20 |
+
# from utils import compute_metrics
|
21 |
+
|
22 |
+
tqdm.pandas()
|
23 |
+
|
24 |
+
|
25 |
+
# Define and parse arguments.
|
26 |
+
@dataclass
|
27 |
+
class ScriptArguments:
|
28 |
+
"""
|
29 |
+
The name of the OCR model we wish to fine with Seq2SeqTrainer
|
30 |
+
"""
|
31 |
+
|
32 |
+
model_name: Optional[str] = field(default="microsoft/trocr-base-handwritten", metadata={"help": "the model name"})
|
33 |
+
dataset_name: Optional[str] = field(
|
34 |
+
default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}
|
35 |
+
)
|
36 |
+
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
|
37 |
+
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
38 |
+
batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"})
|
39 |
+
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
|
40 |
+
gradient_accumulation_steps: Optional[int] = field(
|
41 |
+
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
42 |
+
)
|
43 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
|
44 |
+
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
|
45 |
+
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
46 |
+
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
47 |
+
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
|
48 |
+
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
49 |
+
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
50 |
+
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"})
|
51 |
+
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
|
52 |
+
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
53 |
+
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
54 |
+
max_length: Optional[int] = field(default=10, metadata={"help": "the maximum length"})
|
55 |
+
no_repeat_ngram_size: Optional[int] = field(default=3, metadata={"help": "the number of repeat"})
|
56 |
+
length_penalty: Optional[float] = field(default=2.0, metadata={"help": "the length of penalty"})
|
57 |
+
num_beams: Optional[int] = field(default=3, metadata={"help": "the number of beam search"})
|
58 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "Early stopping"})
|
59 |
+
save_steps: Optional[int] = field(
|
60 |
+
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"}
|
61 |
+
)
|
62 |
+
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
63 |
+
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
64 |
+
gradient_checkpointing: Optional[bool] = field(
|
65 |
+
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
|
66 |
+
)
|
67 |
+
gradient_checkpointing_kwargs: Optional[dict] = field(
|
68 |
+
default=None,
|
69 |
+
metadata={
|
70 |
+
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
|
71 |
+
},
|
72 |
+
)
|
73 |
+
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
|
74 |
+
|
75 |
+
parser = HfArgumentParser(ScriptArguments)
|
76 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
77 |
+
|
78 |
+
# # Step 1: Load the dataset
|
79 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
|
80 |
+
df = pd.read_csv(df_path)
|
81 |
+
df.dropna(inplace=True)
|
82 |
+
train_df, test_df = train_test_split(df, test_size=0.15, random_state=0)
|
83 |
+
# we reset the indices to start from zero
|
84 |
+
train_df.reset_index(drop=True, inplace=True)
|
85 |
+
test_df.reset_index(drop=True, inplace=True)
|
86 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
|
87 |
+
processor = TrOCRProcessor.from_pretrained(script_args.model_name)
|
88 |
+
|
89 |
+
# Transformation for training including augmentations
|
90 |
+
transform = transforms.Compose([
|
91 |
+
transforms.PILToTensor(),
|
92 |
+
transforms.RandomRotation(degrees=(0, 180))
|
93 |
+
])
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor, transform=transform)
|
98 |
+
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
|
99 |
+
|
100 |
+
# Step 2: Load the model
|
101 |
+
if script_args.load_in_8bit and script_args.load_in_4bit:
|
102 |
+
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
103 |
+
elif script_args.load_in_8bit or script_args.load_in_4bit:
|
104 |
+
quantization_config = BitsAndBytesConfig(
|
105 |
+
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
|
106 |
+
)
|
107 |
+
# Copy the model to each device
|
108 |
+
device_map = (
|
109 |
+
{"": f"xpu:{Accelerator().local_process_index}"}
|
110 |
+
if is_xpu_available()
|
111 |
+
else {"": Accelerator().local_process_index}
|
112 |
+
)
|
113 |
+
torch_dtype = torch.bfloat16
|
114 |
+
else:
|
115 |
+
device_map = None
|
116 |
+
quantization_config = None
|
117 |
+
torch_dtype = None
|
118 |
+
|
119 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
120 |
+
script_args.model_name,
|
121 |
+
quantization_config=quantization_config,
|
122 |
+
device_map=device_map,
|
123 |
+
trust_remote_code=script_args.trust_remote_code,
|
124 |
+
torch_dtype=torch_dtype,
|
125 |
+
token=script_args.use_auth_token,
|
126 |
+
)
|
127 |
+
|
128 |
+
# set special tokens used for creating the decoder_input_ids from the labels
|
129 |
+
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
130 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
131 |
+
# make sure vocab size is set correctly
|
132 |
+
model.config.vocab_size = model.config.decoder.vocab_size
|
133 |
+
|
134 |
+
# set beam search parameters
|
135 |
+
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
136 |
+
model.config.max_length = script_args.max_length
|
137 |
+
model.config.early_stopping = script_args.early_stopping
|
138 |
+
model.config.no_repeat_ngram_size = script_args.no_repeat_ngram_size
|
139 |
+
model.config.length_penalty = script_args.length_penalty
|
140 |
+
model.config.num_beams = script_args.num_beams
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
# # Step 3: Define the training arguments
|
146 |
+
training_args = Seq2SeqTrainingArguments(
|
147 |
+
predict_with_generate=True,
|
148 |
+
evaluation_strategy="steps",
|
149 |
+
# per_device_train_batch_size=script_args.batch_size,
|
150 |
+
# per_device_eval_batch_size=script_args.batch_size,
|
151 |
+
fp16=True,
|
152 |
+
output_dir=script_args.output_dir,
|
153 |
+
logging_steps=script_args.logging_steps,
|
154 |
+
save_steps=script_args.save_steps,
|
155 |
+
eval_steps=100,
|
156 |
+
save_total_limit=script_args.save_total_limit,
|
157 |
+
load_best_model_at_end = True,
|
158 |
+
report_to=script_args.log_with,
|
159 |
+
num_train_epochs=script_args.num_train_epochs,
|
160 |
+
push_to_hub=script_args.push_to_hub,
|
161 |
+
hub_model_id=script_args.hub_model_id,
|
162 |
+
gradient_checkpointing=script_args.gradient_checkpointing,
|
163 |
+
auto_find_batch_size=True,
|
164 |
+
metric_for_best_model="eval/cer"
|
165 |
+
# TODO: uncomment that on the next release
|
166 |
+
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
# Step 4: Define a metric
|
171 |
+
|
172 |
+
def compute_metrics(pred):
|
173 |
+
# accuracy_metric = evaluate.load("precision")
|
174 |
+
cer_metric = evaluate.load("cer")
|
175 |
+
|
176 |
+
labels_ids = pred.label_ids
|
177 |
+
pred_ids = pred.predictions
|
178 |
+
|
179 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
180 |
+
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
181 |
+
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
|
182 |
+
|
183 |
+
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
184 |
+
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
|
185 |
+
|
186 |
+
return {"cer": cer}
|
187 |
+
|
188 |
+
early_stop = EarlyStoppingCallback(10, .001)
|
189 |
+
|
190 |
+
# Step 5: Define the LoraConfig
|
191 |
+
if script_args.use_peft:
|
192 |
+
peft_config = LoraConfig(
|
193 |
+
r=script_args.peft_lora_r,
|
194 |
+
lora_alpha=script_args.peft_lora_alpha,
|
195 |
+
bias="none",
|
196 |
+
task_type="CAUSAL_LM",
|
197 |
+
target_modules="all-linear"
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
peft_config = None
|
201 |
+
# # Step 6: Define the Trainer
|
202 |
+
trainer = SFTTrainer(
|
203 |
+
model=model,
|
204 |
+
tokenizer=processor.feature_extractor,
|
205 |
+
args=training_args,
|
206 |
+
compute_metrics=compute_metrics,
|
207 |
+
train_dataset=train_dataset,
|
208 |
+
eval_dataset=eval_dataset,
|
209 |
+
data_collator=default_data_collator,
|
210 |
+
peft_config=peft_config,
|
211 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
|
212 |
+
)
|
213 |
+
|
214 |
+
trainer.train()
|
215 |
+
|
216 |
+
# # Step 6: Save the model
|
217 |
+
# trainer.save_model(script_args.output_dir)
|
phi3/checkpoint-25/adapter_config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "HuggingFaceM4/idefics2-8b",
|
5 |
+
"bias": "none",
|
6 |
+
"fan_in_fan_out": false,
|
7 |
+
"inference_mode": false,
|
8 |
+
"init_lora_weights": "gaussian",
|
9 |
+
"layer_replication": null,
|
10 |
+
"layers_pattern": null,
|
11 |
+
"layers_to_transform": null,
|
12 |
+
"loftq_config": {},
|
13 |
+
"lora_alpha": 16,
|
14 |
+
"lora_dropout": 0.1,
|
15 |
+
"megatron_config": null,
|
16 |
+
"megatron_core": "megatron.core",
|
17 |
+
"modules_to_save": null,
|
18 |
+
"peft_type": "LORA",
|
19 |
+
"r": 64,
|
20 |
+
"rank_pattern": {},
|
21 |
+
"revision": null,
|
22 |
+
"target_modules": ".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
|
23 |
+
"task_type": null,
|
24 |
+
"use_dora": false,
|
25 |
+
"use_rslora": false
|
26 |
+
}
|
phi3/checkpoint-25/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c804cede36291fb0feb5cee74f4ffeeaec1178864af130c873b410c6f1fe1a18
|
3 |
+
size 746528304
|
phi3/checkpoint-25/generation_config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bad_words_ids": [
|
4 |
+
[
|
5 |
+
32000
|
6 |
+
],
|
7 |
+
[
|
8 |
+
32001
|
9 |
+
]
|
10 |
+
],
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": [
|
13 |
+
2,
|
14 |
+
32002
|
15 |
+
],
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"transformers_version": "4.42.3"
|
18 |
+
}
|
phi3/checkpoint-25/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d93bc5e374a2ab92fce0a68bbe1800baf7f1ea49f5528c1bd22a6bb987d7a79
|
3 |
+
size 374547732
|
phi3/checkpoint-25/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bc84e110eb9a1be206e02c97bf5c5d7133033f306401f2e818d8847834cab9f
|
3 |
+
size 14244
|
phi3/checkpoint-25/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:630fe59a784956405be1a950d9ce52e5bf6a2f1c12f3a8bd4f3869766a5850cd
|
3 |
+
size 1064
|
phi3/checkpoint-25/trainer_state.json
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 0.008670019074041963,
|
5 |
+
"eval_steps": 10,
|
6 |
+
"global_step": 25,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.0017340038148083927,
|
13 |
+
"grad_norm": 11.838030815124512,
|
14 |
+
"learning_rate": 0.00017600000000000002,
|
15 |
+
"loss": 12.8857,
|
16 |
+
"step": 5
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 0.0034680076296167853,
|
20 |
+
"grad_norm": 0.8977920413017273,
|
21 |
+
"learning_rate": 0.00013600000000000003,
|
22 |
+
"loss": 0.6798,
|
23 |
+
"step": 10
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"epoch": 0.0034680076296167853,
|
27 |
+
"eval_loss": 0.2337629497051239,
|
28 |
+
"eval_runtime": 675.29,
|
29 |
+
"eval_samples_per_second": 13.597,
|
30 |
+
"eval_steps_per_second": 1.7,
|
31 |
+
"step": 10
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"epoch": 0.0052020114444251775,
|
35 |
+
"grad_norm": 0.34665364027023315,
|
36 |
+
"learning_rate": 9.6e-05,
|
37 |
+
"loss": 0.1571,
|
38 |
+
"step": 15
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"epoch": 0.006936015259233571,
|
42 |
+
"grad_norm": 0.26853781938552856,
|
43 |
+
"learning_rate": 5.6000000000000006e-05,
|
44 |
+
"loss": 0.1088,
|
45 |
+
"step": 20
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"epoch": 0.006936015259233571,
|
49 |
+
"eval_loss": 0.09983003884553909,
|
50 |
+
"eval_runtime": 686.1959,
|
51 |
+
"eval_samples_per_second": 13.381,
|
52 |
+
"eval_steps_per_second": 1.673,
|
53 |
+
"step": 20
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"epoch": 0.008670019074041963,
|
57 |
+
"grad_norm": 0.22648762166500092,
|
58 |
+
"learning_rate": 1.6000000000000003e-05,
|
59 |
+
"loss": 0.0911,
|
60 |
+
"step": 25
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"logging_steps": 5,
|
64 |
+
"max_steps": 25,
|
65 |
+
"num_input_tokens_seen": 0,
|
66 |
+
"num_train_epochs": 1,
|
67 |
+
"save_steps": 25,
|
68 |
+
"stateful_callbacks": {
|
69 |
+
"TrainerControl": {
|
70 |
+
"args": {
|
71 |
+
"should_epoch_stop": false,
|
72 |
+
"should_evaluate": false,
|
73 |
+
"should_log": false,
|
74 |
+
"should_save": true,
|
75 |
+
"should_training_stop": true
|
76 |
+
},
|
77 |
+
"attributes": {}
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"total_flos": 1803863725327872.0,
|
81 |
+
"train_batch_size": 8,
|
82 |
+
"trial_name": null,
|
83 |
+
"trial_params": null
|
84 |
+
}
|
phi3/checkpoint-25/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f453fa98240e7ba800f54183433d2dfc198cd72c7267c34a5547639a1d49da5c
|
3 |
+
size 5112
|
phi3_ocr.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Example inspired from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct
|
2 |
+
|
3 |
+
# Import necessary libraries
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
from transformers import AutoModelForCausalLM
|
7 |
+
from transformers import AutoProcessor
|
8 |
+
from transformers import BitsAndBytesConfig
|
9 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
10 |
+
import torch
|
11 |
+
import pandas as pd
|
12 |
+
from torchmetrics.text import CharErrorRate
|
13 |
+
from peft import PeftModel, PeftConfig
|
14 |
+
|
15 |
+
# Define model ID
|
16 |
+
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
17 |
+
peft_model_id = "hadrakey/alphapen_phi3"
|
18 |
+
peft_model_id_new = "hadrakey/alphapen_new_large"
|
19 |
+
|
20 |
+
# Load processor
|
21 |
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
22 |
+
|
23 |
+
# phi3 finetuned
|
24 |
+
# config = PeftConfig.from_pretrained(peft_model_id)
|
25 |
+
|
26 |
+
# processor_fine = AutoProcessor.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)
|
27 |
+
|
28 |
+
# Finetuned model
|
29 |
+
# config_new = PeftConfig.from_pretrained(peft_model_id_new)
|
30 |
+
model_finetune = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_large")
|
31 |
+
# model_new_finetune = AutoModelForCausalLM.from_pretrained(config_new.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
|
32 |
+
|
33 |
+
# model_finetune_phi3 = AutoModelForCausalLM.from_pretrained("hadrakey/alphapen_phi3", trust_remote_code=True)
|
34 |
+
|
35 |
+
#Baseline
|
36 |
+
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
37 |
+
|
38 |
+
|
39 |
+
processor_ocr = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
40 |
+
# processor_ocr_new = AutoProcessor.from_pretrained(config_new.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
|
41 |
+
# Define BitsAndBytes configuration for 4-bit quantization
|
42 |
+
nf4_config = BitsAndBytesConfig(
|
43 |
+
load_in_4bit=True,
|
44 |
+
bnb_4bit_quant_type="nf4",
|
45 |
+
bnb_4bit_use_double_quant=True,
|
46 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
47 |
+
)
|
48 |
+
|
49 |
+
# Load model with 4-bit quantization and map to CUDA
|
50 |
+
model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
model_id,
|
52 |
+
device_map="cuda",
|
53 |
+
trust_remote_code=True,
|
54 |
+
torch_dtype="auto",
|
55 |
+
quantization_config=nf4_config,
|
56 |
+
)
|
57 |
+
|
58 |
+
# base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
|
59 |
+
|
60 |
+
# model_finetune_phi3 = PeftModel.from_pretrained(base_model, peft_model_id)
|
61 |
+
# Define initial chat message with image placeholder
|
62 |
+
messages = [{"role": "user", "content": """<|image_1|>\nThis image contains handwritten French characters forming a complete or partial word. The image is blurred, which makes recognition challenging. Please analyze the image to the best of your ability and provide your best guess of the French word or partial word shown, even if you're not certain. Follow these guidelines:
|
63 |
+
|
64 |
+
1. Examine the overall shape and any discernible character features.
|
65 |
+
2. Consider common French letter combinations and word patterns.
|
66 |
+
3. If you can only identify some characters, provide those as a partial word.
|
67 |
+
4. Make an educated guess based on what you can see, even if it's just a few letters.
|
68 |
+
5. If you can see any characters at all, avoid responding with "indiscernible."
|
69 |
+
|
70 |
+
Your response should be only the predicted French word or partial word, using lowercase letters unless capital letters are clearly visible. If you can see any characters or shapes at all, provide the OCR from the image.
|
71 |
+
"""}]
|
72 |
+
|
73 |
+
# messages = [{"role": "user", "content": """<|image_1|>\nWhat is shown is this images ? You should only output only your guess otherwise output the OCR.
|
74 |
+
# """}]
|
75 |
+
|
76 |
+
# Download image from URL
|
77 |
+
url = "https://images.unsplash.com/photo-1528834342297-fdefb9a5a92b?ixlib=rb-4.0.3&q=85&fm=jpg&crop=entropy&cs=srgb&dl=roonz-nl-vjDbHCjHlEY-unsplash.jpg&w=640"
|
78 |
+
# image = Image.open(requests.get(url, stream=True).raw)
|
79 |
+
|
80 |
+
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
|
81 |
+
data = pd.read_csv(df_path)
|
82 |
+
data.dropna(inplace=True)
|
83 |
+
data.reset_index(inplace=True)
|
84 |
+
sample = data.iloc[:5000,:]
|
85 |
+
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
|
86 |
+
# Prepare prompt with image token
|
87 |
+
prompt = processor.tokenizer.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
cer_metric = CharErrorRate()
|
91 |
+
phi_output=[]
|
92 |
+
phi_finetune_output=[]
|
93 |
+
inf_baseline = []
|
94 |
+
inf_finetune = []
|
95 |
+
inf_finetune_new = []
|
96 |
+
|
97 |
+
cer_phi = []
|
98 |
+
cer_phi_finetune = []
|
99 |
+
cer_trocr_fine_new = []
|
100 |
+
cer_trocr_fine = []
|
101 |
+
cer_trocr_base = []
|
102 |
+
for idx in range(len(sample)):
|
103 |
+
|
104 |
+
# idx=30 # choose the image
|
105 |
+
image = Image.open(root_dir + "final_cropped_rotated_" + data.filename[idx]).convert("RGB")
|
106 |
+
|
107 |
+
# Process prompt and image for model input
|
108 |
+
inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0")
|
109 |
+
|
110 |
+
# Generate text response using model
|
111 |
+
generate_ids = model.generate(
|
112 |
+
**inputs,
|
113 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
114 |
+
max_new_tokens=500,
|
115 |
+
do_sample=False,
|
116 |
+
)
|
117 |
+
|
118 |
+
# Remove input tokens from generated response
|
119 |
+
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
120 |
+
|
121 |
+
# Decode generated IDs to text
|
122 |
+
response = processor.batch_decode(
|
123 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
124 |
+
)[0]
|
125 |
+
phi_output.append(response)
|
126 |
+
cer_phi.append(cer_metric(response.lower(), data.text[idx].lower()).detach().numpy())
|
127 |
+
|
128 |
+
# Generate text response using model finetuned
|
129 |
+
# generate_ids_fine = model_finetune_phi3.generate(
|
130 |
+
# **inputs,
|
131 |
+
# eos_token_id=processor.tokenizer.eos_token_id,
|
132 |
+
# max_new_tokens=500,
|
133 |
+
# do_sample=False,
|
134 |
+
# )
|
135 |
+
|
136 |
+
# # Remove input tokens from generated response
|
137 |
+
# inputs = processor_fine(prompt, [image], return_tensors="pt").to("cuda:0")
|
138 |
+
# generate_ids_fine = generate_ids_fine[:, inputs["input_ids"].shape[1] :]
|
139 |
+
|
140 |
+
# Decode generated IDs to text
|
141 |
+
# response = processor.batch_decode(
|
142 |
+
# generate_ids_fine, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
143 |
+
# )[0]
|
144 |
+
# phi_finetune_output.append(response)
|
145 |
+
# cer_phi_finetune.append(cer_metric(response, data.text[idx]).detach().numpy())
|
146 |
+
|
147 |
+
# Trocr
|
148 |
+
pixel_values = processor_ocr(image, return_tensors="pt").pixel_values
|
149 |
+
generated_ids_base = model_base.generate(pixel_values)
|
150 |
+
generated_ids_fine = model_finetune.generate(pixel_values)
|
151 |
+
# generated_ids_fine_new = model_finetune_new.generate(pixel_values)
|
152 |
+
generated_text_base = processor_ocr.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
|
153 |
+
generated_text_fine= processor_ocr.batch_decode(generated_ids_fine, skip_special_tokens=True)[0]
|
154 |
+
# generated_text_fine_new= processor_ocr_new.batch_decode(generated_ids_fine_new, skip_special_tokens=True)[0]
|
155 |
+
|
156 |
+
inf_baseline.append(generated_text_base)
|
157 |
+
inf_finetune.append(generated_text_fine)
|
158 |
+
# inf_finetune_new.append(generated_text_fine_new)
|
159 |
+
|
160 |
+
# cer_trocr_fine_new.append(cer_metric(generated_text_fine_new, data.text[idx]).detach().numpy())
|
161 |
+
cer_trocr_fine.append(cer_metric(generated_text_fine.lower(), data.text[idx].lower()).detach().numpy())
|
162 |
+
cer_trocr_base.append(cer_metric(generated_text_base.lower(), data.text[idx].lower()).detach().numpy())
|
163 |
+
|
164 |
+
|
165 |
+
# Print the generated response
|
166 |
+
sample["phi3"]=phi_output
|
167 |
+
# sample["phi3_fine"]=phi_finetune_output
|
168 |
+
sample["Baseline"]=inf_baseline
|
169 |
+
sample["Finetune"]=inf_finetune
|
170 |
+
# sample["Finetune_new"]=inf_finetune_new
|
171 |
+
sample["cer_phi"]=cer_phi
|
172 |
+
# sample["cer_phi_fine"]=cer_phi_finetune
|
173 |
+
sample["cer_trocr_base"]=cer_trocr_base
|
174 |
+
sample["cer_trocr_fine"]=cer_trocr_fine
|
175 |
+
# sample["cer_trocr_fine_new"]=cer_trocr_fine_new
|
176 |
+
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "sample_data.csv")
|