youj2005 commited on
Commit
78cf820
1 Parent(s): 2eba1ff

Add multiclass and shrink LLM model size

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -5,10 +5,10 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
7
  te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
8
- qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
9
- qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")
10
 
11
- def predict(context, intent):
12
  input_text = "In one word, what is the opposite of: " + intent + "?"
13
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
14
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
@@ -39,19 +39,19 @@ def predict(context, intent):
39
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
40
 
41
  pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
42
-
43
- pn_tensor = pn_tensor.softmax(dim=0)
 
 
44
  pn_tensor = pn_tensor.tolist()
45
  return {"entailment": pn_tensor[0], "neutral": pn_tensor[1], "contradiction": pn_tensor[2]}
46
 
47
  gradio_app = gr.Interface(
48
  predict,
49
- inputs=["text", "text"],
50
  outputs=[gr.Label(num_top_classes=3)],
51
  title="Intent Analysis",
52
  )
53
 
54
- print(predict("The cat is short.", "long"))
55
-
56
  if __name__ == "__main__":
57
  gradio_app.launch()
 
5
 
6
  te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
7
  te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
8
+ qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
9
+ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
10
 
11
+ def predict(context, intent, multi_class):
12
  input_text = "In one word, what is the opposite of: " + intent + "?"
13
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
14
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
 
39
  pn_tensor[0] = pn_tensor[0] * outputs[2][1]
40
 
41
  pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
42
+ if (multi_class):
43
+ pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
44
+ else:
45
+ pn_tensor = pn_tensor.softmax(dim=0)
46
  pn_tensor = pn_tensor.tolist()
47
  return {"entailment": pn_tensor[0], "neutral": pn_tensor[1], "contradiction": pn_tensor[2]}
48
 
49
  gradio_app = gr.Interface(
50
  predict,
51
+ inputs=[gr.Text("Sentence"), gr.Text("Class"), gr.Checkbox("Allow multiple true classes", default=True)],
52
  outputs=[gr.Label(num_top_classes=3)],
53
  title="Intent Analysis",
54
  )
55
 
 
 
56
  if __name__ == "__main__":
57
  gradio_app.launch()