Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from collections import namedtuple | |
import datetime | |
import calendar | |
from dateutil.relativedelta import relativedelta | |
from something import time2date, output2url | |
fields = ['device', 'model_name', 'max_source_length', 'max_target_length', 'beam_size'] | |
params = namedtuple('params', field_names=fields) | |
args = params( | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
model_name='facebook/mbart-large-50-many-to-many-mmt', | |
max_source_length=256, | |
max_target_length=256, | |
beam_size=1 | |
) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"Huy1432884/db_retrieval", | |
use_auth_token="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" | |
) | |
model.to(args.device) | |
model.eval() | |
if "mbart" in args.model_name.lower(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name, src_lang="vi_VN", tgt_lang="vi_VN" | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
def text_analysis(text): | |
text = text.lower() | |
inputs = tokenizer( | |
[text], | |
text_target=None, | |
padding="longest", | |
max_length=args.max_source_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
for k, v in inputs.items(): | |
inputs[k] = v.to(args.device) | |
if "mbart" in args.model_name: | |
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] | |
outputs = model.generate( | |
**inputs, | |
max_length=args.max_target_length, | |
num_beams=args.beam_size, | |
early_stopping=True, | |
) | |
output_sentences = tokenizer.batch_decode( | |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
out = json.loads("{" + output_sentences[0] + "}") | |
if out['LOẠI BIỂU ĐỒ']=='dashboard': | |
if out['CHU KỲ THỜI GIAN']!='tháng': | |
chu_ky_in = 'ngày' | |
else: | |
chu_ky_in = 'tháng' | |
out['CHU KỲ THỜI GIAN']='ngày' if out['CHU KỲ THỜI GIAN'] not in ['ngày', 'tháng'] else out['CHU KỲ THỜI GIAN'] | |
check_dashboard = out['ĐƠN VỊ']+"_"+chu_ky_in | |
out['DB URL'] = output2url[check_dashboard] | |
out['DATE'] = str(time2date(out)).replace("-", "").replace("-", "") | |
out['FINAL URL'] = "https://vsds.viettel.vn"+ out['DB URL'] + "?toDate=" + out['DATE'] | |
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN', 'DB URL', 'DATE', 'FINAL URL']} | |
elif out['LOẠI BIỂU ĐỒ']=='biểu đồ': | |
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN']} | |
else: | |
show = out | |
return show | |
demo = gr.Interface( | |
text_analysis, | |
gr.Textbox(placeholder="Enter sentence here..."), | |
["json"], | |
examples=[ | |
["Mở dashboard vtc ngày hôm qua"], | |
["Mở biểu đồ cột td ngày này"], | |
["Hãy mở biểu đồ cơ cấu của tập đoàn trong ngày hôm nay"], | |
["Tháng này, vtc cần tôi mở biểu đồ rank để cập nhật danh sách khách hàng"], | |
["Các thông số NAT ngày hôm qua đã được ghi nhận trên đát bọt"], | |
["Hôm nay hãy mở của Viettel tt không gian mạng Viettel vtcc để kiểm tra"], | |
["Mở DB CTM ngày gốc"], | |
["Tôi đã sử dụng Dashboard để truy cập thông tin qti vào ngày hôm nay"], | |
["Trưởng phòng đã ra lệnh mở biểu đồ kết hợp đường và cột cho toàn tập đoàn vào hôm nay"] | |
], | |
) | |
demo.launch() | |