Unggi commited on
Commit
bd5680b
·
1 Parent(s): c6206e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -16,25 +16,42 @@ def load_model(model_name):
16
  return model, tokenizer
17
 
18
 
19
- def inference(prompt):
 
20
  model_name = "Unggi/feedback_prize_kor"
21
 
22
  model, tokenizer = load_model(
23
  model_name = model_name
24
  )
25
 
26
- inputs = tokenizer(
27
- prompt,
28
- return_tensors="pt"
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- with torch.no_grad():
32
- logits = model(**inputs).logits
33
 
34
- predicted_class_id = logits.argmax().item()
35
- class_id = model.config.id2label[predicted_class_id]
36
 
37
- return class_id
38
 
39
  demo = gr.Interface(
40
  fn=inference,
 
16
  return model, tokenizer
17
 
18
 
19
+ def inference(prompt_inputs):
20
+
21
  model_name = "Unggi/feedback_prize_kor"
22
 
23
  model, tokenizer = load_model(
24
  model_name = model_name
25
  )
26
 
27
+ # prompt 구두점 단위로 분리하기
28
+ prompt_list = prompt_inputs.split('.|!|?')
29
+
30
+ class_id_list = []
31
+
32
+ for prompt in prompt_list:
33
+ inputs = tokenizer(
34
+ prompt,
35
+ return_tensors="pt"
36
+ )
37
+
38
+ with torch.no_grad():
39
+ logits = model(**inputs).logits
40
+
41
+ predicted_class_id = logits.argmax().item()
42
+ class_id = model.config.id2label[predicted_class_id]
43
+
44
+ class_id_list.append(class_id)
45
+
46
+ outputs = []
47
+
48
+ for p, c_id in zip(prompt_list, class_id_list):
49
 
50
+ outputs.append(p + '\t' + c_id)
 
51
 
52
+ outputs = outputs.join('\n')
 
53
 
54
+ return outputs
55
 
56
  demo = gr.Interface(
57
  fn=inference,