espejelomar commited on
Commit
ad895c8
1 Parent(s): df98ebc

Add Spanish titles and disclaimers

Browse files
Files changed (2) hide show
  1. app.py +36 -26
  2. utils.py +11 -8
app.py CHANGED
@@ -1,49 +1,59 @@
1
- import streamlit as st # HF spaces at v1.2.0
2
- from utils import load_model,generate
3
 
4
- ## Configuración de nuestro demo
 
 
5
  st.title("Butterfly GAN (GAN de mariposas)")
6
- st.write("Modelo Light-GAN entrenado con 1000 imágenes de mariposas tomadas de la colección del Museo Smithsonian.")
7
- st.write("*Disclaimers:")
8
- st.write("* Este demo es una versión simplificada del creado por [Ceyda Cinarel](https://github.com/cceyda) y [Jonathan Whitaker](https://datasciencecastnet.home.blog/) ([link](https://huggingface.co/spaces/huggan/butterfly-gan)) durante el hackathon [HugGan](https://github.com/huggingface/community-events).")
9
- st.write("* Modelo basado en el [paper](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*.")
10
 
 
11
  st.sidebar.subheader("¡Esta mariposa no existe! Ni en América Latina 🤯.")
12
  st.sidebar.image("assets/logo.png", width=200)
13
- st.sidebar.caption(f"[Modelo](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) y [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) usados.")
14
- st.sidebar.caption(f"*Contribuidores:*")
15
- st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) y [Jonathan Whitaker](https://datasciencecastnet.home.blog/). Edición de versión simplificada por [Omar Espejel](https://twitter.com/espejelomar).")
 
 
 
 
 
 
 
16
 
17
  ## Cargamos modelo
18
- repo_id = 'ceyda/butterfly_cropped_uniq1K_512'
19
- version_modelo ='57d36a15546909557d9f967f47713236c8288838'
20
- modelo_gan = load_model(repo_id, version_modelo)
21
 
22
  ## Generamos 4 mariposas
23
- n_mariposas =4
24
 
25
  ## Función que genera mariposas y lo guarda como un estado de la sesión
26
  def corre():
27
  with st.spinner("Generando, espera un poco..."):
28
- ims=generate(modelo_gan,n_mariposas)
29
- st.session_state['ims'] = ims
 
30
 
31
  ## Si no hay una imagen generada entonces generala
32
- if 'ims' not in st.session_state:
33
- st.session_state['ims'] = None
34
  corre()
35
 
36
  ## ims contiene las imágenes generadas
37
- ims=st.session_state["ims"]
38
 
39
  ## Si la usuaria da click en el botón entonces corremos la función genera()
40
- runb=st.button("Genera mariposas, porfa.", on_click=corre ,help="Estamos en pleno vuelo, puede tardar.")
 
 
 
 
41
 
42
  if ims is not None:
43
- cols=st.columns(n_mariposas)
44
- # picks=[False]*n_mariposas
45
- for j,im in enumerate(ims):
46
- i=j%n_mariposas
47
  cols[i].image(im, use_column_width=True)
48
-
49
-
1
+ import streamlit as st
 
2
 
3
+ from utils import carga_modelo, genera
4
+
5
+ ## Página principal
6
  st.title("Butterfly GAN (GAN de mariposas)")
7
+ st.write(
8
+ "Modelo Light-GAN entrenado con 1000 imágenes de mariposas tomadas de la colección del Museo Smithsonian."
9
+ )
 
10
 
11
+ ## Barra lateral
12
  st.sidebar.subheader("¡Esta mariposa no existe! Ni en América Latina 🤯.")
13
  st.sidebar.image("assets/logo.png", width=200)
14
+ st.sidebar.caption(
15
+ f"[Modelo](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) y [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) usados."
16
+ )
17
+ st.sidebar.caption(f"*Disclaimers:*")
18
+ st.sidebar.caption(
19
+ "* Este demo es una versión simplificada del creado por [Ceyda Cinarel](https://github.com/cceyda) y [Jonathan Whitaker](https://datasciencecastnet.home.blog/) ([link](https://huggingface.co/spaces/huggan/butterfly-gan)) durante el hackathon [HugGan](https://github.com/huggingface/community-events). Cualquier error se atribuye a [Omar Espejel](https://twitter.com/espejelomar)."
20
+ )
21
+ st.sidebar.caption(
22
+ "* Modelo basado en el [paper](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*."
23
+ )
24
 
25
  ## Cargamos modelo
26
+ repo_id = "ceyda/butterfly_cropped_uniq1K_512"
27
+ version_modelo = "57d36a15546909557d9f967f47713236c8288838"
28
+ modelo_gan = carga_modelo(repo_id, version_modelo)
29
 
30
  ## Generamos 4 mariposas
31
+ n_mariposas = 4
32
 
33
  ## Función que genera mariposas y lo guarda como un estado de la sesión
34
  def corre():
35
  with st.spinner("Generando, espera un poco..."):
36
+ ims = genera(modelo_gan, n_mariposas)
37
+ st.session_state["ims"] = ims
38
+
39
 
40
  ## Si no hay una imagen generada entonces generala
41
+ if "ims" not in st.session_state:
42
+ st.session_state["ims"] = None
43
  corre()
44
 
45
  ## ims contiene las imágenes generadas
46
+ ims = st.session_state["ims"]
47
 
48
  ## Si la usuaria da click en el botón entonces corremos la función genera()
49
+ corre_boton = st.button(
50
+ "Genera mariposas, porfa.",
51
+ on_click=corre,
52
+ help="Estamos en pleno vuelo, puede tardar.",
53
+ )
54
 
55
  if ims is not None:
56
+ cols = st.columns(n_mariposas)
57
+ for j, im in enumerate(ims):
58
+ i = j % n_mariposas
 
59
  cols[i].image(im, use_column_width=True)
 
 
utils.py CHANGED
@@ -1,15 +1,18 @@
 
1
  import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
- import numpy as np
4
 
5
 
6
- def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
7
- gan = LightweightGAN.from_pretrained(model_name,version=model_version)
 
8
  gan.eval()
9
  return gan
10
-
11
- def generate(gan,batch_size=1):
 
 
12
  with torch.no_grad():
13
- ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
14
- ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
15
- return ims
1
+ import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
 
4
 
5
 
6
+ ## Cargamos el modelo desde el Hub de Hugging Face
7
+ def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
8
+ gan = LightweightGAN.from_pretrained(model_name, version=model_version)
9
  gan.eval()
10
  return gan
11
+
12
+
13
+ ## Usamos el modelo GAN para generar imágenes
14
+ def genera(gan, batch_size=1):
15
  with torch.no_grad():
16
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
17
+ ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
18
+ return ims