khaiphan29 commited on
Commit
0217fc8
·
1 Parent(s): 0af052d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ COPY . .
13
+
14
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,9 @@
1
  ---
2
- title: Fact Check Api
3
- emoji: 📈
4
- colorFrom: blue
5
  colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.12.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
+ title: Fact Checking Api
3
+ emoji: 📊
4
+ colorFrom: pink
5
  colorTo: blue
6
+ sdk: docker
 
 
7
  pinned: false
8
  ---
9
 
main.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #uvicorn main:app --reload
2
+ from fastapi import FastAPI, status
3
+ from fastapi.responses import Response, JSONResponse
4
+ from pydantic import BaseModel
5
+
6
+ from typing import List
7
+
8
+ import os
9
+ import json
10
+ import time
11
+
12
+ from src.myNLI import FactChecker
13
+ from src.crawler import MyCrawler
14
+
15
+ #request body
16
+ class Claim(BaseModel):
17
+ claim: str
18
+
19
+ class ScrapeBase(BaseModel):
20
+ id: int
21
+ name: str
22
+ scraping_url: str
23
+
24
+ class ScrapeList(BaseModel):
25
+ data: List[ScrapeBase]
26
+
27
+ app = FastAPI()
28
+
29
+ # load model
30
+ t_0 = time.time()
31
+ fact_checker = FactChecker()
32
+ t_load = time.time() - t_0
33
+ print("time load model: {}".format(t_load))
34
+
35
+ crawler = MyCrawler()
36
+
37
+ label_code = {
38
+ "REFUTED": 0,
39
+ "SUPPORTED": 1,
40
+ "NEI": 2
41
+ }
42
+
43
+ @app.get("/")
44
+ async def root():
45
+ return {"msg": "This is for interacting with Fact-checking AI Model"}
46
+
47
+ @app.post("/ai-fact-check")
48
+ async def get_claim(req: Claim):
49
+ claim = req.claim
50
+ result = fact_checker.predict(claim)
51
+ print(result)
52
+
53
+ if not result:
54
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
55
+
56
+ return { "claim": claim,
57
+ "final_label": label_code[result["label"]],
58
+ "evidence": result["evidence"],
59
+ "provider": result["provider"],
60
+ "url": result["url"]
61
+ }
62
+
63
+ @app.post("/scraping-check")
64
+ async def get_claim(req: ScrapeList):
65
+ response = []
66
+ for ele in req.data:
67
+ response.append({
68
+ "id": ele.id,
69
+ "name": ele.name,
70
+ "scraping_url": ele.scraping_url,
71
+ "status": crawler.scraping(ele.scraping_url)
72
+ })
73
+
74
+
75
+ return JSONResponse({
76
+ "list": response
77
+ })
requirements.txt ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.1
2
+ aiosignal==1.3.1
3
+ annotated-types==0.6.0
4
+ anyio==4.2.0
5
+ async-timeout==4.0.3
6
+ attrs==23.2.0
7
+ beautifulsoup4==4.12.2
8
+ certifi==2023.11.17
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ datasets==2.16.1
12
+ dill==0.3.7
13
+ exceptiongroup==1.2.0
14
+ fastapi==0.108.0
15
+ filelock==3.13.1
16
+ frozenlist==1.4.1
17
+ fsspec==2023.10.0
18
+ h11==0.14.0
19
+ huggingface-hub==0.20.1
20
+ idna==3.6
21
+ Jinja2==3.1.2
22
+ joblib==1.3.2
23
+ MarkupSafe==2.1.3
24
+ mpmath==1.3.0
25
+ multidict==6.0.4
26
+ multiprocess==0.70.15
27
+ networkx==3.2.1
28
+ nltk==3.8.1
29
+ numpy==1.26.2
30
+ packaging==23.2
31
+ pandas==2.1.4
32
+ Pillow==10.1.0
33
+ pyarrow==14.0.2
34
+ pyarrow-hotfix==0.6
35
+ pydantic==2.5.3
36
+ pydantic_core==2.14.6
37
+ python-dateutil==2.8.2
38
+ pytz==2023.3.post1
39
+ PyYAML==6.0.1
40
+ regex==2023.12.25
41
+ requests==2.31.0
42
+ safetensors==0.4.1
43
+ scikit-learn==1.3.2
44
+ scipy==1.11.4
45
+ sentence-transformers==2.2.2
46
+ sentencepiece==0.1.99
47
+ six==1.16.0
48
+ sniffio==1.3.0
49
+ soupsieve==2.5
50
+ starlette==0.32.0.post1
51
+ sympy==1.12
52
+ threadpoolctl==3.2.0
53
+ tokenizers==0.15.0
54
+ torch==2.1.2
55
+ torchvision==0.16.2
56
+ tqdm==4.66.1
57
+ transformers==4.36.2
58
+ typing_extensions==4.9.0
59
+ tzdata==2023.4
60
+ urllib3==2.1.0
61
+ uvicorn==0.25.0
62
+ xxhash==3.4.1
63
+ yarl==1.9.4
script.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ api.upload_folder(
2
+ folder_path="./src",
3
+ repo_id="khaiphan29/fact-check-api",
4
+ repo_type="space",
5
+ )
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
File without changes
src/crawler.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import re
4
+ import time
5
+
6
+ from .utils import timer_func
7
+
8
+ def remove_emoji(string):
9
+ emoji_pattern = re.compile("["
10
+ u"\U0001F300-\U0001FAD6" # emoticons
11
+ u"\U0001F300-\U0001F5FF" # symbols & pictographs
12
+ u"\U0001F680-\U0001F6FF" # transport & map symbols
13
+ u"\U0001F1E0-\U0001F1FF" # flags (iOS)
14
+ u"\U00002702-\U000027B0"
15
+ u"\U000024C2-\U0001F251"
16
+ "]+", flags=re.UNICODE)
17
+ return emoji_pattern.sub(r'', string)
18
+
19
+ def preprocess(texts):
20
+ texts = [text.replace("_", " ") for text in texts]
21
+ texts = [i.lower() for i in texts]
22
+ texts = [remove_emoji(i) for i in texts]
23
+
24
+ texts = [re.sub('[^\w\d\s]', '', i) for i in texts]
25
+
26
+ texts = [re.sub('\s+|\n', ' ', i) for i in texts]
27
+ texts = [re.sub('^\s|\s$', '', i) for i in texts]
28
+
29
+ # texts = [ViTokenizer.tokenize(i) for i in texts]
30
+
31
+ return texts
32
+
33
+
34
+ class MyCrawler:
35
+ headers = {
36
+ "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.67 Safari/537.36",
37
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
38
+ 'Accept-Language': 'en-US,en;q=0.5',
39
+ 'Accept-Encoding': 'gzip, deflate',
40
+ 'DNT': '1',
41
+ 'Connection': 'keep-alive',
42
+ 'Upgrade-Insecure-Requests': '1'
43
+ }
44
+
45
+ # headers = {
46
+ # 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:49.0) Gecko/20100101 Firefox/49.0',
47
+ # # 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
48
+ # # 'Accept-Language': 'en-US,en;q=0.5',
49
+ # # 'Accept-Encoding': 'gzip, deflate',
50
+ # # 'DNT': '1',
51
+ # # 'Connection': 'keep-alive',
52
+ # # 'Upgrade-Insecure-Requests': '1'
53
+ # }
54
+
55
+ def getSoup(self, url: str):
56
+ req = requests.get(url,headers=self.headers)
57
+ return BeautifulSoup(req.text, 'html.parser')
58
+
59
+ def crawl_byContainer(self, url: str, article_container: str, body_class: str):
60
+ soup = self.getSoup(url)
61
+
62
+ paragraphs = soup.find(article_container,{"class": body_class})
63
+ if paragraphs:
64
+ #Crawl all paragraphs
65
+ contents = []
66
+ numOfParagraphs = 0
67
+ for p in paragraphs.find_all("p"):
68
+ contents.append(p.get_text())
69
+ numOfParagraphs += 1
70
+ # if numOfParagraphs > 10:
71
+ # break
72
+
73
+ if contents:
74
+ result = "\n".join(contents)
75
+ if (url.split("/")[2] == "vnexpress.net"):
76
+ result = self.crawl_byElement(soup, "p", "description") + "\n" + result
77
+
78
+ return result
79
+ return ""
80
+
81
+ def crawl_byElement(self, soup, element: str, ele_class: str):
82
+ print("by Elements...")
83
+
84
+ paragraph = soup.find(element,{"class": ele_class})
85
+ if paragraph:
86
+ print(paragraph.get_text())
87
+ return paragraph.get_text()
88
+ return ""
89
+
90
+ def crawl_webcontent(self, url: str):
91
+
92
+ provider = url.split("/")[2]
93
+ content = ""
94
+
95
+ if provider == "thanhnien.vn" or provider == "tuoitre.vn":
96
+ content = self.crawl_byContainer(url, "div", "afcbc-body")
97
+ elif provider == "vietnamnet.vn":
98
+ content = self.crawl_byContainer(url, "div", "maincontent")
99
+ elif provider == "vnexpress.net":
100
+ content = self.crawl_byContainer(url, "article", "fck_detail")
101
+ elif provider == "www.24h.com.vn":
102
+ content = self.crawl_byContainer(url, "article", "cate-24h-foot-arti-deta-info")
103
+ elif provider == "vov.vn":
104
+ content = self.crawl_byContainer(url, "div", "article-content")
105
+ elif provider == "vtv.vn":
106
+ content = self.crawl_byContainer(url, "div", "ta-justify")
107
+ elif provider == "vi.wikipedia.org":
108
+ content = self.crawl_byContainer(url, "div", "mw-content-ltr")
109
+ elif provider == "www.vinmec.com":
110
+ content = self.crawl_byContainer(url, "div", "block-content")
111
+
112
+ elif provider == "vietstock.vn":
113
+ content = self.crawl_byContainer(url, "div", "single_post_heading")
114
+ elif provider == "vneconomy.vn":
115
+ content = self.crawl_byContainer(url, "article", "detail-wrap")
116
+
117
+ elif provider == "dantri.com.vn":
118
+ content = self.crawl_byContainer(url, "article", "singular-container")
119
+
120
+ # elif provider == "plo.vn":
121
+ # content = self.crawl_byContainer(url, "div", "article__body")
122
+
123
+ return provider, url, content
124
+
125
+ #def crawl_redir(url):
126
+
127
+ @timer_func
128
+ def search(self, claim: str, count: int = 1):
129
+ processed_claim = preprocess([claim])[0]
130
+
131
+ num_words = 100
132
+ ls_word = processed_claim.split(" ")
133
+ claim_short = " ".join(ls_word[:num_words])
134
+
135
+ print(claim_short)
136
+ query = claim_short
137
+ # query = '+'.join(claim_short.split(" "))
138
+
139
+ try:
140
+
141
+ # print(soup.prettify())
142
+
143
+ #get all URLs
144
+ attemp_time = 0
145
+ urls = []
146
+ while len(urls) == 0 and attemp_time < 3:
147
+ req=requests.get("https://www.bing.com/search?", headers=self.headers, params={
148
+ "q": query,
149
+ "responseFilter":"-images",
150
+ "responseFilter":"-videos"
151
+ })
152
+ print("Query URL: " + req.url)
153
+
154
+ print("Crawling Attempt " + str(attemp_time))
155
+ soup = BeautifulSoup(req.text, 'html.parser')
156
+
157
+ completeData = soup.find_all("li",{"class":"b_algo"})
158
+ for data in completeData:
159
+ urls.append(data.find("a", href=True)["href"])
160
+ attemp_time += 1
161
+ time.sleep(1)
162
+
163
+ print("Got " + str(len(urls)) + " urls")
164
+
165
+ result = []
166
+
167
+ for url in urls:
168
+ print("Crawling... " + url)
169
+ provider, url, content = self.crawl_webcontent(url)
170
+
171
+ if content:
172
+ result.append({
173
+ "provider": provider,
174
+ "url": url,
175
+ "content": content
176
+ })
177
+ count -= 1
178
+ if count == 0:
179
+ break
180
+
181
+ return result
182
+
183
+ except Exception as e:
184
+ print(e)
185
+ return []
186
+
187
+ @timer_func
188
+ def searchGoogle(self, claim: str, count: int = 1):
189
+ processed_claim = preprocess([claim])[0]
190
+
191
+ num_words = 100
192
+ ls_word = processed_claim.split(" ")
193
+ claim_short = " ".join(ls_word[:num_words])
194
+
195
+ print(claim_short)
196
+ query = claim_short
197
+ # query = '+'.join(claim_short.split(" "))
198
+
199
+ try:
200
+
201
+ # print(soup.prettify())
202
+
203
+ #get all URLs
204
+ attemp_time = 0
205
+ urls = []
206
+ while len(urls) == 0 and attemp_time < 3:
207
+ req=requests.get("https://www.google.com/search?", headers=self.headers, params={
208
+ "q": query
209
+ })
210
+ print("Query URL: " + req.url)
211
+
212
+ print("Crawling Attempt " + str(attemp_time))
213
+ soup = BeautifulSoup(req.text, 'html.parser')
214
+
215
+ completeData = soup.find_all("a",{"jsname":"UWckNb"})
216
+ for data in completeData:
217
+ urls.append(data["href"])
218
+ attemp_time += 1
219
+ time.sleep(1)
220
+
221
+ print("Got " + str(len(urls)) + " urls")
222
+
223
+ result = []
224
+
225
+ for url in urls:
226
+ print("Crawling... " + url)
227
+ provider, url, content = self.crawl_webcontent(url)
228
+
229
+ if content:
230
+ result.append({
231
+ "provider": provider,
232
+ "url": url,
233
+ "content": content
234
+ })
235
+ count -= 1
236
+ if count == 0:
237
+ break
238
+
239
+ return result
240
+
241
+ except Exception as e:
242
+ print(e)
243
+ return []
244
+
245
+ @timer_func
246
+ def scraping(self, url: str):
247
+ try:
248
+ provider, url, content = self.crawl_webcontent(url)
249
+
250
+ if content:
251
+ return True
252
+ return False
253
+
254
+ except Exception as e:
255
+ print(e)
256
+ return False
src/mDeBERTa (ft) V6/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/mDeBERTa (ft) V6/cls.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1c3c8eae44569fd01a746b220091611125f9eb04e09af2d60a6d80befcdb769
3
+ size 11064
src/mDeBERTa (ft) V6/cls_log.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Step 0 -- Accuracy: 0.3039772727272727 -- macro_f1: 0.20810584530698015 -- loss: 1.0453389883041382
3
+
4
+ Step 100 -- Accuracy: 0.859375 -- macro_f1: 0.8598470398571504 -- loss: 0.11795929819345474
5
+
6
+ Step 200 -- Accuracy: 0.8747159090909091 -- macro_f1: 0.8755251824421424 -- loss: 0.22730453312397003
7
+
8
+ Step 300 -- Accuracy: 0.8536931818181818 -- macro_f1: 0.8533303214529117 -- loss: 0.18725647032260895
9
+
10
+ Step 400 -- Accuracy: 0.8690340909090909 -- macro_f1: 0.8687299763460793 -- loss: 0.28860458731651306
11
+
12
+ Step 500 -- Accuracy: 0.8798295454545455 -- macro_f1: 0.8802316356122608 -- loss: 0.6372634172439575
13
+
14
+ Step 600 -- Accuracy: 0.8610795454545455 -- macro_f1: 0.8612099869711884 -- loss: 0.41530805826187134
15
+
16
+ Step 700 -- Accuracy: 0.8491477272727272 -- macro_f1: 0.849751664990205 -- loss: 0.5970628261566162
17
+
18
+ Step 800 -- Accuracy: 0.8764204545454546 -- macro_f1: 0.8766266441048876 -- loss: 0.2515469491481781
19
+
20
+ Step 900 -- Accuracy: 0.8710227272727272 -- macro_f1: 0.8712350728851791 -- loss: 0.619756817817688
21
+
22
+ Step 1000 -- Accuracy: 0.8744318181818181 -- macro_f1: 0.8746062203201398 -- loss: 0.5634986758232117
23
+
24
+ Step 1100 -- Accuracy: 0.8735795454545454 -- macro_f1: 0.8735921715063891 -- loss: 0.2514641284942627
25
+
26
+ Step 1200 -- Accuracy: 0.8375 -- macro_f1: 0.8368621880475362 -- loss: 0.44521981477737427
27
+
28
+ Step 1300 -- Accuracy: 0.8551136363636364 -- macro_f1: 0.8555806721970362 -- loss: 0.048632219433784485
29
+
30
+ Step 1400 -- Accuracy: 0.8508522727272727 -- macro_f1: 0.8506097642423027 -- loss: 0.24613773822784424
31
+
32
+ Step 1500 -- Accuracy: 0.8673295454545454 -- macro_f1: 0.8671847303392856 -- loss: 0.1494443565607071
33
+
34
+ Step 1600 -- Accuracy: 0.834375 -- macro_f1: 0.8342641066244109 -- loss: 0.17161081731319427
35
+
36
+ Step 1700 -- Accuracy: 0.865625 -- macro_f1: 0.8651594643017528 -- loss: 0.154042050242424
37
+
38
+ Step 1800 -- Accuracy: 0.865909090909091 -- macro_f1: 0.8657615265484808 -- loss: 0.1435176134109497
39
+
40
+ Step 1900 -- Accuracy: 0.8176136363636364 -- macro_f1: 0.8171586288909666 -- loss: 0.09292535483837128
41
+
42
+ Step 2000 -- Accuracy: 0.8440340909090909 -- macro_f1: 0.843042759250924 -- loss: 0.34320467710494995
43
+
44
+ Step 2100 -- Accuracy: 0.8428977272727273 -- macro_f1: 0.8428498174495328 -- loss: 0.5764151811599731
45
+
46
+ Step 2200 -- Accuracy: 0.8417613636363637 -- macro_f1: 0.8418818479059557 -- loss: 0.28757143020629883
47
+
48
+ Step 2300 -- Accuracy: 0.840625 -- macro_f1: 0.8406394626850148 -- loss: 0.8960273861885071
49
+
50
+ Step 2400 -- Accuracy: 0.8142045454545455 -- macro_f1: 0.8140964442024906 -- loss: 0.8550783395767212
51
+
52
+ Step 2500 -- Accuracy: 0.8144886363636363 -- macro_f1: 0.8147455224461172 -- loss: 0.39625313878059387
53
+
54
+ Step 2600 -- Accuracy: 0.8053977272727273 -- macro_f1: 0.8021211300036969 -- loss: 0.3774358034133911
55
+
56
+ Step 2700 -- Accuracy: 0.8292613636363636 -- macro_f1: 0.8292382309283113 -- loss: 0.16644884645938873
57
+
58
+ Step 2800 -- Accuracy: 0.8150568181818182 -- macro_f1: 0.814290740222007 -- loss: 0.237399160861969
59
+
60
+ Step 2900 -- Accuracy: 0.8107954545454545 -- macro_f1: 0.8111709474507229 -- loss: 0.5621077418327332
61
+
62
+ Step 3000 -- Accuracy: 0.7926136363636364 -- macro_f1: 0.7930916669737708 -- loss: 0.4253169298171997
63
+
64
+ Step 3100 -- Accuracy: 0.8099431818181818 -- macro_f1: 0.8102288703246834 -- loss: 0.43165838718414307
65
+
66
+ Step 3200 -- Accuracy: 0.772159090909091 -- macro_f1: 0.7717788019596861 -- loss: 0.673878014087677
67
+
68
+ Step 3300 -- Accuracy: 0.7897727272727273 -- macro_f1: 0.7895567869064662 -- loss: 0.1990412026643753
69
+
70
+ Step 3400 -- Accuracy: 0.8008522727272728 -- macro_f1: 0.7997998535844976 -- loss: 0.4523601531982422
71
+
72
+ Step 3500 -- Accuracy: 0.7798295454545454 -- macro_f1: 0.7780260696858295 -- loss: 0.8848648071289062
73
+
74
+ Step 3600 -- Accuracy: 0.7775568181818182 -- macro_f1: 0.7779453966289696 -- loss: 0.5041539669036865
75
+
76
+ Step 3700 -- Accuracy: 0.709659090909091 -- macro_f1: 0.7069128111001839 -- loss: 0.6758942604064941
src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/checkpoint",
3
+ "architectures": [
4
+ "DebertaV2Model"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "entailment",
12
+ "1": "neutral",
13
+ "2": "contradiction"
14
+ },
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "label2id": {
18
+ "contradiction": 2,
19
+ "entailment": 0,
20
+ "neutral": 1
21
+ },
22
+ "layer_norm_eps": 1e-07,
23
+ "max_position_embeddings": 512,
24
+ "max_relative_positions": -1,
25
+ "model_type": "deberta-v2",
26
+ "norm_rel_ebd": "layer_norm",
27
+ "num_attention_heads": 12,
28
+ "num_hidden_layers": 12,
29
+ "pad_token_id": 0,
30
+ "pooler_dropout": 0,
31
+ "pooler_hidden_act": "gelu",
32
+ "pooler_hidden_size": 768,
33
+ "pos_att_type": [
34
+ "p2c",
35
+ "c2p"
36
+ ],
37
+ "position_biased_input": false,
38
+ "position_buckets": 256,
39
+ "relative_attention": true,
40
+ "share_att_key": true,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.35.0",
43
+ "type_vocab_size": 0,
44
+ "vocab_size": 251000
45
+ }
src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c7e80e8237ad2969b1c989d71f97fa7b950fd239bfa8b3329f0535a0b8a2aca
3
+ size 1112897768
src/mDeBERTa (ft) V6/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f963dfcdad5469498af3b396c5af0e27365e59a01498c51896b9e6547851cd4
3
+ size 11071
src/mDeBERTa (ft) V6/mean_log.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Step 0 -- Accuracy: 0.275 -- macro_f1: 0.24245894645844043 -- loss: 1.1975505352020264
3
+
4
+ Step 100 -- Accuracy: 0.8230113636363636 -- macro_f1: 0.8247917227891541 -- loss: 0.5072745084762573
5
+
6
+ Step 200 -- Accuracy: 0.8585227272727273 -- macro_f1: 0.8596474113005192 -- loss: 0.3576969504356384
7
+
8
+ Step 300 -- Accuracy: 0.8616477272727273 -- macro_f1: 0.8619445917534628 -- loss: 0.22678352892398834
9
+
10
+ Step 400 -- Accuracy: 0.8710227272727272 -- macro_f1: 0.8713149438253084 -- loss: 0.3302939534187317
11
+
12
+ Step 500 -- Accuracy: 0.8491477272727272 -- macro_f1: 0.8497535984618637 -- loss: 0.8534196615219116
13
+
14
+ Step 600 -- Accuracy: 0.8627840909090909 -- macro_f1: 0.8630171351987245 -- loss: 0.27207863330841064
15
+
16
+ Step 700 -- Accuracy: 0.8676136363636363 -- macro_f1: 0.8681189318753203 -- loss: 0.5472040772438049
17
+
18
+ Step 800 -- Accuracy: 0.8480113636363636 -- macro_f1: 0.8474828960740969 -- loss: 0.20389704406261444
19
+
20
+ Step 900 -- Accuracy: 0.8625 -- macro_f1: 0.8627369387200629 -- loss: 0.7003616094589233
21
+
22
+ Step 1000 -- Accuracy: 0.8471590909090909 -- macro_f1: 0.8474576933366409 -- loss: 0.39897170662879944
23
+
24
+ Step 1100 -- Accuracy: 0.8647727272727272 -- macro_f1: 0.8648449015557045 -- loss: 0.30028393864631653
25
+
26
+ Step 1200 -- Accuracy: 0.8355113636363637 -- macro_f1: 0.8357176579844655 -- loss: 0.5329824090003967
27
+
28
+ Step 1300 -- Accuracy: 0.8318181818181818 -- macro_f1: 0.832158484567787 -- loss: 0.04946904629468918
29
+
30
+ Step 1400 -- Accuracy: 0.8275568181818181 -- macro_f1: 0.8270568913757921 -- loss: 0.290753036737442
31
+
32
+ Step 1500 -- Accuracy: 0.8619318181818182 -- macro_f1: 0.8620216901652552 -- loss: 0.17760200798511505
33
+
34
+ Step 1600 -- Accuracy: 0.8366477272727273 -- macro_f1: 0.8372501215741125 -- loss: 0.18745465576648712
35
+
36
+ Step 1700 -- Accuracy: 0.8556818181818182 -- macro_f1: 0.8555692365839257 -- loss: 0.09077112376689911
37
+
38
+ Step 1800 -- Accuracy: 0.8571022727272727 -- macro_f1: 0.8569408344903815 -- loss: 0.24079212546348572
39
+
40
+ Step 1900 -- Accuracy: 0.8122159090909091 -- macro_f1: 0.8117034674801616 -- loss: 0.3681311309337616
41
+
42
+ Step 2000 -- Accuracy: 0.8318181818181818 -- macro_f1: 0.8319676688379705 -- loss: 0.2374744713306427
43
+
44
+ Step 2100 -- Accuracy: 0.8443181818181819 -- macro_f1: 0.8442918629955193 -- loss: 0.4600515365600586
45
+
46
+ Step 2200 -- Accuracy: 0.8278409090909091 -- macro_f1: 0.8269904995679983 -- loss: 0.3283902704715729
47
+
48
+ Step 2300 -- Accuracy: 0.8298295454545455 -- macro_f1: 0.8299882032010862 -- loss: 1.0965081453323364
49
+
50
+ Step 2400 -- Accuracy: 0.8159090909090909 -- macro_f1: 0.8159808860940237 -- loss: 0.7295159697532654
51
+
52
+ Step 2500 -- Accuracy: 0.8159090909090909 -- macro_f1: 0.8142475187664063 -- loss: 0.3925968408584595
53
+
54
+ Step 2600 -- Accuracy: 0.8204545454545454 -- macro_f1: 0.820545798600696 -- loss: 0.3808274567127228
55
+
56
+ Step 2700 -- Accuracy: 0.8198863636363637 -- macro_f1: 0.8199413434559383 -- loss: 0.26008090376853943
57
+
58
+ Step 2800 -- Accuracy: 0.8056818181818182 -- macro_f1: 0.8051566431375038 -- loss: 0.20567485690116882
59
+
60
+ Step 2900 -- Accuracy: 0.784375 -- macro_f1: 0.7848921849530183 -- loss: 0.5506788492202759
61
+
62
+ Step 3000 -- Accuracy: 0.8153409090909091 -- macro_f1: 0.8150634367874668 -- loss: 0.4250873923301697
63
+
64
+ Step 3100 -- Accuracy: 0.7991477272727273 -- macro_f1: 0.8000715520252392 -- loss: 0.4798588752746582
65
+
66
+ Step 3200 -- Accuracy: 0.7840909090909091 -- macro_f1: 0.7836356305606565 -- loss: 0.5604580640792847
67
+
68
+ Step 3300 -- Accuracy: 0.7977272727272727 -- macro_f1: 0.7965403402362528 -- loss: 0.26682722568511963
69
+
70
+ Step 3400 -- Accuracy: 0.809375 -- macro_f1: 0.8087947373143304 -- loss: 0.3252097964286804
71
+
72
+ Step 3500 -- Accuracy: 0.7568181818181818 -- macro_f1: 0.7548780108676749 -- loss: 0.9467527866363525
73
+
74
+ Step 3600 -- Accuracy: 0.7889204545454546 -- macro_f1: 0.7892382882596812 -- loss: 0.29441171884536743
75
+
76
+ Step 3700 -- Accuracy: 0.7227272727272728 -- macro_f1: 0.7227876418017654 -- loss: 0.8389160633087158
src/mDeBERTa (ft) V6/plot.png ADDED
src/mDeBERTa (ft) V6/public_train_v4.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56c03b9bb2cab8ffbe138badea76b6275ebad727e99f5040d2a8c21f2dcfaff2
3
+ size 227113690
src/myNLI.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
3
+ from sentence_transformers import SentenceTransformer, util
4
+ import nltk
5
+
6
+ # import datasets
7
+ from datasets import Dataset, DatasetDict
8
+
9
+ from typing import List
10
+
11
+ from .utils import timer_func
12
+ from .nli_v3 import NLI_model
13
+ from .crawler import MyCrawler
14
+
15
+ int2label = {0:'SUPPORTED', 1:'NEI', 2:'REFUTED'}
16
+
17
+ class FactChecker:
18
+
19
+ @timer_func
20
+ def __init__(self):
21
+ self.INPUT_TYPE = "mean"
22
+ self.load_model()
23
+
24
+ @timer_func
25
+ def load_model(self):
26
+ self.envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
27
+
28
+ # Load LLM
29
+ self.tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # LOAD mDEBERTa TOKENIZER
30
+ self.mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{self.INPUT_TYPE}") # LOAD FINETUNED MODEL
31
+ # Load classifier model
32
+ self.checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{self.INPUT_TYPE}.pt", map_location=self.envir)
33
+
34
+ self.classifierModel = NLI_model(768, torch.tensor([0., 0., 0.])).to(self.envir)
35
+ self.classifierModel.load_state_dict(self.checkpoints['model_state_dict'])
36
+
37
+ #Load model for predict similarity
38
+ self.model_sbert = SentenceTransformer('keepitreal/vietnamese-sbert')
39
+
40
+ @timer_func
41
+ def get_similarity_v2(self, src_sents, dst_sents, threshold = 0.4):
42
+ corpus_embeddings = self.model_sbert.encode(dst_sents, convert_to_tensor=True)
43
+ top_k = min(5, len(dst_sents))
44
+ ls_top_results = []
45
+ for query in src_sents:
46
+ query_embedding = self.model_sbert.encode(query, convert_to_tensor=True)
47
+ # We use cosine-similarity and torch.topk to find the highest 5 scores
48
+ cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
49
+ top_results = torch.topk(cos_scores, k=top_k)
50
+
51
+ # print("\n\n======================\n\n")
52
+ # print("Query:", src_sents)
53
+ # print("\nTop 5 most similar sentences in corpus:")
54
+ ls_top_results.append({
55
+ "top_k": top_k,
56
+ "claim": query,
57
+ "sim_score": top_results,
58
+ "evidences": [dst_sents[idx] for _, idx in zip(top_results[0], top_results[1])],
59
+ })
60
+
61
+ # for score, idx in zip(top_results[0], top_results[1]):
62
+ # print(dst_sents[idx], "(Score: {:.4f})".format(score))
63
+ return None,ls_top_results
64
+
65
+ @timer_func
66
+ def inferSample(self, evidence, claim):
67
+
68
+ @timer_func
69
+ def mDeBERTa_tokenize(data): # mDeBERTa model: Taking input_ids
70
+ premises = [premise for premise, _ in data['sample']]
71
+ hypothesis = [hypothesis for _, hypothesis in data['sample']]
72
+
73
+ with torch.no_grad():
74
+ input_token = (self.tokenizer(premises, hypothesis, truncation=True, return_tensors="pt", padding = True)['input_ids']).to(self.envir)
75
+ embedding = self.mDeBertaModel(input_token).last_hidden_state
76
+
77
+ mean_embedding = torch.mean(embedding[:, 1:, :], dim = 1)
78
+ cls_embedding = embedding[:, 0, :]
79
+
80
+ return {'mean':mean_embedding, 'cls':cls_embedding}
81
+
82
+ @timer_func
83
+ def predict_mapping(batch):
84
+ with torch.no_grad():
85
+ predict_label, predict_prob = self.classifierModel.predict_step((batch[self.INPUT_TYPE].to(self.envir), None))
86
+ return {'label':predict_label, 'prob':-predict_prob}
87
+
88
+ # Mapping the predict label into corresponding string labels
89
+ @timer_func
90
+ def output_predictedDataset(predict_dataset):
91
+ for record in predict_dataset:
92
+ labels = int2label[ record['label'].item() ]
93
+ confidence = record['prob'].item()
94
+
95
+ return {'labels':labels, 'confidence':confidence}
96
+
97
+ dataset = {'sample':[(evidence, claim)], 'key': [0]}
98
+ output_dataset = DatasetDict({
99
+ 'infer': Dataset.from_dict(dataset)
100
+ })
101
+
102
+ @timer_func
103
+ def tokenize_dataset():
104
+
105
+ tokenized_dataset = output_dataset.map(mDeBERTa_tokenize, batched=True, batch_size=1)
106
+ return tokenized_dataset
107
+
108
+ tokenized_dataset = tokenize_dataset()
109
+ tokenized_dataset = tokenized_dataset.with_format("torch", [self.INPUT_TYPE, 'key'])
110
+ # Running inference step
111
+ predicted_dataset = tokenized_dataset.map(predict_mapping, batched=True, batch_size=tokenized_dataset['infer'].num_rows)
112
+ return output_predictedDataset(predicted_dataset['infer'])
113
+
114
+ @timer_func
115
+ def predict_vt(self, claim: str) -> List:
116
+ # import pdb; pdb.set_trace()
117
+ # step 1: crawl evidences from bing search
118
+ crawler = MyCrawler()
119
+ evidences = crawler.searchGoogle(claim)
120
+
121
+ # evidences = crawler.get_evidences(claim)
122
+ # step 2: use emebdding setences to search most related setences
123
+ if len(evidences) == 0:
124
+ return None
125
+
126
+ for evidence in evidences:
127
+ print(evidence['url'])
128
+ top_evidence = evidence["content"]
129
+
130
+ post_message = nltk.tokenize.sent_tokenize(claim)
131
+ evidences = nltk.tokenize.sent_tokenize(top_evidence)
132
+ _, top_rst = self.get_similarity_v2(post_message, evidences)
133
+
134
+ print(top_rst)
135
+
136
+ ls_evidence, final_verdict = self.get_result_nli_v2(top_rst)
137
+
138
+ print("FINAL: " + final_verdict)
139
+ # _, top_rst = self.get_similarity_v1(post_message, evidences)
140
+ # ls_evidence, final_verdict = self.get_result_nli_v1(post_message, top_rst, evidences)
141
+ return ls_evidence, final_verdict
142
+
143
+
144
+ @timer_func
145
+ def predict(self, claim):
146
+ crawler = MyCrawler()
147
+ evidences = crawler.searchGoogle(claim)
148
+
149
+ if evidences:
150
+ tokenized_claim = nltk.tokenize.sent_tokenize(claim)
151
+ evidence = evidences[0]
152
+ tokenized_evidence = nltk.tokenize.sent_tokenize(evidence["content"])
153
+ # print("TOKENIZED EVIDENCES")
154
+ # print(tokenized_evidence)
155
+ _, top_rst = self.get_similarity_v2(tokenized_claim, tokenized_evidence)
156
+
157
+ processed_evidence = "\n".join(top_rst[0]["evidences"])
158
+ print(processed_evidence)
159
+
160
+ nli_result = self.inferSample(processed_evidence, claim)
161
+ return {
162
+ "claim": claim,
163
+ "label": nli_result["labels"],
164
+ "confidence": nli_result['confidence'],
165
+ "evidence": processed_evidence if nli_result["labels"] != "NEI" else "",
166
+ "provider": evidence['provider'],
167
+ "url": evidence['url']
168
+ }
169
+
170
+
171
+
172
+ @timer_func
173
+ def predict_nofilter(self, claim):
174
+ crawler = MyCrawler()
175
+ evidences = crawler.searchGoogle(claim)
176
+ tokenized_claim = nltk.tokenize.sent_tokenize(claim)
177
+
178
+ evidence = evidences[0]
179
+
180
+ processed_evidence = evidence['content']
181
+
182
+ nli_result = self.inferSample(processed_evidence, claim)
183
+ return {
184
+ "claim": claim,
185
+ "label": nli_result["labels"],
186
+ "confidence": nli_result['confidence'],
187
+ "evidence": processed_evidence if nli_result["labels"] != "NEI" else "",
188
+ "provider": evidence['provider'],
189
+ "url": evidence['url']
190
+ }
src/nli_v3.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ import pandas as pd
4
+
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ # import datasets
8
+ from datasets import Dataset, DatasetDict
9
+
10
+ from sklearn.metrics import classification_report
11
+ from sklearn.metrics._classification import _check_targets
12
+
13
+ envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14
+
15
+ int2label = {0:'SUPPORTED', 1:'NEI', 2:'REFUTED'}
16
+
17
+ class NLI_model(nn.Module):
18
+ def __init__(self, input_dims, class_weights=torch.tensor([0., 0., 0.])):
19
+ super(NLI_model, self).__init__()
20
+
21
+ self.classification = nn.Sequential(
22
+ nn.Linear(input_dims, 3)
23
+ )
24
+
25
+ self.criterion = nn.CrossEntropyLoss(class_weights)
26
+
27
+ def forward(self, input):
28
+ output_linear = self.classification(input)
29
+ return output_linear
30
+
31
+ def training_step(self, train_batch, batch_idx=0):
32
+ input_data, targets = train_batch
33
+ outputs = self.forward(input_data)
34
+ loss = self.criterion(outputs, targets)
35
+ return loss
36
+
37
+ def predict_step(self, batch, batch_idx=0):
38
+ input_data, _ = batch
39
+ outputs = self.forward(input_data)
40
+ prob = outputs.softmax(dim = -1)
41
+ sort_prob, sort_indices = torch.sort(-prob, 1)
42
+ return sort_indices[:,0], sort_prob[:,0]
43
+
44
+ def validation_step(self, val_batch, batch_idx=0):
45
+ _, targets = val_batch
46
+ sort_indices, _ = self.predict_step(val_batch, batch_idx)
47
+ report = classification_report(list(targets.to('cpu').numpy()), list(sort_indices.to('cpu').numpy()), output_dict=True, zero_division = 1)
48
+ return report
49
+
50
+ def test_step(self, batch, dict_form, batch_idx=0):
51
+ _, targets = batch
52
+ sort_indices, _ = self.predict_step(batch, batch_idx)
53
+ report = classification_report(targets.to('cpu').numpy(), sort_indices.to('cpu').numpy(), output_dict=dict_form, zero_division = 1)
54
+ return report
55
+
56
+ def configure_optimizers(self):
57
+ return torch.optim.Adam(self.parameters(), lr = 1e-5)
58
+
59
+
60
+ def inferSample(evidence, claim, tokenizer, mDeBertaModel, classifierModel, input_type):
61
+
62
+ def mDeBERTa_tokenize(data): # mDeBERTa model: Taking input_ids
63
+ premises = [premise for premise, _ in data['sample']]
64
+ hypothesis = [hypothesis for _, hypothesis in data['sample']]
65
+
66
+ with torch.no_grad():
67
+ input_token = (tokenizer(premises, hypothesis, truncation=True, return_tensors="pt", padding = True)['input_ids']).to(envir)
68
+ embedding = mDeBertaModel(input_token).last_hidden_state
69
+
70
+ mean_embedding = torch.mean(embedding[:, 1:, :], dim = 1)
71
+ cls_embedding = embedding[:, 0, :]
72
+
73
+ return {'mean':mean_embedding, 'cls':cls_embedding}
74
+
75
+ def predict_mapping(batch):
76
+ with torch.no_grad():
77
+ predict_label, predict_prob = classifierModel.predict_step((batch[input_type].to(envir), None))
78
+ return {'label':predict_label, 'prob':-predict_prob}
79
+
80
+ # Mapping the predict label into corresponding string labels
81
+ def output_predictedDataset(predict_dataset):
82
+ for record in predict_dataset:
83
+ labels = int2label[ record['label'].item() ]
84
+ confidence = record['prob'].item()
85
+
86
+ return {'labels':labels, 'confidence':confidence}
87
+
88
+ dataset = {'sample':[(evidence, claim)], 'key': [0]}
89
+
90
+ output_dataset = DatasetDict({
91
+ 'infer': Dataset.from_dict(dataset)
92
+ })
93
+
94
+ tokenized_dataset = output_dataset.map(mDeBERTa_tokenize, batched=True, batch_size=1)
95
+ tokenized_dataset = tokenized_dataset.with_format("torch", [input_type, 'key'])
96
+
97
+ # Running inference step
98
+ predicted_dataset = tokenized_dataset.map(predict_mapping, batched=True, batch_size=tokenized_dataset['infer'].num_rows)
99
+ return output_predictedDataset(predicted_dataset['infer'])
100
+
101
+ if __name__ == '__main__':
102
+ # CHANGE 'INPUT_TYPE' TO CHANGE MODEL
103
+ INPUT_TYPE = 'mean' # USE "MEAN" OR "CLS" LAST HIDDEN STATE
104
+
105
+ # Load LLM
106
+ tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # LOAD mDEBERTa TOKENIZER
107
+ mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{INPUT_TYPE}") # LOAD FINETUNED MODEL
108
+ # Load classifier model
109
+ checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{INPUT_TYPE}.pt", map_location=envir)
110
+ classifierModel = NLI_model(768, torch.tensor([0., 0., 0.])).to(envir)
111
+ classifierModel.load_state_dict(checkpoints['model_state_dict'])
112
+
113
+ evidence = "Sau khi thẩm định, Liên đoàn Bóng đá châu Á AFC xác nhận thủ thành mới nhập quốc tịch của Việt Nam Filip Nguyễn đủ điều kiện thi đấu ở Asian Cup 2024."
114
+ claim = "Filip Nguyễn đủ điều kiện dự Asian Cup 2024"
115
+ print(inferSample(evidence, claim, tokenizer, mDeBertaModel, classifierModel, INPUT_TYPE))
src/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ def timer_func(func):
4
+ # This function shows the execution time of
5
+ # the function object passed
6
+ def wrap_func(*args, **kwargs):
7
+ t1 = time()
8
+ result = func(*args, **kwargs)
9
+ t2 = time()
10
+ print(f'Function {func.__name__!r} executed in {(t2-t1):.4f}s')
11
+ return result
12
+ return wrap_func