Tran Xuan Huy
Update app.py
eaf3054
import gradio as gr
import json
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from collections import namedtuple
import datetime
import calendar
from dateutil.relativedelta import relativedelta
from something import time2date, output2url
fields = ['device', 'model_name', 'max_source_length', 'max_target_length', 'beam_size']
params = namedtuple('params', field_names=fields)
args = params(
device="cuda" if torch.cuda.is_available() else "cpu",
model_name='facebook/mbart-large-50-many-to-many-mmt',
max_source_length=256,
max_target_length=256,
beam_size=1
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"Huy1432884/db_retrieval",
use_auth_token="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"
)
model.to(args.device)
model.eval()
if "mbart" in args.model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, src_lang="vi_VN", tgt_lang="vi_VN"
)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
def text_analysis(text):
text = text.lower()
inputs = tokenizer(
[text],
text_target=None,
padding="longest",
max_length=args.max_source_length,
truncation=True,
return_tensors="pt",
)
for k, v in inputs.items():
inputs[k] = v.to(args.device)
if "mbart" in args.model_name:
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
outputs = model.generate(
**inputs,
max_length=args.max_target_length,
num_beams=args.beam_size,
early_stopping=True,
)
output_sentences = tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
out = json.loads("{" + output_sentences[0] + "}")
if out['LOẠI BIỂU ĐỒ']=='dashboard':
if out['CHU KỲ THỜI GIAN']!='tháng':
chu_ky_in = 'ngày'
else:
chu_ky_in = 'tháng'
out['CHU KỲ THỜI GIAN']='ngày' if out['CHU KỲ THỜI GIAN'] not in ['ngày', 'tháng'] else out['CHU KỲ THỜI GIAN']
check_dashboard = out['ĐƠN VỊ']+"_"+chu_ky_in
out['DB URL'] = output2url[check_dashboard]
out['DATE'] = str(time2date(out)).replace("-", "").replace("-", "")
out['FINAL URL'] = "https://vsds.viettel.vn"+ out['DB URL'] + "?toDate=" + out['DATE']
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN', 'DB URL', 'DATE', 'FINAL URL']}
elif out['LOẠI BIỂU ĐỒ']=='biểu đồ':
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN']}
else:
show = out
return show
demo = gr.Interface(
text_analysis,
gr.Textbox(placeholder="Enter sentence here..."),
["json"],
examples=[
["Mở dashboard vtc ngày hôm qua"],
["Mở biểu đồ cột td ngày này"],
["Hãy mở biểu đồ cơ cấu của tập đoàn trong ngày hôm nay"],
["Tháng này, vtc cần tôi mở biểu đồ rank để cập nhật danh sách khách hàng"],
["Các thông số NAT ngày hôm qua đã được ghi nhận trên đát bọt"],
["Hôm nay hãy mở của Viettel tt không gian mạng Viettel vtcc để kiểm tra"],
["Mở DB CTM ngày gốc"],
["Tôi đã sử dụng Dashboard để truy cập thông tin qti vào ngày hôm nay"],
["Trưởng phòng đã ra lệnh mở biểu đồ kết hợp đường và cột cho toàn tập đoàn vào hôm nay"]
],
)
demo.launch()