sasha HF staff commited on
Commit
0622968
1 Parent(s): 3801f61

Update app.py

Browse files

attempting cache

Files changed (1) hide show
  1. app.py +29 -15
app.py CHANGED
@@ -51,26 +51,40 @@ metrics = st.multiselect(
51
  st.markdown("### Please wait for the dataset and models to load (this can take some time if they are big!")
52
 
53
  ### Loading data
54
- data = datasets.load_dataset(dset, split=dset_split)
55
- st.text("Loaded the "+ str(dset_split)+ " split of dataset "+ str(dset))
 
 
 
 
 
56
 
57
  ### Defining Evaluator
58
  eval = evaluator("text-classification")
59
 
60
  ### Loading models
61
-
62
- model_list=[]
63
- for i in range (len(models)):
64
- try:
65
- globals()[f"tokenizer_{i}"] = AutoTokenizer.from_pretrained(models[i])
66
- globals()[f"model_{i}"] = AutoModelForSequenceClassification.from_pretrained(models[i])
67
- model_list.append(models[i])
68
- except:
69
- print("Sorry, I can't load model "+ str(models[i]))
70
-
71
- for i in range (len(model_list)):
72
- globals()[f"pipe_{i}"] = pipeline("text-classification", model = models[i], tokenizer = models[i], device=-1)
73
- st.text("Loaded pipeline "+ str(models[i]))
 
 
 
 
 
 
 
 
 
74
 
75
  ### Defining metrics
76
  for i in range (len(metrics)):
 
51
  st.markdown("### Please wait for the dataset and models to load (this can take some time if they are big!")
52
 
53
  ### Loading data
54
+ @st.cache
55
+ def loaddset(d, d_split):
56
+ data = datasets.load_dataset(d, split=d_split)
57
+ st.text("Loaded the "+ str(d_split)+ " split of dataset "+ str(d))
58
+ return(data)
59
+
60
+ data = loaddset(dset,dset_split)
61
 
62
  ### Defining Evaluator
63
  eval = evaluator("text-classification")
64
 
65
  ### Loading models
66
+ @st.cache
67
+ def load_models(mod_names):
68
+ model_list=[]
69
+ for i in range (len(mod_names)):
70
+ try:
71
+ globals()[f"tokenizer_{i}"] = AutoTokenizer.from_pretrained(mod_names[i])
72
+ globals()[f"model_{i}"] = AutoModelForSequenceClassification.from_pretrained(mod_names[i])
73
+ model_list.append(mod_names[i])
74
+ except:
75
+ print("Sorry, I can't load model "+ str(mod_names[i]))
76
+ return("Loaded "+ str(len(model_list))+ " models")
77
+
78
+ load_models(models)
79
+
80
+ @st.cache
81
+ def load_pipes(mod_list):
82
+ for i in range (len(mod_list)):
83
+ globals()[f"pipe_{i}"] = pipeline("text-classification", model = models[i], tokenizer = models[i], device=-1)
84
+ st.text("Loaded pipeline "+ str(models[i]))
85
+ return("Loaded "+ str(len(mod_list))+ " models")
86
+
87
+ load_pipes(model_list)
88
 
89
  ### Defining metrics
90
  for i in range (len(metrics)):