rdose commited on
Commit
52cbc64
1 Parent(s): 3994894

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -110,13 +110,28 @@ def _inference_classifier(text):
110
 
111
  return sigmoid(ort_outs[0])
112
 
113
- def inference(input_batch,isurl,use_archive,limit_companies=10):
114
  input_batch_content = []
115
- print("->Input size:",len(input_batch))
116
- print("+",input_batch)
 
 
 
 
 
 
 
 
 
 
 
 
117
  if isurl:
118
- for row_in in input_batch:
119
- url = row_in[0]
 
 
 
120
  if use_archive:
121
  archive = is_in_archive(url)
122
  if archive['archived']:
@@ -125,8 +140,12 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
125
  extracted = Extractor().extract(requests.get(url).text)
126
  input_batch_content.append(extracted['content'])
127
  else:
128
- for row_in in input_batch:
129
- input_batch_content.append(row_in[0])
 
 
 
 
130
  print("->Batch size:",len(input_batch_content))
131
  print("+",input_batch_content)
132
 
@@ -162,7 +181,9 @@ examples = [[[['https://www.bbc.com/news/uk-62732447'],
162
  ['https://www.bbc.com/news/business-62728621'],
163
  ['https://www.bbc.com/news/science-environment-62680423']],'url',False,5]]
164
  demo = gr.Interface(fn=inference,
165
- inputs=[gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True),
 
 
166
  gr.Dropdown(label='data type', choices=['text','url'], type='index', value='url'),
167
  gr.Checkbox(label='if url parse cached in archive.org'),
168
  gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output', value=5)],
 
110
 
111
  return sigmoid(ort_outs[0])
112
 
113
+ def inference(file_in,file_col_name,input_batch,isurl,use_archive,limit_companies=10):
114
  input_batch_content = []
115
+ if file_in is not None:
116
+ dft = pd.read_csv(
117
+ file_in,
118
+ compression=dict(method='zip')
119
+ )
120
+ assert file_col_name in dft.columns, "Indicated col_name not found in file"
121
+ input_batch_r = dft[file_col_name].values.tolist()
122
+ else:
123
+ assert len(input_batch) > 0, "input_batch array is empty"
124
+ input_batch_r = input_batch
125
+
126
+ print("->Input size:",len(input_batch_r))
127
+ print("+",input_batch_r)
128
+
129
  if isurl:
130
+ for row_in in input_batch_r:
131
+ if isinstance(row_in , list):
132
+ url = row_in[0]
133
+ else:
134
+ url = row_in
135
  if use_archive:
136
  archive = is_in_archive(url)
137
  if archive['archived']:
 
140
  extracted = Extractor().extract(requests.get(url).text)
141
  input_batch_content.append(extracted['content'])
142
  else:
143
+ if isinstance(input_batch_r[0], list):
144
+ for row_in in input_batch_r:
145
+ input_batch_content.append(row_in[0])
146
+ else:
147
+ input_batch_content = input_batch_r
148
+
149
  print("->Batch size:",len(input_batch_content))
150
  print("+",input_batch_content)
151
 
 
181
  ['https://www.bbc.com/news/business-62728621'],
182
  ['https://www.bbc.com/news/science-environment-62680423']],'url',False,5]]
183
  demo = gr.Interface(fn=inference,
184
+ inputs=[gr.File(label='zipped csv file'),
185
+ gr.Textbox(label='If csv, column header name that contains the relevant data:'),
186
+ gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True),
187
  gr.Dropdown(label='data type', choices=['text','url'], type='index', value='url'),
188
  gr.Checkbox(label='if url parse cached in archive.org'),
189
  gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output', value=5)],