AlekseyKorshuk commited on
Commit
fd63a18
1 Parent(s): bdbeae2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -179
app.py CHANGED
@@ -1,121 +1,75 @@
1
- import subprocess
2
- from pathlib import Path
3
- import einops
4
- import numpy as np
5
- import torch
6
- from huggingface_hub import hf_hub_download
7
- from PIL import Image
8
- from torch import nn
9
- from torchvision.utils import save_image
10
  from huggingface_hub.hf_api import HfApi
11
  import streamlit as st
 
 
12
 
13
  hfapi = HfApi()
 
14
 
 
 
15
 
16
- class Generator(nn.Module):
17
- def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
18
- super(Generator, self).__init__()
19
- self.model = nn.Sequential(
20
- # input is Z, going into a convolution
21
- nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
22
- nn.BatchNorm2d(hidden_size * 8),
23
- nn.ReLU(True),
24
- # state size. (hidden_size*8) x 4 x 4
25
- nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
26
- nn.BatchNorm2d(hidden_size * 4),
27
- nn.ReLU(True),
28
- # state size. (hidden_size*4) x 8 x 8
29
- nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
30
- nn.BatchNorm2d(hidden_size * 2),
31
- nn.ReLU(True),
32
- # state size. (hidden_size*2) x 16 x 16
33
- nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
34
- nn.BatchNorm2d(hidden_size),
35
- nn.ReLU(True),
36
- # state size. (hidden_size) x 32 x 32
37
- nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
38
- nn.Tanh()
39
- # state size. (num_channels) x 64 x 64
40
- )
41
 
42
- def forward(self, noise):
43
- pixel_values = self.model(noise)
44
-
45
- return pixel_values
46
-
47
-
48
- @torch.no_grad()
49
- def interpolate(model, save_dir='./lerp/', frames=100, rows=8, cols=8):
50
- save_dir = Path(save_dir)
51
- save_dir.mkdir(exist_ok=True, parents=True)
52
-
53
- z1 = torch.randn(rows * cols, 100, 1, 1)
54
- z2 = torch.randn(rows * cols, 100, 1, 1)
55
-
56
- zs = []
57
- for i in range(frames):
58
- alpha = i / frames
59
- z = (1 - alpha) * z1 + alpha * z2
60
- zs.append(z)
61
-
62
- zs += zs[::-1] # also go in reverse order to complete loop
63
-
64
- frames = []
65
- for i, z in enumerate(zs):
66
- imgs = model(z)
67
-
68
- save_image(imgs, save_dir / f"{i:03}.png", normalize=True)
69
- img = Image.open(save_dir / f"{i:03}.png").convert('RGBA')
70
- img.putalpha(255)
71
- frames.append(img)
72
- img.save(save_dir / f"{i:03}.png")
73
- frames[0].save("out.gif", format="GIF", append_images=frames,
74
- save_all=True, duration=100, loop=1)
75
-
76
-
77
- def predict(model_name, choice, seed):
78
- try:
79
- model = Generator(3)
80
- weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
81
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
82
- except:
83
- model = Generator(4)
84
- weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
85
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
86
- torch.manual_seed(seed)
87
-
88
- if choice == 'interpolation':
89
- interpolate(model)
90
- return 'out.gif'
91
- else:
92
- z = torch.randn(64, 100, 1, 1)
93
- punks = model(z)
94
- save_image(punks, "image.png", normalize=True)
95
- img = Image.open(f"image.png").convert('RGBA')
96
- img.putalpha(255)
97
- img.save("image.png")
98
- return 'image.png'
99
 
 
 
 
 
 
100
 
101
- model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
 
102
 
103
- st.set_page_config(page_title="Hugging NFT")
104
 
105
- st.title("Hugging NFT")
106
- st.sidebar.markdown(
107
- """
108
- <style>
109
- .aligncenter {
110
- text-align: center;
111
- }
112
- </style>
113
- <p class="aligncenter">
114
- <img src="https://raw.githubusercontent.com/AlekseyKorshuk/optimum-transformers/master/data/social_preview.png" width="300" />
115
- </p>
116
- """,
117
- unsafe_allow_html=True,
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  st.sidebar.markdown(
120
  """
121
  <style>
@@ -123,11 +77,9 @@ st.sidebar.markdown(
123
  text-align: center;
124
  }
125
  </style>
126
-
127
  <p style='text-align: center'>
128
- <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">GitHub</a>
129
  </p>
130
-
131
  <p class="aligncenter">
132
  <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
133
  <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
@@ -142,73 +94,95 @@ st.sidebar.markdown(
142
  unsafe_allow_html=True,
143
  )
144
 
145
- st.markdown(
146
- "🤗 [Hugging NFT](https://github.com/AlekseyKorshuk/huggingnft) - Generate NFT by OpenSea collection name.")
147
-
148
- st.markdown(
149
- "🚀️ SN-GAN used to train all models.")
150
-
151
- st.markdown(
152
- "⁉️ Want to train your model? Check [project repository](https://github.com/AlekseyKorshuk/huggingnft) and make this in few clicks!")
153
- #
154
- # st.markdown("🚀 Up to 1ms on Bert-based transformers")
155
- #
156
- # st.markdown(
157
- # "‼️ NOTE: This Space **does not show** the real power of this project because: low recources, not possbile to optimize models. Check [project repository](https://github.com/AlekseyKorshuk/optimum-transformers) with real bechmarks!")
158
-
159
- # st.sidebar.header("Settings:")
160
- model_name = st.selectbox(
161
- 'Choose model:',
162
- model_names)
163
-
164
- output_type = st.selectbox(
165
- 'Output type:',
166
- ['image', 'interpolation'])
167
-
168
- seed_value = st.slider("Seed:",
169
- min_value=1,
170
- max_value=1000,
171
- step=1,
172
- value=100,
173
- )
174
-
175
- model_html = """
176
-
177
- <div class="inline-flex flex-col" style="line-height: 1.5;">
178
- <div class="flex">
179
- <div
180
- \t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url(&#39;USER_PROFILE&#39;)">
181
- </div>
182
- </div>
183
- <div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">🤖 HuggingArtists Model 🤖</div>
184
- <div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div>
185
- <a href="https://genius.com/artists/USER_HANDLE">
186
- \t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div>
187
- </a>
188
- </div>
189
- """
190
-
191
- if st.button("Run"):
192
- with st.spinner(text=f"Generating..."):
193
- st.image(predict(model_name, output_type, seed_value))
194
- st.subheader("Please star project repository, this space and follow my Twitter:")
195
- st.markdown(
196
- """
197
- <style>
198
- .aligncenter {
199
- text-align: center;
200
- }
201
- </style>
202
- <p class="aligncenter">
203
- <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
204
- <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
205
- </a>
206
- </p>
207
- <p class="aligncenter">
208
- <a href="https://twitter.com/alekseykorshuk" target="_blank">
209
- <img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/>
210
- </a>
211
- </p>
212
- """,
213
- unsafe_allow_html=True,
214
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from huggingnft.lightweight_gan.train import timestamped_filename
3
+ from streamlit_option_menu import option_menu
4
+
5
+ from huggingface_hub import hf_hub_download, file_download
6
+
 
 
 
7
  from huggingface_hub.hf_api import HfApi
8
  import streamlit as st
9
+ from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer
10
+ from accelerate import Accelerator
11
 
12
  hfapi = HfApi()
13
+ model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
14
 
15
+ # streamlit-option-menu
16
+ # st.set_page_config(page_title="Sharone's Streamlit App Gallery", page_icon="", layout="wide")
17
 
18
+ # sysmenu = '''
19
+ # <style>
20
+ # #MainMenu {visibility:hidden;}
21
+ # footer {visibility:hidden;}
22
+ # '''
23
+ # st.markdown(sysmenu,unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # # Add a logo (optional) in the sidebar
26
+ # logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png')
27
+ # profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name."
30
+ CONTACT_TEXT = "Here is some contact info"
31
+ GENERATE_IMAGE_TEXT = "Text about generation"
32
+ INTERPOLATION_TEXT = "Text about Interpolation"
33
+ COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection"
34
 
35
+ STOPWORDS = ["-old"]
36
+ COLLECTION2COLLECTION_KEYS = ["2"]
37
 
 
38
 
39
+ def load_lightweight_model(model_name):
40
+ file_path = file_download.hf_hub_download(
41
+ repo_id=model_name,
42
+ filename="config.json"
43
+ )
44
+ config = json.loads(open(file_path).read())
45
+ organization_name, name = model_name.split("/")
46
+ model = Trainer(**config, organization_name=organization_name, name=name)
47
+ model.load(use_cpu=True)
48
+ model.accelerator = Accelerator()
49
+ return model
50
+
51
+
52
+ def clean_models(model_names, stopwords):
53
+ cleaned_model_names = []
54
+ for model_name in model_names:
55
+ clear = True
56
+ for stopword in stopwords:
57
+ if stopword in model_name:
58
+ clear = False
59
+ break
60
+ if clear:
61
+ cleaned_model_names.append(model_name)
62
+ return cleaned_model_names
63
+
64
+
65
+ model_names = clean_models(model_names, STOPWORDS)
66
+
67
+ with st.sidebar:
68
+ choose = option_menu("Hugging NFT",
69
+ ["About", "Generate image", "Interpolation", "Collection2Collection", "Contact"],
70
+ icons=['house', 'camera fill', 'bi bi-youtube', 'book', 'person lines fill'],
71
+ menu_icon="app-indicator", default_index=0,
72
+ )
73
  st.sidebar.markdown(
74
  """
75
  <style>
 
77
  text-align: center;
78
  }
79
  </style>
 
80
  <p style='text-align: center'>
81
+ <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">Project Repository</a>
82
  </p>
 
83
  <p class="aligncenter">
84
  <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">
85
  <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/>
 
94
  unsafe_allow_html=True,
95
  )
96
 
97
+ if choose == "About":
98
+ st.title(choose)
99
+ st.markdown(ABOUT_TEXT)
100
+
101
+ if choose == "Contact":
102
+ st.title(choose)
103
+ st.markdown(CONTACT_TEXT)
104
+
105
+ if choose == "Generate image":
106
+ st.title(choose)
107
+ st.markdown(GENERATE_IMAGE_TEXT)
108
+
109
+ model_name = st.selectbox(
110
+ 'Choose model:',
111
+ clean_models(model_names, COLLECTION2COLLECTION_KEYS)
112
+ )
113
+ generation_type = st.selectbox(
114
+ 'Select generation type:',
115
+ ["default", "ema"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
+
118
+ nrows = st.number_input("Number of rows:",
119
+ min_value=1,
120
+ max_value=10,
121
+ step=1,
122
+ value=8,
123
+ )
124
+ generate_image_button = st.button("Generate")
125
+
126
+ if generate_image_button:
127
+ with st.spinner(text=f"Downloading selected model..."):
128
+ model = load_lightweight_model(f"huggingnft/{model_name}")
129
+ with st.spinner(text=f"Generating..."):
130
+ st.image(
131
+ model.generate_app(
132
+ num=timestamped_filename(),
133
+ nrow=nrows,
134
+ checkpoint=-1,
135
+ types=generation_type
136
+ )
137
+ )
138
+
139
+ if choose == "Interpolation":
140
+ st.title(choose)
141
+ st.markdown(INTERPOLATION_TEXT)
142
+
143
+ model_name = st.selectbox(
144
+ 'Choose model:',
145
+ clean_models(model_names, COLLECTION2COLLECTION_KEYS)
146
+ )
147
+ nrows = st.number_input("Number of rows:",
148
+ min_value=1,
149
+ max_value=10,
150
+ step=1,
151
+ value=1,
152
+ )
153
+
154
+ num_steps = st.number_input("Number of steps:",
155
+ min_value=1,
156
+ max_value=1000,
157
+ step=1,
158
+ value=100,
159
+ )
160
+ generate_image_button = st.button("Generate")
161
+
162
+ if generate_image_button:
163
+ with st.spinner(text=f"Downloading selected model..."):
164
+ model = load_lightweight_model(f"huggingnft/{model_name}")
165
+ my_bar = st.progress(0)
166
+ result = model.generate_interpolation(
167
+ num=timestamped_filename(),
168
+ num_image_tiles=nrows,
169
+ num_steps=num_steps,
170
+ save_frames=False,
171
+ progress_bar=my_bar
172
+ )
173
+ my_bar.empty()
174
+ with st.spinner(text=f"Uploading result..."):
175
+ st.image(result)
176
+
177
+ if choose == "Collection2Collection":
178
+ st.title(choose)
179
+ st.markdown(INTERPOLATION_TEXT)
180
+
181
+ model_name = st.selectbox(
182
+ 'Choose model:',
183
+ set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS))
184
+ )
185
+ generate_image_button = st.button("Generate")
186
+
187
+ if generate_image_button:
188
+ st.markdown("generating Collection2Collection")