Rietta commited on
Commit
848df26
1 Parent(s): 61f1751

requisitos gradio

Browse files
Files changed (2) hide show
  1. app.py +31 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import choices
2
+ import numpy as np
3
+ import gradio as gr
4
+ from glob import glob
5
+ from huggingface_hub import from_pretrained_keras
6
+
7
+ model = from_pretrained_keras('Rietta/CycleGAN_DL', compile=False)
8
+
9
+ def transform(img, direction):
10
+ img = (img / 127.5) - 1
11
+ if direction==0:
12
+ pred = model.generator_sims.predict(img[None,:,:,:])[0]
13
+ else:
14
+ pred = model.generator_wow.predict(img[None,:,:,:])[0]
15
+ pred = (pred-pred.min())/(pred.max()-pred.min())
16
+ pred = (pred * 255).astype(np.uint8)
17
+ return pred
18
+
19
+ #examples_gta = [[path, 'GTA->REAL'] for path in glob('Examples/gta*')]
20
+ #examples_real = [[path, 'REAL->GTA'] for path in glob('Examples/real*')]
21
+ #examples = [*examples_gta, *examples_real]
22
+
23
+ demo = gr.Interface(fn=transform,
24
+ inputs=[gr.inputs.Image(shape=(256, 256), type='numpy'),
25
+ gr.inputs.Radio(choices=['Sims', 'Warcraft'],
26
+ type='index')],
27
+ outputs=gr.outputs.Image(type='numpy'))
28
+ #examples=examples)
29
+
30
+ if __name__ == '__main__':
31
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tensorflow>2.6