ccm commited on
Commit
d4731a4
1 Parent(s): 7bca059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -1,21 +1,16 @@
1
  import numpy
2
  import gradio
3
- from huggingface_hub import from_pretrained_keras
4
  import json
5
- from json import JSONEncoder
6
 
7
- class NumpyArrayEncoder(JSONEncoder):
8
  def default(self, obj):
9
  if isinstance(obj, numpy.ndarray):
10
  return obj.tolist()
11
  return JSONEncoder.default(self, obj)
12
 
13
- def print_shape_through(x):
14
- print(x.shape)
15
- return x
16
-
17
- analysis_network = from_pretrained_keras("cmudrc/wave-energy-analysis")
18
- synthesis_network = from_pretrained_keras("cmudrc/wave-energy-synthesis")
19
 
20
  with gradio.Blocks() as demo:
21
  geometry = gradio.Textbox(label="geometry")
@@ -24,7 +19,7 @@ with gradio.Blocks() as demo:
24
  analyze_it = gradio.Button("Analyze")
25
  synthesize_it = gradio.Button("Synthesize")
26
 
27
- analyze_it.click(fn=lambda x: json.dumps(analysis_network.predict(print_shape_through(numpy.asarray(json.loads(x)))), cls=NumpyArrayEncoder), inputs=[geometry], outputs=[spectrum], api_name="analyze")
28
- synthesize_it.click(fn=lambda x: json.dumps(synthesis_network.predict(print_shape_through(numpy.asarray(json.loads(x)))), cls=NumpyArrayEncoder), inputs=[spectrum], outputs=[geometry], api_name="synthesize")
29
 
30
- demo.launch()
 
1
  import numpy
2
  import gradio
3
+ import huggingface_hub
4
  import json
 
5
 
6
+ class NumpyArrayEncoder(json.JSONEncoder):
7
  def default(self, obj):
8
  if isinstance(obj, numpy.ndarray):
9
  return obj.tolist()
10
  return JSONEncoder.default(self, obj)
11
 
12
+ analysis_network = huggingface_hub.from_pretrained_keras("cmudrc/wave-energy-analysis")
13
+ synthesis_network = huggingface_hub.from_pretrained_keras("cmudrc/wave-energy-synthesis")
 
 
 
 
14
 
15
  with gradio.Blocks() as demo:
16
  geometry = gradio.Textbox(label="geometry")
 
19
  analyze_it = gradio.Button("Analyze")
20
  synthesize_it = gradio.Button("Synthesize")
21
 
22
+ analyze_it.click(fn=lambda x: json.dumps(analysis_network.predict(numpy.asarray([json.loads(x)])), cls=NumpyArrayEncoder), inputs=[geometry], outputs=[spectrum], api_name="analyze")
23
+ synthesize_it.click(fn=lambda x: json.dumps(synthesis_network.predict(numpy.asarray(json.loads(x))), cls=NumpyArrayEncoder), inputs=[spectrum], outputs=[geometry], api_name="synthesize")
24
 
25
+ demo.launch(debug=True)