|
import gradio as gr |
|
|
|
import json |
|
from functools import partial |
|
from typing import Callable, Dict, List |
|
import transformers |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
pipeline |
|
) |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
'airesearch/wangchanbart-large', |
|
revision='finetuned@wisesight_sentiment', |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
'airesearch/wangchanbart-large', |
|
) |
|
model.config.return_all_scores = True |
|
LABEL_MAPPING = { |
|
'pos': '🤗 Positive', |
|
'neu': '😐 Neutral', |
|
'neg': '😡 Negative', |
|
'q': '🤔 Quesiton', |
|
} |
|
CSS_PROGRESS_BAR_MAPPING = { |
|
'pos':'w3-green', |
|
'neu': 'w3-light-blue', |
|
'neg': 'w3-red', |
|
'q': 'w3-blue', |
|
} |
|
LABEL_MAPPING_REVERSED = {v:k for k,v in LABEL_MAPPING.items() } |
|
text_cls_pipeline = pipeline(task='text-classification', |
|
tokenizer=tokenizer, |
|
model=model, |
|
return_all_scores=True) |
|
|
|
css_text = """<link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.css">""" |
|
|
|
def render_html(items: List[Dict]): |
|
html_text = '' |
|
for item in items: |
|
|
|
label, score = item['label'], item['score'] |
|
label_id = LABEL_MAPPING_REVERSED[label] |
|
|
|
progress_bar_class_text = CSS_PROGRESS_BAR_MAPPING[label_id] |
|
|
|
html_text += f'<span>{label.replace(" ", " ")}: {(score*100):8.2f}%<span>' + \ |
|
f'<div class="w3-light-grey w3-round"><div class="{progress_bar_class_text} w3-round" style="height:19px;width:{round(score*100,2)}%"></div></div><div style="height:8px;"></div>' |
|
|
|
return '<div class="w3-container">' + html_text + '</div>' |
|
|
|
def classify_text(text: str): |
|
text = text.replace(' ', '<_>') |
|
results = text_cls_pipeline(text)[0] |
|
print(f'results:\n {results}') |
|
for i, result in enumerate(results): |
|
results[i]['label'] = LABEL_MAPPING[result['label']] |
|
results[i]['score'] = float(round(float(result['score']), 4)) |
|
html_text = css_text + render_html(results) |
|
print(html_text) |
|
return json.dumps(results, ensure_ascii=False, indent=4), html_text |
|
|
|
|
|
demo = gr.Interface(fn=classify_text, |
|
inputs=gr.Textbox(lines=5, placeholder='Input text in Thai', label='Input text'), |
|
examples=[ |
|
['ขอบคุณมากค๊าา เดี๋ยวไปทานแล้วจะถ่ายรูปสวยๆ มาให้นะคะ 😊'], |
|
['ฟอร์ด บุกตลาด อีวี ในอินเดีย #prachachat #ตลาดรถยนต์'], |
|
['สั่งไป2 เมนู คือมัชฉะลาเต้ร้อน กับ ไอศครีมชาเขียว มัชฉะลาเต้ร้อน รสชาเขียวเข้มข้น หอม มัน แต่ไม่กลมกล่อม มันจืดแบบจืดสนิท ส่วนไอศครีมชาเขียว ทานแล้วรสมันออกใบไม้ๆมากกว่าชาเขียว แล้วก็หวานไป โดยรวมแล้วเฉยมากก ดีแค่รสชาเขียวเข้ม มีน้ำเปล่าบริการฟรี'], |
|
['สาขานี้มีลิปของ Etude ไหมอ่าคะ '], |
|
], |
|
|
|
outputs=[gr.Textbox(), gr.HTML()]) |
|
|
|
print(f'\nINFO: transformers.__version__: {transformers.__version__}') |
|
|
|
|
|
demo.launch() |
|
|