trminhnam20082002 commited on
Commit
7147095
·
1 Parent(s): b91b62a

chore: update streamlit deprecation

Browse files
Files changed (2) hide show
  1. app.py +15 -10
  2. utils.py +2 -16
app.py CHANGED
@@ -49,8 +49,6 @@ else:
49
  device = "cpu"
50
  max_len = st.sidebar.slider("Max length", 32, 512, 256, 32)
51
  beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1)
52
- tokenizer = load_tokenizer(model_name)
53
- model = load_model(model_name, device)
54
 
55
  # create a text input box for each of the following item
56
  # CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm
@@ -138,13 +136,20 @@ data = {
138
  "Previous year": previous_year,
139
  }
140
 
 
 
141
 
142
  if st.button("Generate"):
143
- with st.spinner("Generating..."):
144
- input_string = make_input_sentence_from_strings(data)
145
- print(input_string)
146
- descriptions = generate_description(
147
- input_string, model, tokenizer, device, max_len, model_name, beam_size
148
- )
149
-
150
- st.success(descriptions)
 
 
 
 
 
 
49
  device = "cpu"
50
  max_len = st.sidebar.slider("Max length", 32, 512, 256, 32)
51
  beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1)
 
 
52
 
53
  # create a text input box for each of the following item
54
  # CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm
 
136
  "Previous year": previous_year,
137
  }
138
 
139
+ tokenizer = load_tokenizer(model_name)
140
+ model = load_model(model_name, device)
141
 
142
  if st.button("Generate"):
143
+ if objective_name == "":
144
+ st.error("Please input objective name")
145
+ elif unit == "":
146
+ st.error("Please input unit")
147
+ else:
148
+ with st.spinner("Generating..."):
149
+ input_string = make_input_sentence_from_strings(data)
150
+ print(input_string)
151
+ descriptions = generate_description(
152
+ input_string, model, tokenizer, device, max_len, model_name, beam_size
153
+ )
154
+
155
+ st.success(descriptions)
utils.py CHANGED
@@ -15,21 +15,7 @@ import streamlit as st
15
  from typing import Dict, List
16
 
17
 
18
- def get_model(args):
19
- print(f"Using model {args.model_name}")
20
- model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
21
- model.to(args.device)
22
-
23
- if args.load_model_path:
24
- print(f"Loading model from {args.load_model_path}")
25
- model.load_state_dict(
26
- torch.load(args.load_model_path, map_location=torch.device(args.device))
27
- )
28
-
29
- return model
30
-
31
-
32
- @st.cache(allow_output_mutation=True)
33
  def load_model(model_name, device):
34
  print(f"Using model {model_name}")
35
  os.makedirs("cache", exist_ok=True)
@@ -46,7 +32,7 @@ def load_model(model_name, device):
46
  return model
47
 
48
 
49
- @st.cache(allow_output_mutation=True)
50
  def load_tokenizer(model_name):
51
  print(f"Loading tokenizer {model_name}")
52
  if "mbart" in model_name.lower():
 
15
  from typing import Dict, List
16
 
17
 
18
+ @st.cache_resource
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_model(model_name, device):
20
  print(f"Using model {model_name}")
21
  os.makedirs("cache", exist_ok=True)
 
32
  return model
33
 
34
 
35
+ @st.cache_resource
36
  def load_tokenizer(model_name):
37
  print(f"Loading tokenizer {model_name}")
38
  if "mbart" in model_name.lower():