Rhasan97 commited on
Commit
0441b30
1 Parent(s): 2e7b57c
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -15,11 +15,13 @@ input_name = inf_session.get_inputs()[0].name
15
  output_name = inf_session.get_outputs()[0].name
16
 
17
  def classify_recipe(description):
18
- input_ids = tokenizer(description)['input_ids'][:512]
19
- logits = inf_session.run([output_name],{input_name:[input_ids]})[0]
20
- logits = torch.FloatTensor(logits)
21
- probs = torch.sigmoid(logits)[0]
22
- return dict(zip(recipe, map(float, probs)))
 
 
23
 
24
  label = gr.outputs.Label(num_top_classes=5)
25
  iface = gr.Interface(fn=classify_recipe, inputs="text", outputs=label)
 
15
  output_name = inf_session.get_outputs()[0].name
16
 
17
  def classify_recipe(description):
18
+ input_ids = tokenizer(description)['input_ids'][:512]
19
+ logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
20
+ logits = torch.FloatTensor(logits)
21
+ probs = torch.sigmoid(logits)[0]
22
+
23
+ result = {class_label: float(prob) for class_label, prob in zip(recipe, probs)}
24
+ return result
25
 
26
  label = gr.outputs.Label(num_top_classes=5)
27
  iface = gr.Interface(fn=classify_recipe, inputs="text", outputs=label)