Benjamin Bossan commited on
Commit
5444690
·
1 Parent(s): 7a8ea79

Simplify plotting in a grid

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -234,6 +234,16 @@ description = (
234
  "Colored cirles are (predicted) labels and black x are outliers."
235
  )
236
 
 
 
 
 
 
 
 
 
 
 
237
  with gr.Blocks(title=title) as demo:
238
  gr.HTML(f"<b>{title}</b>")
239
  gr.Markdown(description)
@@ -253,20 +263,16 @@ with gr.Blocks(title=title) as demo:
253
  )
254
  n_rows = int(math.ceil(len(input_models) / N_COLS))
255
  counter = 0
256
- # code below is not very elegant, maybe there is a better way?
257
- for i in range(n_rows):
258
- with gr.Row():
259
- for j in range(N_COLS):
260
- with gr.Column():
261
- if counter >= len(input_models):
262
- break
263
-
264
- input_model = input_models[counter]
265
- plot = gr.Plot(label=input_model)
266
- fn = partial(cluster, clustering_algorithm=input_model)
267
- input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
268
- input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
269
- counter += 1
270
 
271
 
272
  demo.launch()
 
234
  "Colored cirles are (predicted) labels and black x are outliers."
235
  )
236
 
237
+
238
+ def iter_grid(n_rows, n_cols):
239
+ # create a grid using gradio Block
240
+ for _ in range(n_rows):
241
+ with gr.Row():
242
+ for _ in range(n_cols):
243
+ with gr.Column():
244
+ yield
245
+
246
+
247
  with gr.Blocks(title=title) as demo:
248
  gr.HTML(f"<b>{title}</b>")
249
  gr.Markdown(description)
 
263
  )
264
  n_rows = int(math.ceil(len(input_models) / N_COLS))
265
  counter = 0
266
+ for _ in iter_grid(n_rows, N_COLS):
267
+ if counter >= len(input_models):
268
+ break
269
+
270
+ input_model = input_models[counter]
271
+ plot = gr.Plot(label=input_model)
272
+ fn = partial(cluster, clustering_algorithm=input_model)
273
+ input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
274
+ input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
275
+ counter += 1
 
 
 
 
276
 
277
 
278
  demo.launch()