Joe Davison commited on
Commit
85dd546
1 Parent(s): 039194f

fix model caching error

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -13,6 +13,9 @@ import psutil
13
  with open("hit_log.txt", mode='a') as file:
14
  file.write(str(datetime.datetime.now()) + '\n')
15
 
 
 
 
16
  MODEL_DESC = {
17
  'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
18
  'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
@@ -58,10 +61,22 @@ models = load_models()
58
  def load_tokenizer(tok_id):
59
  return AutoTokenizer.from_pretrained(tok_id)
60
 
61
- @st.cache(allow_output_mutation=True, show_spinner=False)
62
- def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class, do_print_code):
63
- classifier = pipeline('zero-shot-classification', model=models[nli_model_id], tokenizer=load_tokenizer(nli_model_id), device=device)
64
- outputs = classifier(sequence, labels, hypothesis_template, multi_class)
 
 
 
 
 
 
 
 
 
 
 
 
65
  return outputs['labels'], outputs['scores']
66
 
67
  def load_examples(model_id):
@@ -88,7 +103,6 @@ def plot_result(top_topics, scores):
88
  fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
89
  st.plotly_chart(fig)
90
 
91
-
92
 
93
  def main():
94
  with open("style.css") as f:
@@ -124,18 +138,11 @@ def main():
124
  st.markdown(CODE_DESC.format(model_id))
125
 
126
  with st.spinner('Classifying...'):
127
- top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class, do_print_code)
128
-
129
- plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
130
-
131
- if "socat" not in [p.name() for p in psutil.process_iter()]:
132
- os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
133
-
134
-
135
 
 
136
 
137
 
138
 
139
  if __name__ == '__main__':
140
  main()
141
-
 
13
  with open("hit_log.txt", mode='a') as file:
14
  file.write(str(datetime.datetime.now()) + '\n')
15
 
16
+
17
+ MAX_GRAPH_ROWS = 10
18
+
19
  MODEL_DESC = {
20
  'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
21
  'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
 
61
  def load_tokenizer(tok_id):
62
  return AutoTokenizer.from_pretrained(tok_id)
63
 
64
+ @st.cache(allow_output_mutation=True, show_spinner=False, hash_funcs={
65
+ torch.nn.Parameter: lambda _: None
66
+ })
67
+ def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class):
68
+ classifier = pipeline(
69
+ 'zero-shot-classification',
70
+ model=models[nli_model_id],
71
+ tokenizer=load_tokenizer(nli_model_id),
72
+ device=device
73
+ )
74
+ outputs = classifier(
75
+ sequence,
76
+ candidate_labels=labels,
77
+ hypothesis_template=hypothesis_template,
78
+ multi_label=multi_class
79
+ )
80
  return outputs['labels'], outputs['scores']
81
 
82
  def load_examples(model_id):
 
103
  fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
104
  st.plotly_chart(fig)
105
 
 
106
 
107
  def main():
108
  with open("style.css") as f:
 
138
  st.markdown(CODE_DESC.format(model_id))
139
 
140
  with st.spinner('Classifying...'):
141
+ top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class)
 
 
 
 
 
 
 
142
 
143
+ plot_result(top_topics[::-1][-MAX_GRAPH_ROWS:], scores[::-1][-MAX_GRAPH_ROWS:])
144
 
145
 
146
 
147
  if __name__ == '__main__':
148
  main()