AgaMiko commited on
Commit
b9afdfb
1 Parent(s): ca0f425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -3,10 +3,11 @@ import streamlit as st
3
  from PIL import Image
4
  import os
5
 
 
6
  @st.cache(allow_output_mutation=True)
7
  def load_model_cache():
8
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
9
-
10
  tokenizer_pl = T5Tokenizer.from_pretrained(
11
  "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
12
  )
@@ -31,18 +32,19 @@ st.set_page_config(
31
 
32
  tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
33
 
 
34
  def get_predictions(text):
35
- input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
36
- output = model_pl.generate(
37
- input_ids,
38
- no_repeat_ngram_size=1,
39
- num_beams=3,
40
- num_beam_groups=3,
41
- min_length=10,
42
- max_length=100,
43
- )
44
- predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
45
- return predicted_kw
46
 
47
 
48
  def trim_length():
 
3
  from PIL import Image
4
  import os
5
 
6
+
7
  @st.cache(allow_output_mutation=True)
8
  def load_model_cache():
9
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
10
+
11
  tokenizer_pl = T5Tokenizer.from_pretrained(
12
  "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
13
  )
 
32
 
33
  tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
34
 
35
+
36
  def get_predictions(text):
37
+ input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
38
+ output = model_pl.generate(
39
+ input_ids,
40
+ no_repeat_ngram_size=1,
41
+ num_beams=3,
42
+ num_beam_groups=3,
43
+ min_length=10,
44
+ max_length=100,
45
+ )
46
+ predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
47
+ return predicted_rfc
48
 
49
 
50
  def trim_length():