vrushabh94 commited on
Commit
0251703
·
1 Parent(s): 5f9dcd8

Removed the sharing

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -11,12 +11,14 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
  premise = ""
12
  hypothesis = ""
13
 
14
- def zeroShotClassification(text_input, hypothesis, candidate_labels):
 
 
15
  input = tokenizer(text_input, hypothesis, truncation=True, return_tensors="pt")
16
  output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
17
  prediction = torch.softmax(output["logits"][0], -1).tolist()
18
- label_names = candidate_labels #["entailment", "neutral", "contradiction"]
19
- prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
20
  return prediction
21
 
22
  examples = [
@@ -28,5 +30,5 @@ examples = [
28
  ["Submission Receipt", "Submission Receipt"]
29
  ]
30
 
31
- demo = gr.Interface(fn=zeroShotClassification, inputs=[gr.Textbox(label="Input"), gr.Textbox(label="Candidate Labels", value="Meeting Minutes / Outcomes,Submission Receipt")], outputs=gr.Label(label="Classification"), examples=examples)
32
  demo.launch();
 
11
  premise = ""
12
  hypothesis = ""
13
 
14
+ def zeroShotClassification(text_input, candidate_labels, hypothesis):
15
+ print(text_input)
16
+ print(candidate_labels)
17
  input = tokenizer(text_input, hypothesis, truncation=True, return_tensors="pt")
18
  output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
19
  prediction = torch.softmax(output["logits"][0], -1).tolist()
20
+ labels = [label.strip(' ') for label in candidate_labels.split(',')]
21
+ prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, labels)}
22
  return prediction
23
 
24
  examples = [
 
30
  ["Submission Receipt", "Submission Receipt"]
31
  ]
32
 
33
+ demo = gr.Interface(fn=zeroShotClassification, inputs=[gr.Textbox(label="Input"), gr.Textbox(label="Candidate Labels", value="Meeting Minutes / Outcomes,Submission Receipt"), gr.Textbox(label="Hypothesys", value="Meeting Minutes / Outcomes,Submission Receipt")], outputs=gr.Label(label="Classification"), examples=examples)
34
  demo.launch();