shoukaku commited on
Commit
802b789
·
1 Parent(s): 5d9208e

make the app more 'user friendly'

Browse files
Files changed (1) hide show
  1. app.py +111 -101
app.py CHANGED
@@ -10,56 +10,56 @@ if gr.NO_RELOAD:
10
 
11
  DEVICE = 'cpu'
12
  MODELS = [
13
- (
14
- 'bert-model_1950',
15
- lambda: BaseTransferLearningModel(
16
- 'bert-base-uncased',
17
- [('linear', ['in', 'out']), ('softmax')],
18
- 2,
19
- device=DEVICE,
20
- state_dict='src/ckpt/bert-model_1950.pt',
21
- ),
22
- ),
23
- (
24
- 'bert-model_2000',
25
- lambda: BaseTransferLearningModel(
26
- 'bert-base-uncased',
27
- [('linear', ['in', 'out']), ('softmax')],
28
- 2,
29
- device=DEVICE,
30
- state_dict='src/ckpt/bert-model_2000.pt',
31
- ),
32
- ),
33
- (
34
- 'deberta-base-model_1100',
35
- lambda: BaseTransferLearningModel(
36
- 'microsoft/deberta-base',
37
- [('linear', ['in', 'out']), ('softmax')],
38
- 2,
39
- device=DEVICE,
40
- state_dict='src/ckpt/deberta-base-model_4400.pt',
41
- ),
42
- ),
43
- (
44
- 'deberta-base-model_2000',
45
- lambda: BaseTransferLearningModel(
46
- 'microsoft/deberta-base',
47
- [('linear', ['in', 'out']), ('softmax')],
48
- 2,
49
- device=DEVICE,
50
- state_dict='src/ckpt/deberta-base-model_8000.pt',
51
- ),
52
- ),
53
- (
54
- 'deberta-v3-base-model_1700',
55
- lambda: BaseTransferLearningModel(
56
- 'microsoft/deberta-v3-base',
57
- [('linear', ['in', 'out']), ('softmax')],
58
- 2,
59
- device=DEVICE,
60
- state_dict='src/ckpt/deberta-v3-base-model_3400.pt',
61
- ),
62
- ),
63
  (
64
  'deberta-v3-base-model_2000',
65
  lambda: BaseTransferLearningModel(
@@ -70,58 +70,64 @@ MODELS = [
70
  state_dict='src/ckpt/deberta-v3-base-model_4000.pt',
71
  ),
72
  ),
73
- (
74
- 'distilbert-model_1850',
75
- lambda: BaseTransferLearningModel(
76
- 'distilbert-base-uncased',
77
- [('linear', ['in', 'out']), ('softmax')],
78
- 2,
79
- device=DEVICE,
80
- state_dict='src/ckpt/distilbert-model_1850.pt',
81
- ),
82
- ),
83
- (
84
- 'distilbert-model_2000',
85
- lambda: BaseTransferLearningModel(
86
- 'distilbert-base-uncased',
87
- [('linear', ['in', 'out']), ('softmax')],
88
- 2,
89
- device=DEVICE,
90
- state_dict='src/ckpt/distilbert-model_2000.pt',
91
- ),
92
- ),
93
- (
94
- 'roberta-base-model_1250',
95
- lambda: BaseTransferLearningModel(
96
- 'FacebookAI/roberta-base',
97
- [('linear', ['in', 'out']), ('softmax')],
98
- 2,
99
- device=DEVICE,
100
- state_dict='src/ckpt/roberta-base-model_1250.pt',
101
- ),
102
- ),
103
- (
104
- 'roberta-base-model_2000',
105
- lambda: BaseTransferLearningModel(
106
- 'FacebookAI/roberta-base',
107
- [('linear', ['in', 'out']), ('softmax')],
108
- 2,
109
- device=DEVICE,
110
- state_dict='src/ckpt/roberta-base-model_2000.pt',
111
- ),
112
- ),
113
  ]
114
 
115
 
116
  class WebUI:
117
 
118
- def __init__(self, models: list[(str, Callable)] = [], device: str = 'cpu') -> None:
 
 
 
 
 
119
  self.models = models
120
  self.device = device
121
  self.is_ready = False
122
  self.model = self.models[0][1]()
123
  self.is_ready = True
124
  self.scraper = GenericScraper()
 
125
 
126
  def _change_model(self, idx: int) -> str:
127
  if gr.NO_RELOAD:
@@ -142,7 +148,9 @@ class WebUI:
142
  if self.is_ready == False:
143
  return 'Model is not yet ready!'
144
  output = self.model.predict(text, self.device).detach().cpu().numpy()[0]
145
- return f'Fake: {output[0]:.10f}, Real: {output[1]:.10f}'
 
 
146
 
147
  def _scrape(self, url: str) -> str:
148
  try:
@@ -173,16 +181,18 @@ class WebUI:
173
  )
174
  btn_submit = gr.Button(value='Submit', variant='primary')
175
  with gr.Column():
176
- ddl_model = gr.Dropdown(
177
- label='Model',
178
- choices=[model[0] for model in self.models],
179
- value=self._change_model(0),
180
- type='index',
181
- interactive=True,
182
- filterable=True,
183
- )
 
184
  t_out = gr.Textbox(label='Output')
185
- ddl_model.change(fn=self._change_model, inputs=ddl_model)
 
186
  btn_scrape.click(fn=self._scrape, inputs=t_url, outputs=t_inp)
187
  btn_submit.click(fn=self._predict, inputs=t_inp, outputs=t_out)
188
  return ui
 
10
 
11
  DEVICE = 'cpu'
12
  MODELS = [
13
+ # (
14
+ # 'bert-model_1950',
15
+ # lambda: BaseTransferLearningModel(
16
+ # 'bert-base-uncased',
17
+ # [('linear', ['in', 'out']), ('softmax')],
18
+ # 2,
19
+ # device=DEVICE,
20
+ # state_dict='src/ckpt/bert-model_1950.pt',
21
+ # ),
22
+ # ),
23
+ # (
24
+ # 'bert-model_2000',
25
+ # lambda: BaseTransferLearningModel(
26
+ # 'bert-base-uncased',
27
+ # [('linear', ['in', 'out']), ('softmax')],
28
+ # 2,
29
+ # device=DEVICE,
30
+ # state_dict='src/ckpt/bert-model_2000.pt',
31
+ # ),
32
+ # ),
33
+ # (
34
+ # 'deberta-base-model_1100',
35
+ # lambda: BaseTransferLearningModel(
36
+ # 'microsoft/deberta-base',
37
+ # [('linear', ['in', 'out']), ('softmax')],
38
+ # 2,
39
+ # device=DEVICE,
40
+ # state_dict='src/ckpt/deberta-base-model_4400.pt',
41
+ # ),
42
+ # ),
43
+ # (
44
+ # 'deberta-base-model_2000',
45
+ # lambda: BaseTransferLearningModel(
46
+ # 'microsoft/deberta-base',
47
+ # [('linear', ['in', 'out']), ('softmax')],
48
+ # 2,
49
+ # device=DEVICE,
50
+ # state_dict='src/ckpt/deberta-base-model_8000.pt',
51
+ # ),
52
+ # ),
53
+ # (
54
+ # 'deberta-v3-base-model_1700',
55
+ # lambda: BaseTransferLearningModel(
56
+ # 'microsoft/deberta-v3-base',
57
+ # [('linear', ['in', 'out']), ('softmax')],
58
+ # 2,
59
+ # device=DEVICE,
60
+ # state_dict='src/ckpt/deberta-v3-base-model_3400.pt',
61
+ # ),
62
+ # ),
63
  (
64
  'deberta-v3-base-model_2000',
65
  lambda: BaseTransferLearningModel(
 
70
  state_dict='src/ckpt/deberta-v3-base-model_4000.pt',
71
  ),
72
  ),
73
+ # (
74
+ # 'distilbert-model_1850',
75
+ # lambda: BaseTransferLearningModel(
76
+ # 'distilbert-base-uncased',
77
+ # [('linear', ['in', 'out']), ('softmax')],
78
+ # 2,
79
+ # device=DEVICE,
80
+ # state_dict='src/ckpt/distilbert-model_1850.pt',
81
+ # ),
82
+ # ),
83
+ # (
84
+ # 'distilbert-model_2000',
85
+ # lambda: BaseTransferLearningModel(
86
+ # 'distilbert-base-uncased',
87
+ # [('linear', ['in', 'out']), ('softmax')],
88
+ # 2,
89
+ # device=DEVICE,
90
+ # state_dict='src/ckpt/distilbert-model_2000.pt',
91
+ # ),
92
+ # ),
93
+ # (
94
+ # 'roberta-base-model_1250',
95
+ # lambda: BaseTransferLearningModel(
96
+ # 'FacebookAI/roberta-base',
97
+ # [('linear', ['in', 'out']), ('softmax')],
98
+ # 2,
99
+ # device=DEVICE,
100
+ # state_dict='src/ckpt/roberta-base-model_1250.pt',
101
+ # ),
102
+ # ),
103
+ # (
104
+ # 'roberta-base-model_2000',
105
+ # lambda: BaseTransferLearningModel(
106
+ # 'FacebookAI/roberta-base',
107
+ # [('linear', ['in', 'out']), ('softmax')],
108
+ # 2,
109
+ # device=DEVICE,
110
+ # state_dict='src/ckpt/roberta-base-model_2000.pt',
111
+ # ),
112
+ # ),
113
  ]
114
 
115
 
116
  class WebUI:
117
 
118
+ def __init__(
119
+ self,
120
+ models: list[(str, Callable)] = [],
121
+ device: str = 'cpu',
122
+ debug: bool = False,
123
+ ) -> None:
124
  self.models = models
125
  self.device = device
126
  self.is_ready = False
127
  self.model = self.models[0][1]()
128
  self.is_ready = True
129
  self.scraper = GenericScraper()
130
+ self.debug = debug
131
 
132
  def _change_model(self, idx: int) -> str:
133
  if gr.NO_RELOAD:
 
148
  if self.is_ready == False:
149
  return 'Model is not yet ready!'
150
  output = self.model.predict(text, self.device).detach().cpu().numpy()[0]
151
+ if self.debug:
152
+ return f'Fake: {output[0]:.10f}, Real: {output[1]:.10f}'
153
+ return f'We think that this is a {"fake" if output[0] > output[1] else "real"} news article with {max(output[0], output[1]) * 100:.2f}% certainty.'
154
 
155
  def _scrape(self, url: str) -> str:
156
  try:
 
181
  )
182
  btn_submit = gr.Button(value='Submit', variant='primary')
183
  with gr.Column():
184
+ if self.debug:
185
+ ddl_model = gr.Dropdown(
186
+ label='Model',
187
+ choices=[model[0] for model in self.models],
188
+ value=self._change_model(0),
189
+ type='index',
190
+ interactive=True,
191
+ filterable=True,
192
+ )
193
  t_out = gr.Textbox(label='Output')
194
+ if self.debug:
195
+ ddl_model.change(fn=self._change_model, inputs=ddl_model)
196
  btn_scrape.click(fn=self._scrape, inputs=t_url, outputs=t_inp)
197
  btn_submit.click(fn=self._predict, inputs=t_inp, outputs=t_out)
198
  return ui