DD0101 commited on
Commit
c938124
1 Parent(s): ce2151a

first commit

Browse files
JointModel.png ADDED
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
data_loader.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ import torch
7
+ from torch.utils.data import TensorDataset
8
+ from utils import get_intent_labels, get_slot_labels
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class InputExample(object):
15
+ """
16
+ A single training/test example for simple sequence classification.
17
+
18
+ Args:
19
+ guid: Unique id for the example.
20
+ words: list. The words of the sequence.
21
+ intent_label: (Optional) string. The intent label of the example.
22
+ slot_labels: (Optional) list. The slot labels of the example.
23
+ """
24
+
25
+ def __init__(self, guid, words, intent_label=None, slot_labels=None):
26
+ self.guid = guid
27
+ self.words = words
28
+ self.intent_label = intent_label
29
+ self.slot_labels = slot_labels
30
+
31
+ def __repr__(self):
32
+ return str(self.to_json_string())
33
+
34
+ def to_dict(self):
35
+ """Serializes this instance to a Python dictionary."""
36
+ output = copy.deepcopy(self.__dict__)
37
+ return output
38
+
39
+ def to_json_string(self):
40
+ """Serializes this instance to a JSON string."""
41
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
42
+
43
+
44
+ class InputFeatures(object):
45
+ """A single set of features of data."""
46
+
47
+ def __init__(self, input_ids, attention_mask, token_type_ids, intent_label_id, slot_labels_ids):
48
+ self.input_ids = input_ids
49
+ self.attention_mask = attention_mask
50
+ self.token_type_ids = token_type_ids
51
+ self.intent_label_id = intent_label_id
52
+ self.slot_labels_ids = slot_labels_ids
53
+
54
+ def __repr__(self):
55
+ return str(self.to_json_string())
56
+
57
+ def to_dict(self):
58
+ """Serializes this instance to a Python dictionary."""
59
+ output = copy.deepcopy(self.__dict__)
60
+ return output
61
+
62
+ def to_json_string(self):
63
+ """Serializes this instance to a JSON string."""
64
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
65
+
66
+
67
+ class JointProcessor(object):
68
+ """Processor for the JointBERT data set """
69
+
70
+ def __init__(self, args):
71
+ self.args = args
72
+ self.intent_labels = get_intent_labels(args)
73
+ self.slot_labels = get_slot_labels(args)
74
+
75
+ self.input_text_file = "seq.in"
76
+ self.intent_label_file = "label"
77
+ self.slot_labels_file = "seq.out"
78
+
79
+ @classmethod
80
+ def _read_file(cls, input_file, quotechar=None):
81
+ """Reads a tab separated value file."""
82
+ with open(input_file, "r", encoding="utf-8") as f:
83
+ lines = []
84
+ for line in f:
85
+ lines.append(line.strip())
86
+ return lines
87
+
88
+ def _create_examples(self, texts, intents, slots, set_type):
89
+ """Creates examples for the training and dev sets."""
90
+ examples = []
91
+ for i, (text, intent, slot) in enumerate(zip(texts, intents, slots)):
92
+ guid = "%s-%s" % (set_type, i)
93
+ # 1. input_text
94
+ words = text.split() # Some are spaced twice
95
+ # 2. intent
96
+ intent_label = (
97
+ self.intent_labels.index(intent) if intent in self.intent_labels else self.intent_labels.index("UNK")
98
+ )
99
+ # 3. slot
100
+ slot_labels = []
101
+ for s in slot.split():
102
+ slot_labels.append(
103
+ self.slot_labels.index(s) if s in self.slot_labels else self.slot_labels.index("UNK")
104
+ )
105
+
106
+ assert len(words) == len(slot_labels)
107
+ examples.append(InputExample(guid=guid, words=words, intent_label=intent_label, slot_labels=slot_labels))
108
+ return examples
109
+
110
+ def get_examples(self, mode):
111
+ """
112
+ Args:
113
+ mode: train, dev, test
114
+ """
115
+ data_path = os.path.join(self.args.data_dir, self.args.token_level, mode)
116
+ logger.info("LOOKING AT {}".format(data_path))
117
+ return self._create_examples(
118
+ texts=self._read_file(os.path.join(data_path, self.input_text_file)),
119
+ intents=self._read_file(os.path.join(data_path, self.intent_label_file)),
120
+ slots=self._read_file(os.path.join(data_path, self.slot_labels_file)),
121
+ set_type=mode,
122
+ )
123
+
124
+
125
+ processors = {"syllable-level": JointProcessor, "word-level": JointProcessor}
126
+
127
+
128
+ def convert_examples_to_features(
129
+ examples,
130
+ max_seq_len,
131
+ tokenizer,
132
+ pad_token_label_id=-100,
133
+ cls_token_segment_id=0,
134
+ pad_token_segment_id=0,
135
+ sequence_a_segment_id=0,
136
+ mask_padding_with_zero=True,
137
+ ):
138
+ # Setting based on the current model type
139
+ cls_token = tokenizer.cls_token
140
+ sep_token = tokenizer.sep_token
141
+ unk_token = tokenizer.unk_token
142
+ pad_token_id = tokenizer.pad_token_id
143
+
144
+ features = []
145
+ for (ex_index, example) in enumerate(examples):
146
+ if ex_index % 5000 == 0:
147
+ logger.info("Writing example %d of %d" % (ex_index, len(examples)))
148
+
149
+ # Tokenize word by word (for NER)
150
+ tokens = []
151
+ slot_labels_ids = []
152
+ for word, slot_label in zip(example.words, example.slot_labels):
153
+ word_tokens = tokenizer.tokenize(word)
154
+ if not word_tokens:
155
+ word_tokens = [unk_token] # For handling the bad-encoded word
156
+ tokens.extend(word_tokens)
157
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
158
+ slot_labels_ids.extend([int(slot_label)] + [pad_token_label_id] * (len(word_tokens) - 1))
159
+
160
+ # Account for [CLS] and [SEP]
161
+ special_tokens_count = 2
162
+ if len(tokens) > max_seq_len - special_tokens_count:
163
+ tokens = tokens[: (max_seq_len - special_tokens_count)]
164
+ slot_labels_ids = slot_labels_ids[: (max_seq_len - special_tokens_count)]
165
+
166
+ # Add [SEP] token
167
+ tokens += [sep_token]
168
+ slot_labels_ids += [pad_token_label_id]
169
+ token_type_ids = [sequence_a_segment_id] * len(tokens)
170
+
171
+ # Add [CLS] token
172
+ tokens = [cls_token] + tokens
173
+ slot_labels_ids = [pad_token_label_id] + slot_labels_ids
174
+ token_type_ids = [cls_token_segment_id] + token_type_ids
175
+
176
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
177
+
178
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
179
+ # tokens are attended to.
180
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
181
+
182
+ # Zero-pad up to the sequence length.
183
+ padding_length = max_seq_len - len(input_ids)
184
+ input_ids = input_ids + ([pad_token_id] * padding_length)
185
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
186
+ token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
187
+ slot_labels_ids = slot_labels_ids + ([pad_token_label_id] * padding_length)
188
+
189
+ assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
190
+ assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(
191
+ len(attention_mask), max_seq_len
192
+ )
193
+ assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(
194
+ len(token_type_ids), max_seq_len
195
+ )
196
+ assert len(slot_labels_ids) == max_seq_len, "Error with slot labels length {} vs {}".format(
197
+ len(slot_labels_ids), max_seq_len
198
+ )
199
+
200
+ intent_label_id = int(example.intent_label)
201
+
202
+ if ex_index < 5:
203
+ logger.info("*** Example ***")
204
+ logger.info("guid: %s" % example.guid)
205
+ logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
206
+ logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
207
+ logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
208
+ logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
209
+ logger.info("intent_label: %s (id = %d)" % (example.intent_label, intent_label_id))
210
+ logger.info("slot_labels: %s" % " ".join([str(x) for x in slot_labels_ids]))
211
+
212
+ features.append(
213
+ InputFeatures(
214
+ input_ids=input_ids,
215
+ attention_mask=attention_mask,
216
+ token_type_ids=token_type_ids,
217
+ intent_label_id=intent_label_id,
218
+ slot_labels_ids=slot_labels_ids,
219
+ )
220
+ )
221
+
222
+ return features
223
+
224
+
225
+ def load_and_cache_examples(args, tokenizer, mode):
226
+ processor = processors[args.token_level](args)
227
+
228
+ # Load data features from cache or dataset file
229
+ cached_features_file = os.path.join(
230
+ args.data_dir,
231
+ "cached_{}_{}_{}_{}".format(
232
+ mode, args.token_level, list(filter(None, args.model_name_or_path.split("/"))).pop(), args.max_seq_len
233
+ ),
234
+ )
235
+
236
+ if os.path.exists(cached_features_file):
237
+ logger.info("Loading features from cached file %s", cached_features_file)
238
+ features = torch.load(cached_features_file)
239
+ else:
240
+ # Load data features from dataset file
241
+ logger.info("Creating features from dataset file at %s", args.data_dir)
242
+ if mode == "train":
243
+ examples = processor.get_examples("train")
244
+ elif mode == "dev":
245
+ examples = processor.get_examples("dev")
246
+ elif mode == "test":
247
+ examples = processor.get_examples("test")
248
+ else:
249
+ raise Exception("For mode, Only train, dev, test is available")
250
+
251
+ # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
252
+ pad_token_label_id = args.ignore_index
253
+ features = convert_examples_to_features(
254
+ examples, args.max_seq_len, tokenizer, pad_token_label_id=pad_token_label_id
255
+ )
256
+ logger.info("Saving features into cached file %s", cached_features_file)
257
+ torch.save(features, cached_features_file)
258
+
259
+ # Convert to Tensors and build dataset
260
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
261
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
262
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
263
+ all_intent_label_ids = torch.tensor([f.intent_label_id for f in features], dtype=torch.long)
264
+ all_slot_labels_ids = torch.tensor([f.slot_labels_ids for f in features], dtype=torch.long)
265
+
266
+ dataset = TensorDataset(
267
+ all_input_ids, all_attention_mask, all_token_type_ids, all_intent_label_ids, all_slot_labels_ids
268
+ )
269
+ return dataset
dataset_statistic.png ADDED
early_stopping.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class EarlyStopping:
8
+ """Early stops the training if validation loss doesn't improve after a given patience."""
9
+
10
+ def __init__(self, patience=7, verbose=False):
11
+ """
12
+ Args:
13
+ patience (int): How long to wait after last time validation loss improved.
14
+ Default: 7
15
+ verbose (bool): If True, prints a message for each validation loss improvement.
16
+ Default: False
17
+ """
18
+ self.patience = patience
19
+ self.verbose = verbose
20
+ self.counter = 0
21
+ self.best_score = None
22
+ self.early_stop = False
23
+ self.val_loss_min = np.Inf
24
+
25
+ def __call__(self, val_loss, model, args):
26
+ if args.tuning_metric == "loss":
27
+ score = -val_loss
28
+ else:
29
+ score = val_loss
30
+ if self.best_score is None:
31
+ self.best_score = score
32
+ self.save_checkpoint(val_loss, model, args)
33
+ elif score < self.best_score:
34
+ self.counter += 1
35
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
36
+ if self.counter >= self.patience:
37
+ self.early_stop = True
38
+ else:
39
+ self.best_score = score
40
+ self.save_checkpoint(val_loss, model, args)
41
+ self.counter = 0
42
+
43
+ def save_checkpoint(self, val_loss, model, args):
44
+ """Saves model when validation loss decreases or accuracy/f1 increases."""
45
+ if self.verbose:
46
+ if args.tuning_metric == "loss":
47
+ print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
48
+ else:
49
+ print(
50
+ f"{args.tuning_metric} increased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
51
+ )
52
+ model.save_pretrained(args.model_dir)
53
+ torch.save(args, os.path.join(args.model_dir, "training_args.bin"))
54
+ self.val_loss_min = val_loss
55
+
56
+ # # Save model checkpoint (Overwrite)
57
+ # if not os.path.exists(self.args.model_dir):
58
+ # os.makedirs(self.args.model_dir)
59
+ # model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
60
+ # model_to_save.save_pretrained(self.args.model_dir)
61
+
62
+ # # Save training arguments together with the trained model
63
+ # torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
64
+ # logger.info("Saving model checkpoint to %s", self.args.model_dir)
gradio_demo.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
10
+ from tqdm import tqdm
11
+ from utils import MODEL_CLASSES, get_intent_labels, get_slot_labels, init_logger, load_tokenizer
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def get_device(pred_config):
18
+ return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"
19
+
20
+
21
+ def get_args(pred_config):
22
+ args = torch.load(os.path.join(pred_config.model_dir, "training_args.bin"))
23
+
24
+ args.model_dir = pred_config.model_dir
25
+ args.data_dir = 'PhoATIS'
26
+
27
+ return args
28
+
29
+
30
+ def load_model(pred_config, args, device):
31
+ # Check whether model exists
32
+ if not os.path.exists(pred_config.model_dir):
33
+ raise Exception("Model doesn't exists! Train first!")
34
+
35
+ try:
36
+ model = MODEL_CLASSES[args.model_type][1].from_pretrained(
37
+ args.model_dir, args=args, intent_label_lst=get_intent_labels(args), slot_label_lst=get_slot_labels(args)
38
+ )
39
+ model.to(device)
40
+ model.eval()
41
+ logger.info("***** Model Loaded *****")
42
+ except Exception:
43
+ raise Exception("Some model files might be missing...")
44
+
45
+ return model
46
+
47
+ def convert_input_file_to_tensor_dataset(
48
+ lines,
49
+ pred_config,
50
+ args,
51
+ tokenizer,
52
+ pad_token_label_id,
53
+ cls_token_segment_id=0,
54
+ pad_token_segment_id=0,
55
+ sequence_a_segment_id=0,
56
+ mask_padding_with_zero=True,
57
+ ):
58
+ # Setting based on the current model type
59
+ cls_token = tokenizer.cls_token
60
+ sep_token = tokenizer.sep_token
61
+ unk_token = tokenizer.unk_token
62
+ pad_token_id = tokenizer.pad_token_id
63
+
64
+ all_input_ids = []
65
+ all_attention_mask = []
66
+ all_token_type_ids = []
67
+ all_slot_label_mask = []
68
+
69
+ for words in lines:
70
+ tokens = []
71
+ slot_label_mask = []
72
+ for word in words:
73
+ word_tokens = tokenizer.tokenize(word)
74
+ if not word_tokens:
75
+ word_tokens = [unk_token] # For handling the bad-encoded word
76
+ tokens.extend(word_tokens)
77
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
78
+ slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
79
+
80
+ # Account for [CLS] and [SEP]
81
+ special_tokens_count = 2
82
+ if len(tokens) > args.max_seq_len - special_tokens_count:
83
+ tokens = tokens[: (args.max_seq_len - special_tokens_count)]
84
+ slot_label_mask = slot_label_mask[: (args.max_seq_len - special_tokens_count)]
85
+
86
+ # Add [SEP] token
87
+ tokens += [sep_token]
88
+ token_type_ids = [sequence_a_segment_id] * len(tokens)
89
+ slot_label_mask += [pad_token_label_id]
90
+
91
+ # Add [CLS] token
92
+ tokens = [cls_token] + tokens
93
+ token_type_ids = [cls_token_segment_id] + token_type_ids
94
+ slot_label_mask = [pad_token_label_id] + slot_label_mask
95
+
96
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
97
+
98
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
99
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
100
+
101
+ # Zero-pad up to the sequence length.
102
+ padding_length = args.max_seq_len - len(input_ids)
103
+ input_ids = input_ids + ([pad_token_id] * padding_length)
104
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
105
+ token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
106
+ slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
107
+
108
+ all_input_ids.append(input_ids)
109
+ all_attention_mask.append(attention_mask)
110
+ all_token_type_ids.append(token_type_ids)
111
+ all_slot_label_mask.append(slot_label_mask)
112
+
113
+ # Change to Tensor
114
+ all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
115
+ all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
116
+ all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
117
+ all_slot_label_mask = torch.tensor(all_slot_label_mask, dtype=torch.long)
118
+
119
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_slot_label_mask)
120
+
121
+ return dataset
122
+
123
+ def predict(text):
124
+
125
+ lines = text
126
+ dataset = convert_input_file_to_tensor_dataset(lines, pred_config, args, tokenizer, pad_token_label_id)
127
+
128
+ # Predict
129
+ sampler = SequentialSampler(dataset)
130
+ data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)
131
+
132
+ all_slot_label_mask = None
133
+ intent_preds = None
134
+ slot_preds = None
135
+
136
+ for batch in tqdm(data_loader, desc="Predicting"):
137
+ batch = tuple(t.to(device) for t in batch)
138
+ with torch.no_grad():
139
+ inputs = {
140
+ "input_ids": batch[0],
141
+ "attention_mask": batch[1],
142
+ "intent_label_ids": None,
143
+ "slot_labels_ids": None,
144
+ }
145
+ if args.model_type != "distilbert":
146
+ inputs["token_type_ids"] = batch[2]
147
+ outputs = model(**inputs)
148
+ _, (intent_logits, slot_logits) = outputs[:2]
149
+
150
+ # Intent Prediction
151
+ if intent_preds is None:
152
+ intent_preds = intent_logits.detach().cpu().numpy()
153
+ else:
154
+ intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
155
+
156
+ # Slot prediction
157
+ if slot_preds is None:
158
+ if args.use_crf:
159
+ # decode() in `torchcrf` returns list with best index directly
160
+ slot_preds = np.array(model.crf.decode(slot_logits))
161
+ else:
162
+ slot_preds = slot_logits.detach().cpu().numpy()
163
+ all_slot_label_mask = batch[3].detach().cpu().numpy()
164
+ else:
165
+ if args.use_crf:
166
+ slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0)
167
+ else:
168
+ slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
169
+ all_slot_label_mask = np.append(all_slot_label_mask, batch[3].detach().cpu().numpy(), axis=0)
170
+
171
+ intent_preds = np.argmax(intent_preds, axis=1)
172
+
173
+ if not args.use_crf:
174
+ slot_preds = np.argmax(slot_preds, axis=2)
175
+
176
+ slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
177
+ slot_preds_list = [[] for _ in range(slot_preds.shape[0])]
178
+
179
+ for i in range(slot_preds.shape[0]):
180
+ for j in range(slot_preds.shape[1]):
181
+ if all_slot_label_mask[i, j] != pad_token_label_id:
182
+ slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
183
+
184
+ return (lines, slot_preds_list, intent_preds)
185
+
186
+
187
+ def text_analysis(text):
188
+ text = [text.strip().split()]
189
+
190
+ words, slot_preds, intent_pred = predict(text)[0][0], predict(text)[1][0], predict(text)[2][0]
191
+
192
+ slot_tokens = []
193
+
194
+ for word, pred in zip(words, slot_preds):
195
+ if pred == 'O':
196
+ slot_tokens.extend([(word, None), (" ", None)])
197
+ elif pred[0] == 'I':
198
+ added_tokens = list(slot_tokens[-2])
199
+ added_tokens[0] += f' {word}'
200
+ slot_tokens[-2] = tuple(added_tokens)
201
+ else:
202
+ slot_tokens.extend([(word, pred[2:]), (" ", None)])
203
+
204
+ intent_label = intent_label_lst[intent_pred]
205
+
206
+ return slot_tokens, intent_label
207
+
208
+
209
+
210
+ if __name__ == "__main__":
211
+ init_logger()
212
+ parser = argparse.ArgumentParser()
213
+
214
+ # parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
215
+ # parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
216
+ parser.add_argument("--model_dir", default="./atis_model", type=str, help="Path to save, load model")
217
+
218
+ parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
219
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
220
+
221
+ pred_config = parser.parse_args()
222
+
223
+ # load model and args
224
+ args = get_args(pred_config)
225
+ device = get_device(pred_config)
226
+ model = load_model(pred_config, args, device)
227
+ logger.info(args)
228
+
229
+ intent_label_lst = get_intent_labels(args)
230
+ slot_label_lst = get_slot_labels(args)
231
+
232
+ # Convert input file to TensorDataset
233
+ pad_token_label_id = args.ignore_index
234
+ tokenizer = load_tokenizer(args)
235
+
236
+
237
+ examples = ["tôi muốn bay một chuyến khứ_hồi từ đà_nẵng đến đà_lạt",
238
+ ("giá vé khứ_hồi từ đà_nẵng đến vinh dưới 2 triệu đồng giá vé khứ_hồi từ quy nhơn đến vinh dưới 3 triệu đồng giá vé khứ_hồi từ"
239
+ " buôn_ma_thuột đến vinh dưới 4 triệu rưỡi"),
240
+ "cho tôi biết các chuyến bay đến đà_nẵng vào ngày 14 tháng sáu",
241
+ "những chuyến bay nào khởi_hành từ thành_phố hồ_chí_minh bay đến frankfurt mà nối chuyến ở singapore và hạ_cánh trước 9 giờ tối"]
242
+
243
+ demo = gr.Interface(
244
+ text_analysis,
245
+ gr.Textbox(placeholder="Enter sentence here...", label="Input"),
246
+ [gr.HighlightedText(label='Highlighted Output'), gr.Textbox(label='Intent Label')],
247
+ examples=examples,
248
+ )
249
+
250
+ demo.launch(share=True)
main.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from data_loader import load_and_cache_examples
4
+ from trainer import Trainer
5
+ from utils import MODEL_CLASSES, MODEL_PATH_MAP, init_logger, load_tokenizer, set_seed
6
+
7
+
8
+ def main(args):
9
+ init_logger()
10
+ set_seed(args)
11
+ tokenizer = load_tokenizer(args)
12
+
13
+ train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
14
+ dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
15
+ test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
16
+
17
+ trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
18
+
19
+ if args.do_train:
20
+ trainer.train()
21
+
22
+ if args.do_eval:
23
+ trainer.load_model()
24
+ trainer.evaluate("test")
25
+ if args.do_eval_dev:
26
+ trainer.load_model()
27
+ trainer.evaluate("dev")
28
+
29
+
30
+ if __name__ == "__main__":
31
+ parser = argparse.ArgumentParser()
32
+
33
+ # parser.add_argument("--task", default=None, required=True, type=str, help="The name of the task to train")
34
+ parser.add_argument("--model_dir", default=None, required=True, type=str, help="Path to save, load model")
35
+ parser.add_argument("--data_dir", default="./PhoATIS", type=str, help="The input data dir")
36
+ parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
37
+ parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
38
+
39
+ parser.add_argument(
40
+ "--model_type",
41
+ default="phobert",
42
+ type=str,
43
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
44
+ )
45
+ parser.add_argument("--tuning_metric", default="loss", type=str, help="Metrics to tune when training")
46
+ parser.add_argument("--seed", type=int, default=1, help="random seed for initialization")
47
+ parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.")
48
+ parser.add_argument("--eval_batch_size", default=64, type=int, help="Batch size for evaluation.")
49
+ parser.add_argument(
50
+ "--max_seq_len", default=50, type=int, help="The maximum total input sequence length after tokenization."
51
+ )
52
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
53
+ parser.add_argument(
54
+ "--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform."
55
+ )
56
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
57
+ parser.add_argument(
58
+ "--gradient_accumulation_steps",
59
+ type=int,
60
+ default=1,
61
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
62
+ )
63
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
64
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
65
+ parser.add_argument(
66
+ "--max_steps",
67
+ default=-1,
68
+ type=int,
69
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
70
+ )
71
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
72
+ parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
73
+
74
+ parser.add_argument("--logging_steps", type=int, default=200, help="Log every X updates steps.")
75
+ parser.add_argument("--save_steps", type=int, default=200, help="Save checkpoint every X updates steps.")
76
+
77
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
78
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
79
+ parser.add_argument("--do_eval_dev", action="store_true", help="Whether to run eval on the dev set.")
80
+
81
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
82
+
83
+ parser.add_argument(
84
+ "--ignore_index",
85
+ default=0,
86
+ type=int,
87
+ help="Specifies a target value that is ignored and does not contribute to the input gradient",
88
+ )
89
+
90
+ parser.add_argument("--intent_loss_coef", type=float, default=0.5, help="Coefficient for the intent loss.")
91
+ parser.add_argument(
92
+ "--token_level",
93
+ type=str,
94
+ default="word-level",
95
+ help="Tokens are at syllable level or word level (Vietnamese) [word-level, syllable-level]",
96
+ )
97
+ parser.add_argument(
98
+ "--early_stopping",
99
+ type=int,
100
+ default=50,
101
+ help="Number of unincreased validation step to wait for early stopping",
102
+ )
103
+ parser.add_argument("--gpu_id", type=int, default=0, help="Select gpu id")
104
+ # CRF option
105
+ parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
106
+ # init pretrained
107
+ parser.add_argument("--pretrained", action="store_true", help="Whether to init model from pretrained base model")
108
+ parser.add_argument("--pretrained_path", default="./viatis_xlmr_crf", type=str, help="The pretrained model path")
109
+
110
+ # Slot-intent interaction
111
+ parser.add_argument(
112
+ "--use_intent_context_concat",
113
+ action="store_true",
114
+ help="Whether to feed context information of intent into slots vectors (simple concatenation)",
115
+ )
116
+ parser.add_argument(
117
+ "--use_intent_context_attention",
118
+ action="store_true",
119
+ help="Whether to feed context information of intent into slots vectors (dot product attention)",
120
+ )
121
+ parser.add_argument(
122
+ "--attention_embedding_size", type=int, default=200, help="hidden size of attention output vector"
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--slot_pad_label",
127
+ default="PAD",
128
+ type=str,
129
+ help="Pad token for slot label pad (to be ignore when calculate loss)",
130
+ )
131
+ parser.add_argument(
132
+ "--embedding_type", default="soft", type=str, help="Embedding type for intent vector (hard/soft)"
133
+ )
134
+ parser.add_argument("--use_attention_mask", action="store_true", help="Whether to use attention mask")
135
+
136
+ args = parser.parse_args()
137
+
138
+ args.model_name_or_path = MODEL_PATH_MAP[args.model_type]
139
+ main(args)
predict.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
8
+ from tqdm import tqdm
9
+ from utils import MODEL_CLASSES, get_intent_labels, get_slot_labels, init_logger, load_tokenizer
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_device(pred_config):
16
+ return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"
17
+
18
+
19
+ def get_args(pred_config):
20
+ args = torch.load(os.path.join(pred_config.model_dir, "training_args.bin"))
21
+
22
+ args.model_dir = 'JointBERT-CRF_PhoBERTencoder'
23
+ args.data_dir = 'PhoATIS'
24
+
25
+ return args
26
+
27
+
28
+ def load_model(pred_config, args, device):
29
+ # Check whether model exists
30
+ if not os.path.exists(pred_config.model_dir):
31
+ raise Exception("Model doesn't exists! Train first!")
32
+
33
+ try:
34
+ model = MODEL_CLASSES[args.model_type][1].from_pretrained(
35
+ args.model_dir, args=args, intent_label_lst=get_intent_labels(args), slot_label_lst=get_slot_labels(args)
36
+ )
37
+ model.to(device)
38
+ model.eval()
39
+ logger.info("***** Model Loaded *****")
40
+ except Exception:
41
+ raise Exception("Some model files might be missing...")
42
+
43
+ return model
44
+
45
+
46
+ def read_input_file(pred_config):
47
+ lines = []
48
+ with open(pred_config.input_file, "r", encoding="utf-8") as f:
49
+ for line in f:
50
+ line = line.strip()
51
+ words = line.split()
52
+ lines.append(words)
53
+
54
+ return lines
55
+
56
+
57
+ def convert_input_file_to_tensor_dataset(
58
+ lines,
59
+ pred_config,
60
+ args,
61
+ tokenizer,
62
+ pad_token_label_id,
63
+ cls_token_segment_id=0,
64
+ pad_token_segment_id=0,
65
+ sequence_a_segment_id=0,
66
+ mask_padding_with_zero=True,
67
+ ):
68
+ # Setting based on the current model type
69
+ cls_token = tokenizer.cls_token
70
+ sep_token = tokenizer.sep_token
71
+ unk_token = tokenizer.unk_token
72
+ pad_token_id = tokenizer.pad_token_id
73
+
74
+ all_input_ids = []
75
+ all_attention_mask = []
76
+ all_token_type_ids = []
77
+ all_slot_label_mask = []
78
+
79
+ for words in lines:
80
+ tokens = []
81
+ slot_label_mask = []
82
+ for word in words:
83
+ word_tokens = tokenizer.tokenize(word)
84
+ if not word_tokens:
85
+ word_tokens = [unk_token] # For handling the bad-encoded word
86
+ tokens.extend(word_tokens)
87
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
88
+ slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
89
+
90
+ # Account for [CLS] and [SEP]
91
+ special_tokens_count = 2
92
+ if len(tokens) > args.max_seq_len - special_tokens_count:
93
+ tokens = tokens[: (args.max_seq_len - special_tokens_count)]
94
+ slot_label_mask = slot_label_mask[: (args.max_seq_len - special_tokens_count)]
95
+
96
+ # Add [SEP] token
97
+ tokens += [sep_token]
98
+ token_type_ids = [sequence_a_segment_id] * len(tokens)
99
+ slot_label_mask += [pad_token_label_id]
100
+
101
+ # Add [CLS] token
102
+ tokens = [cls_token] + tokens
103
+ token_type_ids = [cls_token_segment_id] + token_type_ids
104
+ slot_label_mask = [pad_token_label_id] + slot_label_mask
105
+
106
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
107
+
108
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
109
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
110
+
111
+ # Zero-pad up to the sequence length.
112
+ padding_length = args.max_seq_len - len(input_ids)
113
+ input_ids = input_ids + ([pad_token_id] * padding_length)
114
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
115
+ token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
116
+ slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
117
+
118
+ all_input_ids.append(input_ids)
119
+ all_attention_mask.append(attention_mask)
120
+ all_token_type_ids.append(token_type_ids)
121
+ all_slot_label_mask.append(slot_label_mask)
122
+
123
+ # Change to Tensor
124
+ all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
125
+ all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
126
+ all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
127
+ all_slot_label_mask = torch.tensor(all_slot_label_mask, dtype=torch.long)
128
+
129
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_slot_label_mask)
130
+
131
+ return dataset
132
+
133
+
134
+ def predict(pred_config):
135
+ # load model and args
136
+ args = get_args(pred_config)
137
+ device = get_device(pred_config)
138
+ model = load_model(pred_config, args, device)
139
+ logger.info(args)
140
+
141
+ intent_label_lst = get_intent_labels(args)
142
+ slot_label_lst = get_slot_labels(args)
143
+
144
+ # Convert input file to TensorDataset
145
+ pad_token_label_id = args.ignore_index
146
+ tokenizer = load_tokenizer(args)
147
+ lines = read_input_file(pred_config)
148
+ dataset = convert_input_file_to_tensor_dataset(lines, pred_config, args, tokenizer, pad_token_label_id)
149
+
150
+ # Predict
151
+ sampler = SequentialSampler(dataset)
152
+ data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)
153
+
154
+ all_slot_label_mask = None
155
+ intent_preds = None
156
+ slot_preds = None
157
+
158
+ for batch in tqdm(data_loader, desc="Predicting"):
159
+ batch = tuple(t.to(device) for t in batch)
160
+ with torch.no_grad():
161
+ inputs = {
162
+ "input_ids": batch[0],
163
+ "attention_mask": batch[1],
164
+ "intent_label_ids": None,
165
+ "slot_labels_ids": None,
166
+ }
167
+ if args.model_type != "distilbert":
168
+ inputs["token_type_ids"] = batch[2]
169
+ outputs = model(**inputs)
170
+ _, (intent_logits, slot_logits) = outputs[:2]
171
+
172
+ # Intent Prediction
173
+ if intent_preds is None:
174
+ intent_preds = intent_logits.detach().cpu().numpy()
175
+ else:
176
+ intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
177
+
178
+ # Slot prediction
179
+ if slot_preds is None:
180
+ if args.use_crf:
181
+ # decode() in `torchcrf` returns list with best index directly
182
+ slot_preds = np.array(model.crf.decode(slot_logits))
183
+ else:
184
+ slot_preds = slot_logits.detach().cpu().numpy()
185
+ all_slot_label_mask = batch[3].detach().cpu().numpy()
186
+ else:
187
+ if args.use_crf:
188
+ slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0)
189
+ else:
190
+ slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
191
+ all_slot_label_mask = np.append(all_slot_label_mask, batch[3].detach().cpu().numpy(), axis=0)
192
+
193
+ intent_preds = np.argmax(intent_preds, axis=1)
194
+
195
+ if not args.use_crf:
196
+ slot_preds = np.argmax(slot_preds, axis=2)
197
+
198
+ slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
199
+ slot_preds_list = [[] for _ in range(slot_preds.shape[0])]
200
+
201
+ for i in range(slot_preds.shape[0]):
202
+ for j in range(slot_preds.shape[1]):
203
+ if all_slot_label_mask[i, j] != pad_token_label_id:
204
+ slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
205
+
206
+ # Write to output file
207
+ with open(pred_config.output_file, "w", encoding="utf-8") as f:
208
+ for words, slot_preds, intent_pred in zip(lines, slot_preds_list, intent_preds):
209
+ line = ""
210
+ for word, pred in zip(words, slot_preds):
211
+ if pred == "O":
212
+ line = line + word + " "
213
+ else:
214
+ line = line + "[{}:{}] ".format(word, pred)
215
+ f.write("<{}> -> {}\n".format(intent_label_lst[intent_pred], line.strip()))
216
+
217
+ logger.info("Prediction Done!")
218
+
219
+
220
+ if __name__ == "__main__":
221
+ init_logger()
222
+ parser = argparse.ArgumentParser()
223
+
224
+ parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
225
+ parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
226
+ parser.add_argument("--model_dir", default="./atis_model", type=str, help="Path to save, load model")
227
+
228
+ parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
229
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
230
+
231
+ pred_config = parser.parse_args()
232
+ predict(pred_config)
predict.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python3 predict.py --input_file data/viatis/test/seq.in \
2
+ --output_file predictions.txt \
3
+ --model_dir viatis_phobert_crf_attn/4e-5/0.15
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ transformers
3
+ seqeval
4
+ pytorch-crf
5
+ tensorflow
6
+ sentencepiece
7
+ tensorboard
8
+ numpy>=1.21.2
9
+ tqdm
10
+ typing_extensions
11
+ protobuf<5,>=3.20.3
run_jointBERT-CRF_PhoBERTencoder.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export lr=3e-5
2
+ export c=0.6
3
+ export s=100
4
+ echo "${lr}"
5
+ export MODEL_DIR=JointBERT-CRF_PhoBERTencoder
6
+ export MODEL_DIR=$MODEL_DIR"/"$lr"/"$c"/"$s
7
+ echo "${MODEL_DIR}"
8
+ python3 main.py --token_level word-level \
9
+ --model_type phobert \
10
+ --model_dir $MODEL_DIR \
11
+ --data_dir PhoATIS \
12
+ --seed $s \
13
+ --do_train \
14
+ --do_eval \
15
+ --save_steps 140 \
16
+ --logging_steps 140 \
17
+ --num_train_epochs 50 \
18
+ --tuning_metric mean_intent_slot \
19
+ --use_crf \
20
+ --gpu_id 0 \
21
+ --embedding_type soft \
22
+ --intent_loss_coef $c \
23
+ --learning_rate $lr
run_jointBERT-CRF_XLM-Rencoder.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export lr=4e-5
2
+ export c=0.45
3
+ export s=10
4
+ echo "${lr}"
5
+ export MODEL_DIR=JointBERT-CRF_XLM-Rencoder
6
+ export MODEL_DIR=$MODEL_DIR"/"$lr"/"$c"/"$s
7
+ echo "${MODEL_DIR}"
8
+ python3 main.py --token_level syllable-level \
9
+ --model_type xlmr \
10
+ --model_dir $MODEL_DIR \
11
+ --data_dir PhoATIS \
12
+ --seed $s \
13
+ --do_train \
14
+ --do_eval \
15
+ --save_steps 140 \
16
+ --logging_steps 140 \
17
+ --num_train_epochs 50 \
18
+ --tuning_metric mean_intent_slot \
19
+ --use_crf \
20
+ --gpu_id 0 \
21
+ --embedding_type soft \
22
+ --intent_loss_coef $c \
23
+ --learning_rate $lr
run_jointIDSF_PhoBERTencoder.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As we initialize JointIDSF from JointBERT, user need to train a base model JointBERT first
2
+ ./run_jointBERT-CRF_PhoBERTencoder.sh
3
+ #Train JointIDSF
4
+ export lr=4e-5
5
+ export c=0.15
6
+ export s=100
7
+ echo "${lr}"
8
+ export MODEL_DIR=JointIDSF_PhoBERTencoder
9
+ export MODEL_DIR=$MODEL_DIR"/"$lr"/"$c"/"$s
10
+ echo "${MODEL_DIR}"
11
+ python3 main.py --token_level word-level \
12
+ --model_type phobert \
13
+ --model_dir $MODEL_DIR \
14
+ --data_dir PhoATIS \
15
+ --seed $s \
16
+ --do_train \
17
+ --do_eval \
18
+ --save_steps 140 \
19
+ --logging_steps 140 \
20
+ --num_train_epochs 50 \
21
+ --tuning_metric mean_intent_slot \
22
+ --use_intent_context_attention \
23
+ --attention_embedding_size 200 \
24
+ --use_crf \
25
+ --gpu_id 0 \
26
+ --embedding_type soft \
27
+ --intent_loss_coef $c \
28
+ --pretrained \
29
+ --pretrained_path JointBERT-CRF_PhoBERTencoder/3e-5/0.6/100 \
30
+ --learning_rate $lr
run_jointIDSF_XLM-Rencoder.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As we initialize JointIDSF from JointBERT, user need to train a base model JointBERT first
2
+ ./run_jointBERT-CRF_XLM-Rencoder.sh
3
+ #Train JointIDSF
4
+ export lr=3e-5
5
+ export c=0.25
6
+ export s=10
7
+ echo "${lr}"
8
+ export MODEL_DIR=JointIDSF_XLM-Rencoder
9
+ export MODEL_DIR=$MODEL_DIR"/"$lr"/"$c"/"$s
10
+ echo "${MODEL_DIR}"
11
+ python3 main.py --token_level syllable-level \
12
+ --model_type xlmr \
13
+ --model_dir $MODEL_DIR \
14
+ --data_dir PhoATIS \
15
+ --seed $s \
16
+ --do_train \
17
+ --do_eval \
18
+ --save_steps 140 \
19
+ --logging_steps 140 \
20
+ --num_train_epochs 50 \
21
+ --tuning_metric mean_intent_slot \
22
+ --use_intent_context_attention \
23
+ --attention_embedding_size 200 \
24
+ --use_crf \
25
+ --gpu_id 0 \
26
+ --embedding_type soft \
27
+ --intent_loss_coef $c \
28
+ --pretrained \
29
+ --pretrained_path JointBERT-CRF_XLM-Rencoder/4e-5/0.45/10 \
30
+ --learning_rate $lr
trainer.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ from early_stopping import EarlyStopping
7
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from tqdm.auto import tqdm, trange
10
+ from transformers import AdamW, get_linear_schedule_with_warmup
11
+ from utils import MODEL_CLASSES, compute_metrics, get_intent_labels, get_slot_labels
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Trainer(object):
18
+ def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
19
+ self.args = args
20
+ self.train_dataset = train_dataset
21
+ self.dev_dataset = dev_dataset
22
+ self.test_dataset = test_dataset
23
+
24
+ self.intent_label_lst = get_intent_labels(args)
25
+ self.slot_label_lst = get_slot_labels(args)
26
+ # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
27
+ self.pad_token_label_id = args.ignore_index
28
+ self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type]
29
+ # self.config = self.config_class.from_pretrained(model_path, finetuning_task=args.task)
30
+
31
+ if args.pretrained:
32
+ print(args.model_name_or_path)
33
+ self.model = self.model_class.from_pretrained(
34
+ args.pretrained_path,
35
+ args=args,
36
+ intent_label_lst=self.intent_label_lst,
37
+ slot_label_lst=self.slot_label_lst,
38
+ )
39
+ else:
40
+ self.config = self.config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.token_level)
41
+ self.model = self.model_class.from_pretrained(
42
+ args.model_name_or_path,
43
+ config=self.config,
44
+ args=args,
45
+ intent_label_lst=self.intent_label_lst,
46
+ slot_label_lst=self.slot_label_lst,
47
+ )
48
+ # GPU or CPU
49
+ torch.cuda.set_device(self.args.gpu_id)
50
+ print(self.args.gpu_id)
51
+ print(torch.cuda.current_device())
52
+ self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
53
+ self.model.to(self.device)
54
+
55
+ def train(self):
56
+ train_sampler = RandomSampler(self.train_dataset)
57
+ train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)
58
+ writer = SummaryWriter(log_dir=self.args.model_dir)
59
+ if self.args.max_steps > 0:
60
+ t_total = self.args.max_steps
61
+ self.args.num_train_epochs = (
62
+ self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
63
+ )
64
+ else:
65
+ t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
66
+ print("check init")
67
+ results = self.evaluate("dev")
68
+ print(results)
69
+ # Prepare optimizer and schedule (linear warmup and decay)
70
+ no_decay = ["bias", "LayerNorm.weight"]
71
+ optimizer_grouped_parameters = [
72
+ {
73
+ "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
74
+ "weight_decay": self.args.weight_decay,
75
+ },
76
+ {
77
+ "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
78
+ "weight_decay": 0.0,
79
+ },
80
+ ]
81
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
82
+ scheduler = get_linear_schedule_with_warmup(
83
+ optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total
84
+ )
85
+
86
+ # Train!
87
+ logger.info("***** Running training *****")
88
+ logger.info(" Num examples = %d", len(self.train_dataset))
89
+ logger.info(" Num Epochs = %d", self.args.num_train_epochs)
90
+ logger.info(" Total train batch size = %d", self.args.train_batch_size)
91
+ logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
92
+ logger.info(" Total optimization steps = %d", t_total)
93
+ logger.info(" Logging steps = %d", self.args.logging_steps)
94
+ logger.info(" Save steps = %d", self.args.save_steps)
95
+
96
+ global_step = 0
97
+ tr_loss = 0.0
98
+ self.model.zero_grad()
99
+
100
+ train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")
101
+ early_stopping = EarlyStopping(patience=self.args.early_stopping, verbose=True)
102
+
103
+ for _ in train_iterator:
104
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", position=0, leave=True)
105
+ print("\nEpoch", _)
106
+
107
+ for step, batch in enumerate(epoch_iterator):
108
+ self.model.train()
109
+ batch = tuple(t.to(self.device) for t in batch) # GPU or CPU
110
+
111
+ inputs = {
112
+ "input_ids": batch[0],
113
+ "attention_mask": batch[1],
114
+ "intent_label_ids": batch[3],
115
+ "slot_labels_ids": batch[4],
116
+ }
117
+ if self.args.model_type != "distilbert":
118
+ inputs["token_type_ids"] = batch[2]
119
+ outputs = self.model(**inputs)
120
+ loss = outputs[0]
121
+
122
+ if self.args.gradient_accumulation_steps > 1:
123
+ loss = loss / self.args.gradient_accumulation_steps
124
+
125
+ loss.backward()
126
+
127
+ tr_loss += loss.item()
128
+ if (step + 1) % self.args.gradient_accumulation_steps == 0:
129
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
130
+
131
+ optimizer.step()
132
+ scheduler.step() # Update learning rate schedule
133
+ self.model.zero_grad()
134
+ global_step += 1
135
+
136
+ if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
137
+ print("\nTuning metrics:", self.args.tuning_metric)
138
+ results = self.evaluate("dev")
139
+ writer.add_scalar("Loss/validation", results["loss"], _)
140
+ writer.add_scalar("Intent Accuracy/validation", results["intent_acc"], _)
141
+ writer.add_scalar("Slot F1/validation", results["slot_f1"], _)
142
+ writer.add_scalar("Mean Intent Slot", results["mean_intent_slot"], _)
143
+ writer.add_scalar("Sentence Accuracy/validation", results["semantic_frame_acc"], _)
144
+ early_stopping(results[self.args.tuning_metric], self.model, self.args)
145
+ if early_stopping.early_stop:
146
+ print("Early stopping")
147
+ break
148
+
149
+ # if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
150
+ # self.save_model()
151
+
152
+ if 0 < self.args.max_steps < global_step:
153
+ epoch_iterator.close()
154
+ break
155
+
156
+ if 0 < self.args.max_steps < global_step or early_stopping.early_stop:
157
+ train_iterator.close()
158
+ break
159
+ writer.add_scalar("Loss/train", tr_loss / global_step, _)
160
+
161
+ return global_step, tr_loss / global_step
162
+
163
+ def write_evaluation_result(self, out_file, results):
164
+ out_file = self.args.model_dir + "/" + out_file
165
+ w = open(out_file, "w", encoding="utf-8")
166
+ w.write("***** Eval results *****\n")
167
+ for key in sorted(results.keys()):
168
+ to_write = " {key} = {value}".format(key=key, value=str(results[key]))
169
+ w.write(to_write)
170
+ w.write("\n")
171
+ w.close()
172
+
173
+ def evaluate(self, mode):
174
+ if mode == "test":
175
+ dataset = self.test_dataset
176
+ elif mode == "dev":
177
+ dataset = self.dev_dataset
178
+ else:
179
+ raise Exception("Only dev and test dataset available")
180
+
181
+ eval_sampler = SequentialSampler(dataset)
182
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)
183
+
184
+ # Eval!
185
+ logger.info("***** Running evaluation on %s dataset *****", mode)
186
+ logger.info(" Num examples = %d", len(dataset))
187
+ logger.info(" Batch size = %d", self.args.eval_batch_size)
188
+ eval_loss = 0.0
189
+ nb_eval_steps = 0
190
+ intent_preds = None
191
+ slot_preds = None
192
+ out_intent_label_ids = None
193
+ out_slot_labels_ids = None
194
+
195
+ self.model.eval()
196
+
197
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
198
+ batch = tuple(t.to(self.device) for t in batch)
199
+ with torch.no_grad():
200
+ inputs = {
201
+ "input_ids": batch[0],
202
+ "attention_mask": batch[1],
203
+ "intent_label_ids": batch[3],
204
+ "slot_labels_ids": batch[4],
205
+ }
206
+ if self.args.model_type != "distilbert":
207
+ inputs["token_type_ids"] = batch[2]
208
+ outputs = self.model(**inputs)
209
+ tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]
210
+
211
+ eval_loss += tmp_eval_loss.mean().item()
212
+ nb_eval_steps += 1
213
+
214
+ # Intent prediction
215
+ if intent_preds is None:
216
+ intent_preds = intent_logits.detach().cpu().numpy()
217
+ out_intent_label_ids = inputs["intent_label_ids"].detach().cpu().numpy()
218
+ else:
219
+ intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
220
+ out_intent_label_ids = np.append(
221
+ out_intent_label_ids, inputs["intent_label_ids"].detach().cpu().numpy(), axis=0
222
+ )
223
+
224
+ # Slot prediction
225
+ if slot_preds is None:
226
+ if self.args.use_crf:
227
+ # decode() in `torchcrf` returns list with best index directly
228
+ slot_preds = np.array(self.model.crf.decode(slot_logits))
229
+ else:
230
+ slot_preds = slot_logits.detach().cpu().numpy()
231
+
232
+ out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
233
+ else:
234
+ if self.args.use_crf:
235
+ slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
236
+ else:
237
+ slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
238
+
239
+ out_slot_labels_ids = np.append(
240
+ out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(), axis=0
241
+ )
242
+
243
+ eval_loss = eval_loss / nb_eval_steps
244
+ results = {"loss": eval_loss}
245
+
246
+ # Intent result
247
+ intent_preds = np.argmax(intent_preds, axis=1)
248
+
249
+ # Slot result
250
+ if not self.args.use_crf:
251
+ slot_preds = np.argmax(slot_preds, axis=2)
252
+ slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
253
+ out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
254
+ slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
255
+
256
+ for i in range(out_slot_labels_ids.shape[0]):
257
+ for j in range(out_slot_labels_ids.shape[1]):
258
+ if out_slot_labels_ids[i, j] != self.pad_token_label_id:
259
+ out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
260
+ slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
261
+
262
+ total_result = compute_metrics(intent_preds, out_intent_label_ids, slot_preds_list, out_slot_label_list)
263
+ results.update(total_result)
264
+
265
+ logger.info("***** Eval results *****")
266
+ for key in sorted(results.keys()):
267
+ logger.info(" %s = %s", key, str(results[key]))
268
+ if mode == "test":
269
+ self.write_evaluation_result("eval_test_results.txt", results)
270
+ elif mode == "dev":
271
+ self.write_evaluation_result("eval_dev_results.txt", results)
272
+ return results
273
+
274
+ def save_model(self):
275
+ # Save model checkpoint (Overwrite)
276
+ if not os.path.exists(self.args.model_dir):
277
+ os.makedirs(self.args.model_dir)
278
+ model_to_save = self.model.module if hasattr(self.model, "module") else self.model
279
+ model_to_save.save_pretrained(self.args.model_dir)
280
+
281
+ # Save training arguments together with the trained model
282
+ torch.save(self.args, os.path.join(self.args.model_dir, "training_args.bin"))
283
+ logger.info("Saving model checkpoint to %s", self.args.model_dir)
284
+
285
+ def load_model(self):
286
+ # Check whether model exists
287
+ if not os.path.exists(self.args.model_dir):
288
+ raise Exception("Model doesn't exists! Train first!")
289
+
290
+ try:
291
+ self.model = self.model_class.from_pretrained(
292
+ self.args.model_dir,
293
+ args=self.args,
294
+ intent_label_lst=self.intent_label_lst,
295
+ slot_label_lst=self.slot_label_lst,
296
+ )
297
+ self.model.to(self.device)
298
+ logger.info("***** Model Loaded *****")
299
+ except Exception:
300
+ raise Exception("Some model files might be missing...")
utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from model import JointPhoBERT, JointXLMR
8
+ from seqeval.metrics import f1_score, precision_score, recall_score
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ RobertaConfig,
12
+ XLMRobertaConfig,
13
+ XLMRobertaTokenizer,
14
+ )
15
+
16
+
17
+ MODEL_CLASSES = {
18
+ "xlmr": (XLMRobertaConfig, JointXLMR, XLMRobertaTokenizer),
19
+ "phobert": (RobertaConfig, JointPhoBERT, AutoTokenizer),
20
+ }
21
+
22
+ MODEL_PATH_MAP = {
23
+ "xlmr": "xlm-roberta-base",
24
+ "phobert": "vinai/phobert-base",
25
+ }
26
+
27
+
28
+ def get_intent_labels(args):
29
+ return [
30
+ label.strip()
31
+ for label in open(os.path.join(args.data_dir, args.token_level, args.intent_label_file), "r", encoding="utf-8")
32
+ ]
33
+
34
+
35
+ def get_slot_labels(args):
36
+ return [
37
+ label.strip()
38
+ for label in open(os.path.join(args.data_dir, args.token_level, args.slot_label_file), "r", encoding="utf-8")
39
+ ]
40
+
41
+
42
+ def load_tokenizer(args):
43
+ return MODEL_CLASSES[args.model_type][2].from_pretrained(args.model_name_or_path)
44
+
45
+
46
+ def init_logger():
47
+ logging.basicConfig(
48
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
49
+ datefmt="%m/%d/%Y %H:%M:%S",
50
+ level=logging.INFO,
51
+ )
52
+
53
+
54
+ def set_seed(args):
55
+ random.seed(args.seed)
56
+ np.random.seed(args.seed)
57
+ torch.manual_seed(args.seed)
58
+ if not args.no_cuda and torch.cuda.is_available():
59
+ torch.cuda.manual_seed_all(args.seed)
60
+
61
+
62
+ def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels):
63
+ assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels)
64
+ results = {}
65
+ intent_result = get_intent_acc(intent_preds, intent_labels)
66
+ slot_result = get_slot_metrics(slot_preds, slot_labels)
67
+ sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels)
68
+
69
+ mean_intent_slot = (intent_result["intent_acc"] + slot_result["slot_f1"]) / 2
70
+
71
+ results.update(intent_result)
72
+ results.update(slot_result)
73
+ results.update(sementic_result)
74
+ results["mean_intent_slot"] = mean_intent_slot
75
+
76
+ return results
77
+
78
+
79
+ def get_slot_metrics(preds, labels):
80
+ assert len(preds) == len(labels)
81
+ return {
82
+ "slot_precision": precision_score(labels, preds),
83
+ "slot_recall": recall_score(labels, preds),
84
+ "slot_f1": f1_score(labels, preds),
85
+ }
86
+
87
+
88
+ def get_intent_acc(preds, labels):
89
+ acc = (preds == labels).mean()
90
+ return {"intent_acc": acc}
91
+
92
+
93
+ def read_prediction_text(args):
94
+ return [text.strip() for text in open(os.path.join(args.pred_dir, args.pred_input_file), "r", encoding="utf-8")]
95
+
96
+
97
+ def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels):
98
+ """For the cases that intent and all the slots are correct (in one sentence)"""
99
+ # Get the intent comparison result
100
+ intent_result = intent_preds == intent_labels
101
+
102
+ # Get the slot comparision result
103
+ slot_result = []
104
+ for preds, labels in zip(slot_preds, slot_labels):
105
+ assert len(preds) == len(labels)
106
+ one_sent_result = True
107
+ for p, l in zip(preds, labels):
108
+ if p != l:
109
+ one_sent_result = False
110
+ break
111
+ slot_result.append(one_sent_result)
112
+ slot_result = np.array(slot_result)
113
+
114
+ semantic_acc = np.multiply(intent_result, slot_result).mean()
115
+ return {"semantic_frame_acc": semantic_acc}