cccc commited on
Commit
93cb5d2
·
1 Parent(s): f17457c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -42
app.py CHANGED
@@ -42,56 +42,48 @@ def sentiment_analysis(sentence, model_name):
42
  translated_tokens = model.generate(
43
  **tokenizer(sentences, return_tensors="pt", padding=True)
44
  )
45
- # for t in translated:
46
- # print( tokenizer.decode(t, skip_special_tokens=True) )
47
-
48
- # output = [sentence ['translation_text'] for sentence in tokenizer.decode(translated_tokens, skip_special_tokens=True)]
49
- output = []
50
  for t in translated_tokens:
51
- output.append(tokenizer.decode(t, skip_special_tokens=True))
52
 
53
-
54
-
55
- # testdata = []
56
- # for i,sentence in enumerate(sentences):
57
- # testdata.append(InputExample(guid=i,text_a=sentence,label=0))
58
 
59
- # plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
60
 
61
- # promptTemplate = ManualTemplate(
62
- # text = template,
63
- # tokenizer = tokenizer,
64
- # )
65
- # promptVerbalizer = ManualVerbalizer(
66
- # classes = classes,
67
- # label_words = label_words,
68
- # tokenizer = tokenizer,
69
- # )
70
- # test_dataloader = PromptDataLoader(
71
- # dataset = testdata,
72
- # tokenizer = tokenizer,
73
- # template = promptTemplate,
74
- # tokenizer_wrapper_class = WrapperClass,
75
- # batch_size = 4,
76
- # max_seq_length = 512,
77
- # )
78
- # prompt_model = PromptForClassification(
79
- # plm=plm,
80
- # template=promptTemplate,
81
- # verbalizer=promptVerbalizer,
82
- # freeze_plm=True
83
- # )
84
- # result = []
85
- # for step, inputs in enumerate(test_dataloader):
86
- # logits = prompt_model(inputs)
87
- # result.extend(torch.argmax(logits, dim=-1))
88
- # output = '\n'.join([classes[i] for i in result])
89
  return str(output)
90
 
91
 
92
 
93
-
94
-
95
  demo = gr.Interface(fn=sentiment_analysis,
96
  inputs = [gr.Textbox(placeholder="Enter sentence here. If you have multiple sentences, separate them with '\\n'.",label="Sentence",lines=5),
97
  gr.Radio(choices=["RoBERTa_Chinese_AnnualReport_tuned","RoBERTa_Chinese_Financial_News_tuned","RoBERTa_English_AnnualReport_tuned",
 
42
  translated_tokens = model.generate(
43
  **tokenizer(sentences, return_tensors="pt", padding=True)
44
  )
45
+ sentences_list = []
 
 
 
 
46
  for t in translated_tokens:
47
+ sentences_list.append(tokenizer.decode(t, skip_special_tokens=True))
48
 
49
+ testdata = []
50
+ for i,sentence in enumerate(sentences_list):
51
+ testdata.append(InputExample(guid=i,text_a=sentence,label=0))
 
 
52
 
53
+ plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name)
54
 
55
+ promptTemplate = ManualTemplate(
56
+ text = template,
57
+ tokenizer = tokenizer,
58
+ )
59
+ promptVerbalizer = ManualVerbalizer(
60
+ classes = classes,
61
+ label_words = label_words,
62
+ tokenizer = tokenizer,
63
+ )
64
+ test_dataloader = PromptDataLoader(
65
+ dataset = testdata,
66
+ tokenizer = tokenizer,
67
+ template = promptTemplate,
68
+ tokenizer_wrapper_class = WrapperClass,
69
+ batch_size = 4,
70
+ max_seq_length = 512,
71
+ )
72
+ prompt_model = PromptForClassification(
73
+ plm=plm,
74
+ template=promptTemplate,
75
+ verbalizer=promptVerbalizer,
76
+ freeze_plm=True
77
+ )
78
+ result = []
79
+ for step, inputs in enumerate(test_dataloader):
80
+ logits = prompt_model(inputs)
81
+ result.extend(torch.argmax(logits, dim=-1))
82
+ output = '\n'.join([classes[i] for i in result])
83
  return str(output)
84
 
85
 
86
 
 
 
87
  demo = gr.Interface(fn=sentiment_analysis,
88
  inputs = [gr.Textbox(placeholder="Enter sentence here. If you have multiple sentences, separate them with '\\n'.",label="Sentence",lines=5),
89
  gr.Radio(choices=["RoBERTa_Chinese_AnnualReport_tuned","RoBERTa_Chinese_Financial_News_tuned","RoBERTa_English_AnnualReport_tuned",