|
--- |
|
language: |
|
- en |
|
- ko |
|
license: apache-2.0 |
|
datasets: AI-Hub |
|
metrics: |
|
- accuracy |
|
pipeline_tag: text-classification |
|
--- |
|
# 1. Introduction |
|
|
|
## 1.1 examples |
|
|
|
![examples](https://github.com/BurningFalls/algorithm-study/assets/30232837/596e5010-53b6-4598-8dd3-4ef7fc65e60e) |
|
|
|
## 1.2 f1-score |
|
|
|
![bert_accuracy](https://github.com/BurningFalls/algorithm-study/assets/30232837/58830340-aebe-4dc2-85fa-313138ac3020) |
|
|
|
--- |
|
|
|
# 2. Requirements |
|
```python |
|
# my env |
|
python==3.11.3 |
|
tensorflow==2.12.0 |
|
transformers==4.29.2 |
|
|
|
# maybe you need to |
|
python>=3.6 |
|
tensorflow>=2.0 |
|
transformers>=4.0 |
|
``` |
|
|
|
--- |
|
|
|
# 3. Load |
|
```python |
|
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification |
|
from transformers import TextClassificationPipeline |
|
|
|
BERT_PARH = "burningfalls/my-fine-tuned-bert" |
|
|
|
def load_bert(): |
|
loaded_tokenizer = AutoTokenizer.from_pretrained(BERT_PATH) |
|
loaded_model = TFAutoModelForSequenceClassification.from_pretrained(BERT_PATH) |
|
|
|
text_classifier = TextClassificationPipeline( |
|
tokenizer=loaded_tokenizer, |
|
model=loaded_model, |
|
framework='tf', |
|
top_k=1 |
|
) |
|
``` |
|
|
|
--- |
|
|
|
# 4. Usage |
|
```python |
|
import re |
|
import sentiments |
|
|
|
def predict_sentiment(text): |
|
result = text_classifier(text)[0] |
|
feel_idx = int(re.sub(r'[^0-9]', '', result[0]['label'])) |
|
feel = sentiments.Feel[feel_idx]["label"] |
|
|
|
return feel |
|
``` |
|
|
|
--- |
|
|
|
# 5. sentiments.py |
|
```python |
|
Feel = [ |
|
{"label": "κ°λν, λΆμ°ν", "index": 0}, |
|
{"label": "κ°μ¬νλ", "index": 1}, |
|
{"label": "κ±±μ μ€λ¬μ΄", "index": 2}, |
|
{"label": "κ³ λ¦½λ", "index": 3}, |
|
{"label": "κ΄΄λ‘μνλ", "index": 4}, |
|
{"label": "ꡬμμ§ λλ", "index": 5}, |
|
{"label": "κΈ°μ¨", "index": 6}, |
|
{"label": "λλ΄ν", "index": 7}, |
|
{"label": "λ¨μ μμ μ μμνλ", "index": 8}, |
|
{"label": "λ
Έμ¬μνλ", "index": 9}, |
|
{"label": "λλ¬Όμ΄ λλ", "index": 10}, |
|
{"label": "λκΈ", "index": 11}, |
|
{"label": "λΉνΉμ€λ¬μ΄", "index": 12}, |
|
{"label": "λΉν©", "index": 13}, |
|
{"label": "λλ €μ΄", "index": 14}, |
|
{"label": "λ§λΉλ", "index": 15}, |
|
{"label": "λ§μ‘±μ€λ¬μ΄", "index": 16}, |
|
{"label": "λ°©μ΄μ μΈ", "index": 17}, |
|
{"label": "λ°°μ λΉν", "index": 18}, |
|
{"label": "λ²λ €μ§", "index": 19}, |
|
{"label": "λΆλλ¬μ΄", "index": 20}, |
|
{"label": "λΆλ
Έ", "index": 21}, |
|
{"label": "λΆμ", "index": 22}, |
|
{"label": "λΉν΅ν", "index": 23}, |
|
{"label": "μμ²", "index": 24}, |
|
{"label": "μ±κ°μ ", "index": 25}, |
|
{"label": "μ€νΈλ μ€ λ°λ", "index": 26}, |
|
{"label": "μ¬ν", "index": 27}, |
|
{"label": "μ λ’°νλ", "index": 28}, |
|
{"label": "μ μ΄ λ", "index": 29}, |
|
{"label": "μ€λ§ν", "index": 30}, |
|
{"label": "μ
μμ μΈ", "index": 31}, |
|
{"label": "μλ¬νλ", "index": 32}, |
|
{"label": "μλ", "index": 33}, |
|
{"label": "μ΅μΈν", "index": 34}, |
|
{"label": "μ΄λ±κ°", "index": 35}, |
|
{"label": "μΌμΈμ μΈ", "index": 36}, |
|
{"label": "μΈλ‘μ΄", "index": 37}, |
|
{"label": "μ°μΈν", "index": 38}, |
|
{"label": "μμ νλ", "index": 39}, |
|
{"label": "μ‘°μ¬μ€λ¬μ΄", "index": 40}, |
|
{"label": "μ’μ ν", "index": 41}, |
|
{"label": "μ£μ±
κ°μ", "index": 42}, |
|
{"label": "μ§ν¬νλ", "index": 43}, |
|
{"label": "μ§μ¦λ΄λ", "index": 44}, |
|
{"label": "μ΄μ‘°ν", "index": 45}, |
|
{"label": "좩격 λ°μ", "index": 46}, |
|
{"label": "μ·¨μ½ν", "index": 47}, |
|
{"label": "ν΄ν΄λλ", "index": 48}, |
|
{"label": "νΈμν", "index": 49}, |
|
{"label": "νμ¬ν", "index": 50}, |
|
{"label": "νμ€μ€λ¬μ΄", "index": 51}, |
|
{"label": "νΌλμ€λ¬μ΄", "index": 52}, |
|
{"label": "νλ©Έμ λλΌλ", "index": 53}, |
|
{"label": "νμμ μΈ", "index": 54}, |
|
{"label": "ννλλ", "index": 55}, |
|
{"label": "ν₯λΆ", "index": 56}, |
|
{"label": "ν¬μλ", "index": 57}, |
|
] |
|
``` |
|
|
|
--- |
|
|
|
# 6. Reference |
|
|
|
* BERT: [klue/bert-base](https://huggingface.co/klue/bert-base) |
|
|
|
* Dataset: [AI-Hub κ°μ± λν λ§λμΉ](https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=86) |