bhadresh-savani commited on
Commit
b09c42f
1 Parent(s): 7c82ec5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -22
app.py CHANGED
@@ -1,23 +1,241 @@
1
- import tensorflow as tf
2
- import pathlib
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
- from huggingface_hub import from_pretrained_keras
6
- import numpy as np
7
 
8
- # Normalizing the images to [-1, 1]
9
- def normalize_test(input_image):
10
- input_image = tf.cast(input_image, tf.float32)
11
- input_image = (input_image / 127.5) - 1
12
- return input_image
13
-
14
- def resize(input_image, height, width):
15
- input_image = tf.image.resize(input_image, [height, width],
16
- method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
17
- return input_image
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_image_infer(image_file):
20
- input_image = resize(image_file, 256, 256)
21
  input_image = normalize_test(input_image)
22
 
23
  return input_image
@@ -27,20 +245,19 @@ def generate_images(test_input):
27
  prediction = generator(np.expand_dims(test_input, axis=0), training=True)
28
  fig = plt.figure(figsize=(128, 128))
29
  title = ['Predicted Image']
30
-
31
  plt.title('Predicted Image')
32
  # Getting the pixel values in the [0, 1] range to plot.
33
  plt.imshow(prediction[0,:,:,:] * 0.5 + 0.5)
34
  plt.axis('off')
35
  return fig
36
-
37
-
38
- generator = from_pretrained_keras("keras-io/pix2pix-generator")
39
 
 
 
40
 
41
  img = gr.inputs.Image(shape=(256,256))
42
  plot = gr.outputs.Image(type="plot")
43
 
44
- description = "Conditional GAN model that translates image-to-image."
45
  gr.Interface(generate_images, inputs = img, outputs = plot,
46
- title = "Pix2Pix Shoes Reconstructor", description = description, examples = [["./img.png"]]).launch()
 
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
 
 
3
 
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch
7
+
8
+ from pathlib import Path
9
+ from re import TEMPLATE
10
+ from typing import Optional, Union
11
+ import os
12
+
13
+ from huggingface_hub import PyTorchModelHubMixin, HfApi, HfFolder, Repository
14
+
15
+ TEMPLATE_MODEL_CARD_PATH = "dummy"
16
+
17
+ class HugGANModelHubMixin(PyTorchModelHubMixin):
18
+ """A mixin to push PyTorch Models to the Hugging Face Hub. This
19
+ mixin was adapted from the PyTorchModelHubMixin to also push a template
20
+ README.md for the HugGAN sprint.
21
+ """
22
+
23
+ def push_to_hub(
24
+ self,
25
+ repo_path_or_name: Optional[str] = None,
26
+ repo_url: Optional[str] = None,
27
+ commit_message: Optional[str] = "Add model",
28
+ organization: Optional[str] = None,
29
+ private: Optional[bool] = None,
30
+ api_endpoint: Optional[str] = None,
31
+ use_auth_token: Optional[Union[bool, str]] = None,
32
+ git_user: Optional[str] = None,
33
+ git_email: Optional[str] = None,
34
+ config: Optional[dict] = None,
35
+ skip_lfs_files: bool = False,
36
+ default_model_card: Optional[str] = TEMPLATE_MODEL_CARD_PATH
37
+ ) -> str:
38
+ """
39
+ Upload model checkpoint or tokenizer files to the Hub while
40
+ synchronizing a local clone of the repo in `repo_path_or_name`.
41
+ Parameters:
42
+ repo_path_or_name (`str`, *optional*):
43
+ Can either be a repository name for your model or tokenizer in
44
+ the Hub or a path to a local folder (in which case the
45
+ repository will have the name of that local folder). If not
46
+ specified, will default to the name given by `repo_url` and a
47
+ local directory with that name will be created.
48
+ repo_url (`str`, *optional*):
49
+ Specify this in case you want to push to an existing repository
50
+ in the hub. If unspecified, a new repository will be created in
51
+ your namespace (unless you specify an `organization`) with
52
+ `repo_name`.
53
+ commit_message (`str`, *optional*):
54
+ Message to commit while pushing. Will default to `"add config"`,
55
+ `"add tokenizer"` or `"add model"` depending on the type of the
56
+ class.
57
+ organization (`str`, *optional*):
58
+ Organization in which you want to push your model or tokenizer
59
+ (you must be a member of this organization).
60
+ private (`bool`, *optional*):
61
+ Whether the repository created should be private.
62
+ api_endpoint (`str`, *optional*):
63
+ The API endpoint to use when pushing the model to the hub.
64
+ use_auth_token (`bool` or `str`, *optional*):
65
+ The token to use as HTTP bearer authorization for remote files.
66
+ If `True`, will use the token generated when running
67
+ `transformers-cli login` (stored in `~/.huggingface`). Will
68
+ default to `True` if `repo_url` is not specified.
69
+ git_user (`str`, *optional*):
70
+ will override the `git config user.name` for committing and
71
+ pushing files to the hub.
72
+ git_email (`str`, *optional*):
73
+ will override the `git config user.email` for committing and
74
+ pushing files to the hub.
75
+ config (`dict`, *optional*):
76
+ Configuration object to be saved alongside the model weights.
77
+ default_model_card (`str`, *optional*):
78
+ Path to a markdown file to use as your default model card.
79
+ Returns:
80
+ The url of the commit of your model in the given repository.
81
+ """
82
+
83
+ if repo_path_or_name is None and repo_url is None:
84
+ raise ValueError(
85
+ "You need to specify a `repo_path_or_name` or a `repo_url`."
86
+ )
87
+
88
+ if use_auth_token is None and repo_url is None:
89
+ token = HfFolder.get_token()
90
+ if token is None:
91
+ raise ValueError(
92
+ "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
93
+ "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
94
+ "token as the `use_auth_token` argument."
95
+ )
96
+ elif isinstance(use_auth_token, str):
97
+ token = use_auth_token
98
+ else:
99
+ token = None
100
+
101
+ if repo_path_or_name is None:
102
+ repo_path_or_name = repo_url.split("/")[-1]
103
+
104
+ # If no URL is passed and there's no path to a directory containing files, create a repo
105
+ if repo_url is None and not os.path.exists(repo_path_or_name):
106
+ repo_id = Path(repo_path_or_name).name
107
+ if organization:
108
+ repo_id = f"{organization}/{repo_id}"
109
+ repo_url = HfApi(endpoint=api_endpoint).create_repo(
110
+ repo_id=repo_id,
111
+ token=token,
112
+ private=private,
113
+ repo_type=None,
114
+ exist_ok=True,
115
+ )
116
+
117
+ repo = Repository(
118
+ repo_path_or_name,
119
+ clone_from=repo_url,
120
+ use_auth_token=use_auth_token,
121
+ git_user=git_user,
122
+ git_email=git_email,
123
+ skip_lfs_files=skip_lfs_files
124
+ )
125
+ repo.git_pull(rebase=True)
126
+
127
+ # Save the files in the cloned repo
128
+ self.save_pretrained(repo_path_or_name, config=config)
129
+
130
+ model_card_path = Path(repo_path_or_name) / 'README.md'
131
+ if not model_card_path.exists():
132
+ model_card_path.write_text(TEMPLATE_MODEL_CARD_PATH.read_text())
133
+
134
+ # Commit and push!
135
+ repo.git_add()
136
+ repo.git_commit(commit_message)
137
+ return repo.git_push()
138
+
139
+
140
+ def weights_init_normal(m):
141
+ classname = m.__class__.__name__
142
+ if classname.find("Conv") != -1:
143
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
144
+ elif classname.find("BatchNorm2d") != -1:
145
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
146
+ torch.nn.init.constant_(m.bias.data, 0.0)
147
+
148
+
149
+ ##############################
150
+ # U-NET
151
+ ##############################
152
+
153
+
154
+ class UNetDown(nn.Module):
155
+ def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
156
+ super(UNetDown, self).__init__()
157
+ layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
158
+ if normalize:
159
+ layers.append(nn.InstanceNorm2d(out_size))
160
+ layers.append(nn.LeakyReLU(0.2))
161
+ if dropout:
162
+ layers.append(nn.Dropout(dropout))
163
+ self.model = nn.Sequential(*layers)
164
+
165
+ def forward(self, x):
166
+ return self.model(x)
167
+
168
+
169
+ class UNetUp(nn.Module):
170
+ def __init__(self, in_size, out_size, dropout=0.0):
171
+ super(UNetUp, self).__init__()
172
+ layers = [
173
+ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
174
+ nn.InstanceNorm2d(out_size),
175
+ nn.ReLU(inplace=True),
176
+ ]
177
+ if dropout:
178
+ layers.append(nn.Dropout(dropout))
179
+
180
+ self.model = nn.Sequential(*layers)
181
+
182
+ def forward(self, x, skip_input):
183
+ x = self.model(x)
184
+ x = torch.cat((x, skip_input), 1)
185
+
186
+ return x
187
+
188
+
189
+ class GeneratorUNet(nn.Module, HugGANModelHubMixin):
190
+ def __init__(self, in_channels=3, out_channels=3):
191
+ super(GeneratorUNet, self).__init__()
192
+
193
+ self.down1 = UNetDown(in_channels, 64, normalize=False)
194
+ self.down2 = UNetDown(64, 128)
195
+ self.down3 = UNetDown(128, 256)
196
+ self.down4 = UNetDown(256, 512, dropout=0.5)
197
+ self.down5 = UNetDown(512, 512, dropout=0.5)
198
+ self.down6 = UNetDown(512, 512, dropout=0.5)
199
+ self.down7 = UNetDown(512, 512, dropout=0.5)
200
+ self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
201
+
202
+ self.up1 = UNetUp(512, 512, dropout=0.5)
203
+ self.up2 = UNetUp(1024, 512, dropout=0.5)
204
+ self.up3 = UNetUp(1024, 512, dropout=0.5)
205
+ self.up4 = UNetUp(1024, 512, dropout=0.5)
206
+ self.up5 = UNetUp(1024, 256)
207
+ self.up6 = UNetUp(512, 128)
208
+ self.up7 = UNetUp(256, 64)
209
+
210
+ self.final = nn.Sequential(
211
+ nn.Upsample(scale_factor=2),
212
+ nn.ZeroPad2d((1, 0, 1, 0)),
213
+ nn.Conv2d(128, out_channels, 4, padding=1),
214
+ nn.Tanh(),
215
+ )
216
+
217
+ def forward(self, x):
218
+ # U-Net generator with skip connections from encoder to decoder
219
+ d1 = self.down1(x)
220
+ d2 = self.down2(d1)
221
+ d3 = self.down3(d2)
222
+ d4 = self.down4(d3)
223
+ d5 = self.down5(d4)
224
+ d6 = self.down6(d5)
225
+ d7 = self.down7(d6)
226
+ d8 = self.down8(d7)
227
+ u1 = self.up1(d8, d7)
228
+ u2 = self.up2(u1, d6)
229
+ u3 = self.up3(u2, d5)
230
+ u4 = self.up4(u3, d4)
231
+ u5 = self.up5(u4, d3)
232
+ u6 = self.up6(u5, d2)
233
+ u7 = self.up7(u6, d1)
234
+
235
+ return self.final(u7)
236
+
237
  def load_image_infer(image_file):
238
+ imageA = Image.fromarray(np.array(imageA)[:, ::-1, :], "RGB")
239
  input_image = normalize_test(input_image)
240
 
241
  return input_image
 
245
  prediction = generator(np.expand_dims(test_input, axis=0), training=True)
246
  fig = plt.figure(figsize=(128, 128))
247
  title = ['Predicted Image']
248
+
249
  plt.title('Predicted Image')
250
  # Getting the pixel values in the [0, 1] range to plot.
251
  plt.imshow(prediction[0,:,:,:] * 0.5 + 0.5)
252
  plt.axis('off')
253
  return fig
 
 
 
254
 
255
+ generator = GeneratorUNet()
256
+ generator.from_pretrained("huggan/pix2pix-edge2shoes")
257
 
258
  img = gr.inputs.Image(shape=(256,256))
259
  plot = gr.outputs.Image(type="plot")
260
 
261
+ description = "Pix2pix model that translates image-to-image."
262
  gr.Interface(generate_images, inputs = img, outputs = plot,
263
+ title = "Pix2Pix Shoes Reconstructor", description = description).launch()