# -*- coding: utf-8 -*- import os import re import torch from transformers import ( AutoTokenizer, AutoModel, T5ForConditionalGeneration, MBartForConditionalGeneration, AutoModelForSeq2SeqLM, ) from tqdm.auto import tqdm import streamlit as st from typing import Dict, List @st.cache_resource def load_model(model_name, device): print(f"Using model {model_name}") os.makedirs("cache", exist_ok=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache") model.to(device) model_name = model_name.split("/")[-1] load_model_path = os.path.join("models", f"{model_name}-best_loss.bin") print(f"Loading model from {load_model_path}") model.load_state_dict( torch.load(load_model_path, map_location=torch.device(device)) ) return model @st.cache_resource def load_tokenizer(model_name): print(f"Loading tokenizer {model_name}") if "mbart" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained( model_name, src_lang="vi_VN", tgt_lang="vi_VN" ) # tokenizer.src_lang = "vi_VN" # tokenizer.tgt_lang = "vi_VN" else: tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"): inputs = tokenizer( batch["src"], text_target=batch["tgt"] if is_train else None, padding="longest", max_length=max_len, truncation=True, return_tensors="pt", ) for k, v in inputs.items(): inputs[k] = v.to(device) return inputs def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"): inputs = tokenizer( src, padding="longest", max_length=max_len, truncation=True, return_tensors="pt", ) for k, v in inputs.items(): inputs[k] = v.to(device) return inputs def make_input_sentence_from_strings(data): # data = { # "CHỈ TIÊU": objective_name, # "ĐƠN VỊ": unit, # "ĐIỀU KIỆN": condition, # "KPI mục tiêu tháng": kpi_target, # "Đánh giá": evaluation_value, # "Thời gian báo cáo": current_time, # f"T{current_time[1]}.{current_time[0]} thực tế": real_value, # "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}", # f"T{previous_month[1]}.{previous_month[0]}": previous_month_value, # "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}", # f"T{previous_year[1]}.{previous_year[0]}": previous_year_value, # "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm", # f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare, # "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm", # "Previous month": previous_month, # "Previous year": previous_year, # } previous_month_value_key = data["Previous month value key"] previous_year_value_key = data["Previous year value key"] objective_name = data["CHỈ TIÊU"] unit = data["ĐƠN VỊ"] condition = data["ĐIỀU KIỆN"] kpi_target = data["KPI mục tiêu tháng"] current_time = data["Thời gian báo cáo"] real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"] evaluation_value = data["Đánh giá"] previous_month_value = data[previous_month_value_key] previous_year_value = data[previous_year_value_key] previous_month_compare_key = data["Previous month compare key"] previous_year_compare_key = data["Previous year compare key"] previous_month_compare = data[previous_month_compare_key] previous_year_compare = data[previous_year_compare_key] previous_month = data["Previous month"] previous_year = data["Previous year"] # make a template string from the following example: # """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}""" template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}' return template_str.format( objective_name, unit, condition, kpi_target, current_time[1], current_time[0], real_value, evaluation_value, previous_month[1], previous_month[0], previous_month_value, previous_month[1], previous_month[0], previous_month_compare, previous_year[1], previous_year[0], previous_year_value, previous_year[1], previous_year[0], previous_year_compare, ) @torch.no_grad() def generate_description( input_string, model, tokenizer, device, max_len, model_name, beam_size ): model.eval() model = model.to(device) inputs = prepare_single_model_inputs( input_string, tokenizer, max_len=max_len, device=device ) if "mbart" in model_name.lower(): inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] outputs = model.generate( **inputs, max_length=max_len, num_beams=beam_size, # early_stopping=True, ) return tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True )