hoang-quoc-trung commited on
Commit
3d52ce7
1 Parent(s): 76f844e

Upload 8 files

Browse files
requirements.txt ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ accelerate==0.23.0
3
+ aiofiles==23.2.1
4
+ aiohttp==3.8.6
5
+ aiosignal==1.3.1
6
+ albumentations==1.3.1
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ anyconfig==0.13.0
10
+ anyio==4.3.0
11
+ appdirs==1.4.4
12
+ asttokens==2.4.1
13
+ async-timeout==4.0.3
14
+ attrs==23.1.0
15
+ backcall==0.2.0
16
+ blinker==1.7.0
17
+ cachetools==5.3.2
18
+ certifi==2023.7.22
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ cmake==3.27.7
22
+ comm==0.2.0
23
+ contourpy==1.1.1
24
+ cycler==0.12.1
25
+ datasets==2.14.7
26
+ debugpy==1.8.0
27
+ decorator==5.1.1
28
+ dill==0.3.7
29
+ docker-pycreds==0.4.0
30
+ evaluate==0.4.1
31
+ exceptiongroup==1.2.0
32
+ executing==2.0.1
33
+ fastapi==0.110.0
34
+ ffmpy==0.3.2
35
+ filelock==3.13.1
36
+ fonttools==4.49.0
37
+ frozenlist==1.4.0
38
+ fsspec==2023.10.0
39
+ gitdb==4.0.11
40
+ GitPython==3.1.40
41
+ google-auth==2.23.4
42
+ google-auth-oauthlib==1.0.0
43
+ gradio==3.50.2
44
+ gradio_client==0.6.1
45
+ grpcio==1.59.2
46
+ h11==0.14.0
47
+ httpcore==1.0.4
48
+ httpx==0.27.0
49
+ huggingface-hub==0.20.3
50
+ idna==3.4
51
+ imageio==2.32.0
52
+ importlib-metadata==6.8.0
53
+ importlib_resources==6.1.2
54
+ ipykernel==6.25.2
55
+ ipython==8.12.3
56
+ jedi==0.19.1
57
+ Jinja2==3.1.2
58
+ jiwer==3.0.3
59
+ joblib==1.3.2
60
+ jsonschema==4.21.1
61
+ jsonschema-specifications==2023.12.1
62
+ jupyter_client==8.6.0
63
+ jupyter_core==5.5.0
64
+ kiwisolver==1.4.5
65
+ lazy_loader==0.3
66
+ Levenshtein==0.23.0
67
+ lightning==2.1.1
68
+ lightning-utilities==0.9.0
69
+ lion-pytorch==0.1.2
70
+ lit==17.0.5
71
+ Markdown==3.5.1
72
+ markdown-it-py==3.0.0
73
+ MarkupSafe==2.1.3
74
+ matplotlib==3.7.5
75
+ matplotlib-inline==0.1.6
76
+ mdurl==0.1.2
77
+ mpmath==1.3.0
78
+ multidict==6.0.4
79
+ multiprocess==0.70.15
80
+ munch==4.0.0
81
+ natsort==8.4.0
82
+ nest-asyncio==1.5.8
83
+ networkx==3.1
84
+ nltk==3.8.1
85
+ nougat-ocr==0.1.17
86
+ numpy==1.22.3
87
+ nvidia-cublas-cu11==11.10.3.66
88
+ nvidia-cuda-cupti-cu11==11.7.101
89
+ nvidia-cuda-nvrtc-cu11==11.7.99
90
+ nvidia-cuda-runtime-cu11==11.7.99
91
+ nvidia-cudnn-cu11==8.5.0.96
92
+ nvidia-cufft-cu11==10.9.0.58
93
+ nvidia-curand-cu11==10.2.10.91
94
+ nvidia-cusolver-cu11==11.4.0.1
95
+ nvidia-cusparse-cu11==11.7.4.91
96
+ nvidia-nccl-cu11==2.14.3
97
+ nvidia-nvtx-cu11==11.7.91
98
+ oauthlib==3.2.2
99
+ opencv-python-headless==4.8.1.78
100
+ orjson==3.9.10
101
+ packaging==23.2
102
+ pandas==2.0.3
103
+ parso==0.8.3
104
+ peft==0.8.2
105
+ pexpect==4.8.0
106
+ pickleshare==0.7.5
107
+ Pillow==10.0.1
108
+ pip==23.3.1
109
+ pkgutil_resolve_name==1.3.10
110
+ platformdirs==4.0.0
111
+ prompt-toolkit==3.0.41
112
+ protobuf==4.25.0
113
+ psutil==5.9.6
114
+ ptyprocess==0.7.0
115
+ pure-eval==0.2.2
116
+ pyarrow==14.0.1
117
+ pyarrow-hotfix==0.5
118
+ pyasn1==0.5.0
119
+ pyasn1-modules==0.3.0
120
+ pydantic==2.6.2
121
+ pydantic_core==2.16.3
122
+ pydeck==0.8.1b0
123
+ pydub==0.25.1
124
+ Pygments==2.16.1
125
+ pyparsing==3.1.1
126
+ pypdf==3.17.1
127
+ pypdfium2==4.24.0
128
+ python-dateutil==2.8.2
129
+ python-Levenshtein==0.23.0
130
+ python-multipart==0.0.9
131
+ pytorch-lightning==2.1.1
132
+ pytz==2023.3.post1
133
+ PyWavelets==1.4.1
134
+ PyYAML==6.0.1
135
+ pyzmq==25.1.1
136
+ qudida==0.0.4
137
+ rapidfuzz==3.5.2
138
+ referencing==0.33.0
139
+ regex==2023.10.3
140
+ requests==2.31.0
141
+ requests-oauthlib==1.3.1
142
+ responses==0.18.0
143
+ rich==13.7.1
144
+ rpds-py==0.18.0
145
+ rsa==4.9
146
+ ruamel.yaml==0.18.5
147
+ ruamel.yaml.clib==0.2.8
148
+ safetensors==0.4.0
149
+ scikit-image==0.21.0
150
+ scikit-learn==1.3.2
151
+ scipy==1.10.1
152
+ sconf==0.2.5
153
+ semantic-version==2.10.0
154
+ sentencepiece==0.1.99
155
+ sentry-sdk==1.37.0
156
+ setproctitle==1.3.3
157
+ setuptools==68.2.2
158
+ six==1.16.0
159
+ smmap==5.0.1
160
+ sniffio==1.3.1
161
+ stack-data==0.6.3
162
+ starlette==0.36.3
163
+ streamlit==1.33.0
164
+ sympy==1.12
165
+ tenacity==8.2.3
166
+ tensorboard==2.14.0
167
+ tensorboard-data-server==0.7.2
168
+ tensorboardX==2.6.2.2
169
+ threadpoolctl==3.2.0
170
+ tifffile==2023.7.10
171
+ timm==0.5.4
172
+ tokenizers==0.15.1
173
+ toml==0.10.2
174
+ toolz==0.12.1
175
+ torch==2.0.0
176
+ torchmetrics==1.2.0
177
+ torchvision==0.15.1
178
+ tornado==6.3.3
179
+ tqdm==4.66.1
180
+ traitlets==5.13.0
181
+ transformers==4.37.0
182
+ triton==2.0.0
183
+ typing_extensions==4.8.0
184
+ tzdata==2023.3
185
+ urllib3==2.1.0
186
+ uvicorn==0.27.1
187
+ wandb==0.16.0
188
+ watchdog==4.0.0
189
+ wcwidth==0.2.10
190
+ websockets==11.0.3
191
+ Werkzeug==3.0.1
192
+ wheel==0.41.3
193
+ xxhash==3.4.1
194
+ yarl==1.9.2
195
+ zipp==3.17.0
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ module = os.path.join(os.path.dirname(os.path.abspath(__file__)))
4
+ sys.path.append(module)
src/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ module = os.path.join(os.path.dirname(os.path.abspath(__file__)))
4
+ sys.path.append(module)
src/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (331 Bytes). View file
 
src/utils/__pycache__/common_utils.cpython-38.pyc ADDED
Binary file (3.09 kB). View file
 
src/utils/__pycache__/metrics.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
src/utils/common_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import torch
4
+ import numpy
5
+
6
+
7
+ def check_device(logger=None):
8
+ if torch.cuda.is_available():
9
+ device = torch.device("cuda")
10
+ logger.info("There are {} GPU(s) available.".format(torch.cuda.device_count()))
11
+ logger.info('We will use the GPU: {}'.format(torch.cuda.get_device_name(0)))
12
+ else:
13
+ logger.info('No GPU available, using the CPU instead.')
14
+ device = torch.device("cpu")
15
+ return device
16
+
17
+
18
+ def print_trainable_parameters(model, logger):
19
+ """
20
+ Prints the number of trainable parameters in the model.
21
+ """
22
+ trainable_params = 0
23
+ all_param = 0
24
+ for _, param in model.named_parameters():
25
+ all_param += param.numel()
26
+ if param.requires_grad:
27
+ trainable_params += param.numel()
28
+ logger.info(
29
+ "Total params: {}M ({}) || Trainable params: {} || Trainable: {}%".format(
30
+ round(all_param/1000000),
31
+ all_param,
32
+ trainable_params,
33
+ 100 * trainable_params / all_param
34
+ )
35
+ )
36
+
37
+
38
+ def save_log(
39
+ loss: float,
40
+ bleu: float,
41
+ edit_distance: float,
42
+ exact_match: float,
43
+ wer: float,
44
+ exprate: float,
45
+ exprate_error_1: float,
46
+ exprate_error_2: float,
47
+ exprate_error_3: float,
48
+ file_name="test_log.csv",
49
+ ):
50
+
51
+ os.makedirs('log', exist_ok=True)
52
+ file_path = os.path.join('log', file_name)
53
+ with open(file_path, mode="a", newline="") as csv_file:
54
+ fieldnames = [
55
+ "loss",
56
+ "bleu",
57
+ "edit_distance",
58
+ "exact_match",
59
+ "wer",
60
+ "exprate",
61
+ "exprate_error_1",
62
+ "exprate_error_2",
63
+ "exprate_error_3"
64
+ ]
65
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
66
+ # write the header row
67
+ if csv_file.tell() == 0:
68
+ writer.writeheader()
69
+ # write the data row
70
+ writer.writerow(
71
+ {
72
+ "loss": loss,
73
+ "bleu": bleu,
74
+ "edit_distance": edit_distance,
75
+ "exact_match": exact_match,
76
+ "wer": wer,
77
+ "exprate": exprate,
78
+ "exprate_error_1": exprate_error_1,
79
+ "exprate_error_2": exprate_error_2,
80
+ "exprate_error_3": exprate_error_3,
81
+ }
82
+ )
83
+
84
+
85
+ def cmp_result(label,rec):
86
+ dist_mat = numpy.zeros((len(label)+1, len(rec)+1),dtype='int32')
87
+ dist_mat[0,:] = range(len(rec) + 1)
88
+ dist_mat[:,0] = range(len(label) + 1)
89
+ for i in range(1, len(label) + 1):
90
+ for j in range(1, len(rec) + 1):
91
+ hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1])
92
+ ins_score = dist_mat[i,j-1] + 1
93
+ del_score = dist_mat[i-1, j] + 1
94
+ dist_mat[i,j] = min(hit_score, ins_score, del_score)
95
+ dist = dist_mat[len(label), len(rec)]
96
+ return dist, len(label)
97
+
98
+
99
+ def compute_exprate(predictions, references):
100
+ total_label = 0
101
+ total_line = 0
102
+ total_line_rec = 0
103
+ total_line_error_1 = 0
104
+ total_line_error_2 = 0
105
+ total_line_error_3 = 0
106
+ for i in range(len(references)):
107
+ pre = predictions[i].split()
108
+ ref = references[i].split()
109
+ dist, llen = cmp_result(pre, ref)
110
+ total_label += llen
111
+ total_line += 1
112
+ if dist == 0:
113
+ total_line_rec += 1
114
+ elif dist ==1:
115
+ total_line_error_1 +=1
116
+ elif dist ==2:
117
+ total_line_error_2 +=1
118
+ elif dist ==3:
119
+ total_line_error_3 +=1
120
+ exprate = float(total_line_rec)/total_line
121
+ error_1 = float(
122
+ total_line_error_1 + total_line_rec
123
+ )/total_line
124
+ error_2 = float(
125
+ total_line_error_2 + total_line_error_1 +total_line_rec
126
+ )/total_line
127
+ error_3 = float(
128
+ total_line_error_3 + total_line_error_2 + total_line_error_1 + total_line_rec
129
+ )/total_line
130
+ return exprate, error_1, error_2, error_3
src/utils/metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import evaluate
3
+ from nltk import edit_distance as compute_edit_distance
4
+ from src.utils.common_utils import compute_exprate
5
+
6
+
7
+ class Metrics:
8
+ def __init__(self, processor):
9
+ self.processor = processor
10
+ self.bleu = evaluate.load("bleu")
11
+ self.wer = evaluate.load("wer")
12
+ self.exact_match = evaluate.load("exact_match")
13
+
14
+ def compute_metrics(self, pred):
15
+ labels_ids = pred.label_ids
16
+ pred_ids = pred.predictions
17
+ pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
18
+ labels_ids[labels_ids == -100] = self.processor.tokenizer.pad_token_id
19
+ label_str = self.processor.tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
20
+
21
+ total_edit_distance, total_bleu, total_exact_match = 0, 0, 0
22
+ for i in range(len(pred_str)):
23
+ # Compute edit distance score
24
+ edit_distance = compute_edit_distance(
25
+ pred_str[i],
26
+ label_str[i]
27
+ )/max(len(pred_str[i]),len(label_str[i]))
28
+ total_edit_distance = total_edit_distance + edit_distance
29
+
30
+ # Compute bleu score
31
+ try:
32
+ bleu = self.bleu.compute(
33
+ predictions=[pred_str[i]],
34
+ references=[label_str[i]],
35
+ max_order=4 # Maximum n-gram order to use when computing BLEU score
36
+ )
37
+ total_bleu += bleu['bleu']
38
+ except ZeroDivisionError:
39
+ total_bleu+=0
40
+
41
+ # Compute exact match score
42
+ exact_match = self.exact_match.compute(
43
+ predictions=[pred_str[i]],
44
+ references=[label_str[i]],
45
+ regexes_to_ignore=[' ']
46
+ )
47
+ total_exact_match += exact_match['exact_match']
48
+ bleu = total_bleu / len(pred_str)
49
+ exact_match = total_exact_match / len(pred_str)
50
+ # Convert minimun edit distance score to maximun edit distance score
51
+ edit_distance = 1 - (total_edit_distance / len(pred_str))
52
+ # Compute word error rate score
53
+ wer = self.wer.compute(predictions=pred_str, references=label_str)
54
+ # Compute expression rate score
55
+ exprate, error_1, error_2, error_3 = compute_exprate(
56
+ predictions=pred_str,
57
+ references=label_str
58
+ )
59
+
60
+ return {
61
+ "bleu": round(bleu*100, 2),
62
+ "maximun_edit_distance": round(edit_distance*100, 2),
63
+ "exact_match": round(exact_match*100, 2),
64
+ "wer": round(wer*100, 2),
65
+ "exprate": round(exprate*100, 2),
66
+ "exprate_error_1": round(error_1*100, 2),
67
+ "exprate_error_2": round(error_2*100, 2),
68
+ "exprate_error_3": round(error_3*100, 2),
69
+ }