trminhnam20082002's picture
chore: update streamlit deprecation
7147095
# -*- coding: utf-8 -*-
import streamlit as st
import pandas as pd
import torch
from utils import (
load_model,
load_tokenizer,
make_input_sentence_from_strings,
generate_description,
)
st.set_page_config(
page_title="Table-to-text generation",
page_icon="📝",
layout="wide",
initial_sidebar_state="auto",
menu_items={
"Get Help": "https://huggingface.co/transformers/master/index.html",
"Report a bug": "https://github.com",
}, # hide the "Made with Streamlit" footer
)
st.title("Table-to-text generation with multilingual pre-trained models")
st.markdown(
"""
This is a demo of table-to-text generation with multilingual pre-trained models.
The models are trained on our custom dataset, which is sampling from Viettel Report Template and generated description by ChatGPT.
"""
)
st.sidebar.title("Settings")
model_name = st.sidebar.selectbox(
"Model name",
[
"vinai/bartpho-syllable",
"vinai/bartpho-syllable-base",
"google/byt5-base",
"google/byt5-small",
"facebook/mbart-large-50",
],
)
if torch.cuda.is_available():
device = "cuda" if st.sidebar.checkbox("Use GPU", False) else "cpu"
else:
st.sidebar.checkbox("Use GPU", False, disabled=True)
device = "cpu"
max_len = st.sidebar.slider("Max length", 32, 512, 256, 32)
beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1)
# create a text input box for each of the following item
# CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm
objective_name = st.text_input("CHỈ TIÊU", "")
(unit_col, condition_col, kpi_target_col) = st.columns(3)
with unit_col:
unit = st.text_input("ĐƠN VỊ", "")
with condition_col:
condition = st.selectbox("ĐIỀU KIỆN", [">=", "<=", None])
with kpi_target_col:
kpi_target = st.text_input("KPI mục tiêu tháng", "")
current_date_col, real_value_col, evaluation_col = st.columns(3)
with current_date_col:
current_date = st.date_input(
"Thời gian báo cáo", value=None, min_value=None, max_value=None, key=None
)
current_time = [int(x) for x in current_date.__str__().split("-")[:2]]
with real_value_col:
real_value = st.text_input(f"T{current_time[1]}.{current_time[0]} thực tế", "")
with evaluation_col:
evaluation_value = st.selectbox(
"Đánh giá",
["Đạt", "Không đạt", "Theo dõi"],
index=2 if (kpi_target == "" or condition is None) else 0,
)
# current_time is in format [year, month, day]
previous_month = (
[current_time[0], current_time[1] - 1]
if current_time[1] > 1
else [current_time[0] - 1, 12]
)
previous_year = [current_time[0] - 1, current_time[1]]
(
previous_month_value_col,
previous_month_compare_col,
previous_year_value_col,
previous_year_compare_col,
) = st.columns(4)
with previous_month_value_col:
previous_month_value = st.text_input(
f"T{previous_month[1]}.{previous_month[0]}", ""
)
with previous_month_compare_col:
previous_month_compare = st.text_input(
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
float(real_value) - float(previous_month_value)
if previous_month_value != ""
else "",
# disabled=True,
)
with previous_year_value_col:
previous_year_value = st.text_input(f"T{previous_year[1]}.{previous_year[0]}", "")
with previous_year_compare_col:
previous_year_compare = st.text_input(
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
float(real_value) - float(previous_year_value)
if previous_year_value != ""
else "",
# disabled=True,
)
data = {
"CHỈ TIÊU": objective_name,
"ĐƠN VỊ": unit,
"ĐIỀU KIỆN": condition,
"KPI mục tiêu tháng": kpi_target,
"Đánh giá": evaluation_value,
"Thời gian báo cáo": current_time,
f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
"Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
"Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
"Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
"Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm": previous_year_compare,
"Previous month": previous_month,
"Previous year": previous_year,
}
tokenizer = load_tokenizer(model_name)
model = load_model(model_name, device)
if st.button("Generate"):
if objective_name == "":
st.error("Please input objective name")
elif unit == "":
st.error("Please input unit")
else:
with st.spinner("Generating..."):
input_string = make_input_sentence_from_strings(data)
print(input_string)
descriptions = generate_description(
input_string, model, tokenizer, device, max_len, model_name, beam_size
)
st.success(descriptions)