truong-xuan-linh commited on
Commit
7d5dab0
1 Parent(s): abc1c82
.github/workflows/main.yml CHANGED
@@ -17,4 +17,4 @@ jobs:
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
- run: git push https://truong-xuan-linh:$HF_TOKEN@huggingface.co/spaces/truong-xuan-linh/content-category-classification master
 
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push --force https://truong-xuan-linh:$HF_TOKEN@huggingface.co/spaces/truong-xuan-linh/content-category-classification main
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -1 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # content_category_classification
 
1
+ ---
2
+ title: Content Category Classification
3
+ emoji: 🏆
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.26.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
  # content_category_classification
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from omegaconf import OmegaConf
3
+
4
+ #Trick to not init function multitime
5
+ if "category_model" not in st.session_state:
6
+ print("INIT MODEL")
7
+ from src.category_model import CategoryModel
8
+ from src.category_model import PhoBERT_classification
9
+ src_config = OmegaConf.load('config/config.yaml')
10
+ st.session_state.category_model = CategoryModel(config=src_config)
11
+ print("DONE INIT MODEL")
12
+
13
+ st.set_page_config(page_title="Vietnamese Category Classification", layout="wide", page_icon = "./linhai.jpeg")
14
+ hide_menu_style = """
15
+ <style>
16
+ footer {visibility: hidden;}
17
+ </style>
18
+ """
19
+ st.markdown(hide_menu_style, unsafe_allow_html= True)
20
+
21
+ st.markdown(
22
+ """
23
+ <style>
24
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
25
+ width: 400px;
26
+ }
27
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
28
+ margin-left: -400px;
29
+ }
30
+
31
+ """,
32
+ unsafe_allow_html=True,
33
+ )
34
+
35
+ st.markdown("<h2 style='text-align: center; color: grey;'>Input: Vietnamese content</h2>", unsafe_allow_html=True)
36
+ st.markdown("<h2 style='text-align: center; color: grey;'>Output: Content classification</h2>", unsafe_allow_html=True)
37
+
38
+ content = st.text_input("Enter your content", value="The length of the sentence must be greater than 50.")
39
+
40
+ if st.button("Submit"):
41
+ st.write("**RESULT:** ")
42
+ if len(content.split()) < 50:
43
+ st.write("The length of the sentence must be greater than 50.")
44
+ else:
45
+ result = st.session_state.category_model.predict(content)
46
+ st.write(result)
config/classes.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"music": 0,
2
+ "food": 1,
3
+ "technology": 2,
4
+ "travel": 3,
5
+ "animal": 4,
6
+ "life": 5,
7
+ "family": 6,
8
+ "entertainment": 7,
9
+ "education": 8,
10
+ "youth": 9,
11
+ "fun": 10,
12
+ "cartoon": 11,
13
+ "science": 12,
14
+ "economy": 13,
15
+ "history": 14,
16
+ "shopping": 15,
17
+ "celebrity": 16,
18
+ "law": 17,
19
+ "movie": 18,
20
+ "book": 19,
21
+ "beauty": 20,
22
+ "health": 21,
23
+ "world": 22,
24
+ "sports": 23,
25
+ "nature": 24,
26
+ "news": 25,
27
+ "fashion": 26,
28
+ "game": 27,
29
+ "culture": 28,
30
+ "vehicles": 29,
31
+ "medical": 30}
config/config.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ model:
2
+ path: ./models/freeze_clean_warnup_0.0005_0.6648_2.8259.pt
3
+ url: https://drive.google.com/uc?id=1gKBx1sgHhJOyLmCidCm_serwgDYG6U1g
4
+ theshold: 0.3
5
+ min_length: 50
linhai.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.28.1
2
+ torch==2.0.0
3
+ gdown==4.7.1
4
+ underthesea==6.7.0
5
+ omegaconf==2.0.6
6
+ streamlit==1.26.0
src/category_model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import gdown
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from underthesea import word_tokenize
10
+ from transformers import AutoTokenizer
11
+
12
+ class PhoBERT_classification(nn.Module):
13
+ def __init__(self, phobert):
14
+ super(PhoBERT_classification, self).__init__()
15
+
16
+ self.phobert = phobert
17
+ self.dropout = nn.Dropout(0.2)
18
+ self.relu = nn.ReLU()
19
+ self.fc1 = nn.Linear(768, 512, device=self.DEVICE)
20
+ self.fc2 = nn.Linear(512, self.classes.__len__(), device=self.DEVICE)
21
+ self.softmax = nn.Softmax(dim=1)
22
+
23
+ def forward(self, input_ids, attention_mask):
24
+ last_hidden_states, cls_hs = self.phobert(input_ids=input_ids, \
25
+ attention_mask=attention_mask, \
26
+ return_dict=False)
27
+
28
+ x = self.fc1(last_hidden_states[:, 0, :])
29
+ x = self.relu(x)
30
+ x = self.dropout(x)
31
+
32
+ x = self.fc2(x)
33
+ x = self.softmax(x)
34
+
35
+ return x
36
+
37
+
38
+
39
+ class CategoryModel():
40
+ def __init__(self, config):
41
+ self.DEVICE = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ self.classes = json.load(open("./config/classes.json", "r"))
43
+ self.id2label = {v: k for k, v in self.classes.items()}
44
+
45
+ self.config = config
46
+ self.get_model()
47
+
48
+ def get_model(self):
49
+ self.tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
50
+
51
+ if not os.path.isfile(self.config.model.path):
52
+ gdown.download(self.config.model.url, self.config.model.path, quiet=True)
53
+ self.model = torch.load(self.config.model.path, map_location=self.DEVICE)
54
+ self.model.eval()
55
+
56
+
57
+ def predict(self, paragraph):
58
+
59
+ def clean_string(input_string):
60
+ # Sử dụng biểu thức chính quy để tìm và loại bỏ các ký tự không phải là chữ cái, khoảng trắng và số
61
+
62
+ input_string = input_string.replace("\n", " ")
63
+ split_string = input_string.split()
64
+ input_string = " ".join([text.title() if text.isupper() else text for text in split_string ])
65
+ cleaned_string = re.sub(r'[^\w\s]', '', input_string)
66
+ return cleaned_string
67
+
68
+ def input_tokenizer(text):
69
+ text = clean_string(text)
70
+ segment_text = word_tokenize(text, format="text")
71
+ tokenized_text = self.tokenizer(segment_text, \
72
+ padding="max_length", \
73
+ truncation=True, \
74
+ max_length=256, \
75
+ return_tensors="pt")
76
+ tokenized_text = {k: v.to(self.DEVICE) for k, v in tokenized_text.items()}
77
+ return tokenized_text
78
+
79
+ def get_top_acc(predictions, thre):
80
+ results = {}
81
+ indexes = np.where(predictions[0] > thre)[0]
82
+ for index in indexes:
83
+ results[self.id2label[index]] = float(predictions[0][index])
84
+ results = {k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)}
85
+
86
+ return results
87
+
88
+ tokenized_text = input_tokenizer(paragraph)
89
+ input_ids = tokenized_text["input_ids"]
90
+ token_type_ids = tokenized_text["token_type_ids"]
91
+ attention_mask = tokenized_text["attention_mask"]
92
+ with torch.no_grad():
93
+ logits = self.model(input_ids, attention_mask)
94
+
95
+ results = get_top_acc(logits.cpu().numpy(), self.config.model.theshold)
96
+ results_arr = []
97
+ for rs in results:
98
+ results_arr.append({
99
+ "category": rs,
100
+ "score": results[rs]
101
+ })
102
+ return results_arr
103
+
104
+
105
+ # if __name__ == '__main__':
106
+ # src_config = OmegaConf.load('config/config.yaml')
107
+ # CategoryModel = CategoryModel(config=src_config)
108
+
109
+ # result = CategoryModel.predict('''''')
110
+ # print(result)
111
+
112
+
113
+
test.ipynb ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from omegaconf import OmegaConf\n",
19
+ "from src.category_model import CategoryModel\n",
20
+ "from src.category_model import PhoBERT_classification"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "VnCoreNLP model folder . already exists! Please load VnCoreNLP from this folder!\n",
33
+ "2023-09-07 13:26:04 INFO WordSegmenter:24 - Loading Word Segmentation model\n"
34
+ ]
35
+ },
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
41
+ ]
42
+ },
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "{'fun': 0.9741222262382507}\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "src_config = OmegaConf.load('config/config.yaml')\n",
53
+ "CategoryModel = CategoryModel(config=src_config)\n",
54
+ "\n",
55
+ "result = CategoryModel.predict('''''')\n",
56
+ "print(result)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 3,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Collecting underthesea\n",
69
+ " Obtaining dependency information for underthesea from https://files.pythonhosted.org/packages/c2/08/f8827734caf4fee1642bb08129afca92579633d8f72fbf0bc2f9a73aa69c/underthesea-6.7.0-py3-none-any.whl.metadata\n",
70
+ " Downloading underthesea-6.7.0-py3-none-any.whl.metadata (14 kB)\n",
71
+ "Requirement already satisfied: Click>=6.0 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from underthesea) (8.1.7)\n",
72
+ "Collecting python-crfsuite>=0.9.6 (from underthesea)\n",
73
+ " Using cached python_crfsuite-0.9.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
74
+ "Collecting nltk (from underthesea)\n",
75
+ " Using cached nltk-3.8.1-py3-none-any.whl (1.5 MB)\n",
76
+ "Requirement already satisfied: tqdm in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from underthesea) (4.66.1)\n",
77
+ "Requirement already satisfied: requests in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from underthesea) (2.31.0)\n",
78
+ "Collecting joblib (from underthesea)\n",
79
+ " Obtaining dependency information for joblib from https://files.pythonhosted.org/packages/10/40/d551139c85db202f1f384ba8bcf96aca2f329440a844f924c8a0040b6d02/joblib-1.3.2-py3-none-any.whl.metadata\n",
80
+ " Using cached joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)\n",
81
+ "Collecting scikit-learn (from underthesea)\n",
82
+ " Obtaining dependency information for scikit-learn from https://files.pythonhosted.org/packages/d4/61/966d3238f6cbcbb13350d31bd0accfc5efdf9e349cd2a42d9761b8b67a18/scikit_learn-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\n",
83
+ " Downloading scikit_learn-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
84
+ "Requirement already satisfied: PyYAML in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from underthesea) (6.0.1)\n",
85
+ "Collecting underthesea-core==1.0.4 (from underthesea)\n",
86
+ " Obtaining dependency information for underthesea-core==1.0.4 from https://files.pythonhosted.org/packages/ab/09/63b71ed80c7c9f31f53297fede1345cafd5323debde4afb0ddbca8b2d800/underthesea_core-1.0.4-cp39-cp39-manylinux2010_x86_64.whl.metadata\n",
87
+ " Downloading underthesea_core-1.0.4-cp39-cp39-manylinux2010_x86_64.whl.metadata (1.7 kB)\n",
88
+ "Requirement already satisfied: regex>=2021.8.3 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from nltk->underthesea) (2023.8.8)\n",
89
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from requests->underthesea) (3.2.0)\n",
90
+ "Requirement already satisfied: idna<4,>=2.5 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from requests->underthesea) (3.4)\n",
91
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from requests->underthesea) (2.0.4)\n",
92
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from requests->underthesea) (2023.7.22)\n",
93
+ "Requirement already satisfied: numpy>=1.17.3 in /home/linh/hahalolo/storage/anaconda3/envs/vietnamese_categories_classification/lib/python3.9/site-packages (from scikit-learn->underthesea) (1.25.2)\n",
94
+ "Collecting scipy>=1.5.0 (from scikit-learn->underthesea)\n",
95
+ " Obtaining dependency information for scipy>=1.5.0 from https://files.pythonhosted.org/packages/a3/d3/f88285098505c8e5d141678a24bb9620d902c683f11edc1eb9532b02624e/scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\n",
96
+ " Using cached scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (59 kB)\n",
97
+ "Collecting threadpoolctl>=2.0.0 (from scikit-learn->underthesea)\n",
98
+ " Obtaining dependency information for threadpoolctl>=2.0.0 from https://files.pythonhosted.org/packages/81/12/fd4dea011af9d69e1cad05c75f3f7202cdcbeac9b712eea58ca779a72865/threadpoolctl-3.2.0-py3-none-any.whl.metadata\n",
99
+ " Using cached threadpoolctl-3.2.0-py3-none-any.whl.metadata (10.0 kB)\n",
100
+ "Downloading underthesea-6.7.0-py3-none-any.whl (20.9 MB)\n",
101
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.9/20.9 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0mm\n",
102
+ "\u001b[?25hUsing cached underthesea_core-1.0.4-cp39-cp39-manylinux2010_x86_64.whl (657 kB)\n",
103
+ "Using cached joblib-1.3.2-py3-none-any.whl (302 kB)\n",
104
+ "Using cached scikit_learn-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9 MB)\n",
105
+ "Using cached scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.5 MB)\n",
106
+ "Using cached threadpoolctl-3.2.0-py3-none-any.whl (15 kB)\n",
107
+ "\u001b[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
108
+ "\u001b[0mInstalling collected packages: underthesea-core, python-crfsuite, threadpoolctl, scipy, joblib, scikit-learn, nltk, underthesea\n",
109
+ "Successfully installed joblib-1.3.2 nltk-3.8.1 python-crfsuite-0.9.9 scikit-learn-1.3.0 scipy-1.11.2 threadpoolctl-3.2.0 underthesea-6.7.0 underthesea-core-1.0.4\n",
110
+ "Note: you may need to restart the kernel to use updated packages.\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "pip install underthesea\n"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 4,
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/plain": [
126
+ "'Chàng trai 9X Quảng_Trị khởi_nghiệp từ nấm sò'"
127
+ ]
128
+ },
129
+ "execution_count": 4,
130
+ "metadata": {},
131
+ "output_type": "execute_result"
132
+ }
133
+ ],
134
+ "source": [
135
+ "from underthesea import word_tokenize\n",
136
+ "sentence = \"Chàng trai 9X Quảng Trị khởi nghiệp từ nấm sò\"\n",
137
+ "\n",
138
+ "word_tokenize(sentence, format=\"text\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": []
147
+ }
148
+ ],
149
+ "metadata": {
150
+ "kernelspec": {
151
+ "display_name": "vietnamese_ocr",
152
+ "language": "python",
153
+ "name": "python3"
154
+ },
155
+ "language_info": {
156
+ "codemirror_mode": {
157
+ "name": "ipython",
158
+ "version": 3
159
+ },
160
+ "file_extension": ".py",
161
+ "mimetype": "text/x-python",
162
+ "name": "python",
163
+ "nbconvert_exporter": "python",
164
+ "pygments_lexer": "ipython3",
165
+ "version": "3.9.0"
166
+ },
167
+ "orig_nbformat": 4
168
+ },
169
+ "nbformat": 4,
170
+ "nbformat_minor": 2
171
+ }