import os |
import sys |
import gradio as gr |
import html |
import torch |
from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, AutoModelForTokenClassification, pipeline |
from torch import nn |
import torch.nn.functional as F |
from underthesea import word_tokenize |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
bartpho_mt_base = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-multi-task") |
bartpho_mt_base_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-multi-task") |
bartpho_mt_base.to(device) |
bartpho_mt = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-Large-multi-task") |
bartpho_mt_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-Large-multi-task") |
bartpho_mt.to(device) |
def segmenter(text): |
text = html.unescape(text) |
tokens = word_tokenize(text) |
result = [] |
for token in tokens: |
if ' ' in token: |
result.append(token.replace(' ', '_')) |
else: |
result.append(token) |
return result |
class MultiTaskModel: |
def __init__(self, model, tokenizer, device): |
self.model = model |
self.tokenizer = tokenizer |
self.device = device |
def get_prompt(self, task): |
if task == 'sa': |
return "Classify the sentiment: " |
elif task == 'mt-en-vi': |
return "Translate English to Vietnamese: " |
elif task == 'mt-vi-en': |
return "Translate Vietnamese to English: " |
else: |
return "" |
def inference(self, task, sentence, device): |
tokenized_text = segmenter(sentence) |
source = self.get_prompt(task) + " ".join(tokenized_text) |
inputs = self.tokenizer(source, padding='max_length', truncation=True, max_length=128, return_tensors='pt') |
input_ids = inputs["input_ids"].to(device) |
attention_mask = inputs["attention_mask"].to(device) |
self.model.eval() |
with torch.no_grad(): |
generated_output = self.model.generate(input_ids, attention_mask=attention_mask, max_length=128) |
prediction = self.tokenizer.decode(generated_output[0], skip_special_tokens=True) |
if task == 'sa': |
class_names = ["Negative", "Positive"] |
return class_names[int(prediction[0])] |
return html.unescape(prediction) |
class CustomModel(nn.Module): |
def __init__(self, bert_model): |
super(CustomModel, self).__init__() |
self.bert = bert_model |
self.mlp = nn.Sequential( |
nn.Linear(768 * 5, 512), |
nn.ReLU(), |
nn.Linear(512, 256), |
nn.ReLU(), |
nn.Linear(256, 3) |
) |
def forward(self, input_ids, attention_mask): |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
last_hidden_states = outputs.hidden_states[-5:] |
cls_embeddings = torch.cat([state[:, 0, :] for state in last_hidden_states], dim=1) |
logits = self.mlp(cls_embeddings) |
return logits |
phobert_sa = AutoModel.from_pretrained("vinai/phobert-base", output_hidden_states=True) |
phobert_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") |
phobert_sa = CustomModel(phobert_sa) |
phobert_sa.load_state_dict(torch.load('phobert_sentiment_analysis.pth', map_location=device)) |
phobert_sa.to(device) |
phobertv2_sa = AutoModel.from_pretrained("vinai/phobert-base-v2", output_hidden_states=True) |
phobertv2_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2") |
phobertv2_sa = CustomModel(phobertv2_sa) |
phobertv2_sa.load_state_dict(torch.load('phobertv2_sentiment_analysis.pth', map_location=device)) |
phobertv2_sa.to(device) |
m_bert_sa = AutoModel.from_pretrained("google-bert/bert-base-multilingual-cased", output_hidden_states=True) |
m_bert_sa_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-cased") |
m_bert_sa = CustomModel(m_bert_sa) |
m_bert_sa.load_state_dict(torch.load('bert_model_sentiment_analysis.pth', map_location=device)) |
m_bert_sa.to(device) |
roberta_large_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned") |
roberta_large_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned") |
roberta_large_qa.to(device) |
roberta_base_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/xlm-roberta-base-fine-tuned-qa-vietnamese", output_hidden_states=True) |
roberta_base_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/xlm-roberta-base-fine-tuned-qa-vietnamese") |
roberta_base_qa.to(device) |
m_bert_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/bert-base-multilingual-cased-fine-tuned-qa-vietnamese") |
m_bert_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/bert-base-multilingual-cased-fine-tuned-qa-vietnamese") |
m_bert_qa.to(device) |
label_map = { |
'B-LOC': 0, |
'B-MISC': 1, |
'B-ORG': 2, |
'B-PER': 3, |
'I-LOC': 4, |
'I-MISC': 5, |
'I-ORG': 6, |
'I-PER': 7, |
'O': 8 |
} |
phobert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERT", num_labels=len(label_map)) |
phobert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERT") |
phobert_ner.to(device) |
phobertv2_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERTv2", num_labels=len(label_map)) |
phobertv2_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERTv2") |
phobertv2_ner.to(device) |
m_bert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER_MultilingualBERT", num_labels=len(label_map)) |
m_bert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER_MultilingualBERT") |
m_bert_ner.to(device) |
def sentiment_inference(model, tokenizer, text, device): |
text = " ".join(segmenter(text)) |
inputs = tokenizer( |
text, |
padding='max_length', |
truncation=True, |
max_length=128, |
return_tensors='pt' |
) |
input_ids = inputs['input_ids'].to(device) |
attention_mask = inputs['attention_mask'].to(device) |
input_ids = input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids |
attention_mask = attention_mask.unsqueeze(0) if attention_mask.dim() == 1 else attention_mask |
model.eval() |
with torch.no_grad(): |
outputs = model(input_ids, attention_mask) |
_, preds = torch.max(outputs, dim=1) |
class_names = ["Negative", "Positive", "Neutral"] |
return class_names[preds.cpu().item()] |
def multitask_inference(model, tokenizer, text, task, device): |
multitask_model = MultiTaskModel(model, tokenizer, device) |
return multitask_model.inference(task, text, device) |
def qa_inference(model, tokenizer, question, context, device): |
qa_pipeline = pipeline('question-answering', model=model, tokenizer=tokenizer) |
res = qa_pipeline(question=question, context=context) |
return res['answer'] |
def ner_inference(model, tokenizer, text, device): |
predictions = [] |
inputs = tokenizer( |
text, |
padding='max_length', |
truncation=True, |
max_length=128, |
return_tensors='pt' |
) |
input_ids = inputs['input_ids'].to(device) |
attention_mask = inputs['attention_mask'].to(device) |
model.eval() |
with torch.no_grad(): |
outputs = model(input_ids, attention_mask) |
_, preds = torch.max(outputs.logits, dim=2) |
id_to_label = {v: k for k, v in label_map.items()} |
predictions = preds[attention_mask.bool()].cpu().numpy().flatten() |
labels = [id_to_label[p] for p in predictions] |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=True) |
labels = labels[1:-1] |
ner_tags = list(zip(tokens, labels)) |
return ner_tags |
def process_input(input_text, context, task): |
results = {} |
if task == "Sentiment Analysis": |
results["PhoBERT"] = sentiment_inference(phobert_sa, phobert_sa_tokenizer, input_text, device) |
results["PhoBERTv2"] = sentiment_inference(phobertv2_sa, phobertv2_sa_tokenizer, input_text, device) |
results["Multilingual BERT"] = sentiment_inference(m_bert_sa, m_bert_sa_tokenizer, input_text, device) |
results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "sa", device) |
results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "sa", device) |
elif task == "English to Vietnamese": |
results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-en-vi", device) |
results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-en-vi", device) |
elif task == "Vietnamese to English": |
results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-vi-en", device) |
results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-vi-en", device) |
elif task == "Question Answering": |
results["RoBERTa Base"] = qa_inference(roberta_base_qa, roberta_base_qa_tokenizer, input_text, context, device) |
results["RoBERTa Large"] = qa_inference(roberta_large_qa, roberta_large_qa_tokenizer, input_text, context, device) |
results["Multilingual BERT"] = qa_inference(m_bert_qa, m_bert_qa_tokenizer, input_text, context, device) |
elif task == "Named Entity Recognition": |
results["PhoBERT"] = ner_inference(phobert_ner, phobert_ner_tokenizer, input_text, device) |
results["PhoBERTv2"] = ner_inference(phobertv2_ner, phobertv2_ner_tokenizer, input_text, device) |
results["Multilingual BERT"] = ner_inference(m_bert_ner, m_bert_ner_tokenizer, input_text, device) |
return results |
with gr.Blocks() as iface: |
gr.Markdown("# Multi-task NLP Demo") |
gr.Markdown("Perform sentiment analysis, machine translation, question answering, or named entity recognition using various models.") |
with gr.Row(): |
task = gr.Radio(["Sentiment Analysis", "Question Answering", "Named Entity Recognition", "English to Vietnamese", "Vietnamese to English"], label="Task") |
with gr.Row(): |
input_text = gr.Textbox(label="Input Text") |
context = gr.Textbox(label="Context", visible=False) |
output = gr.JSON(label="Results") |
submit = gr.Button("Submit") |
def on_task_change(task): |
if task == "Question Answering": |
return { |
input_text: gr.update(label="Question", visible=True), |
context: gr.update(visible=True) |
} |
else: |
return { |
input_text: gr.update(label="Input Text", visible=True), |
context: gr.update(visible=False) |
} |
task.change(on_task_change, task, [input_text, context]) |
submit.click( |
process_input, |
inputs=[input_text, context, task], |
outputs=output |
) |
if __name__ == "__main__": |
iface.launch(share=True) |