bhadresh-savani commited on
Commit
7fc0728
1 Parent(s): 6d6a110

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -268
app.py CHANGED
@@ -1,271 +1,23 @@
1
  import gradio as gr
2
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
3
- import matplotlib.pyplot as plt
4
  from PIL import Image
5
- import numpy as np
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torch
9
-
10
- from pathlib import Path
11
- from re import TEMPLATE
12
- from typing import Optional, Union
13
- import os
14
-
15
- from huggingface_hub import PyTorchModelHubMixin, HfApi, HfFolder, Repository
16
-
17
- TEMPLATE_MODEL_CARD_PATH = "dummy"
18
-
19
- class HugGANModelHubMixin(PyTorchModelHubMixin):
20
- """A mixin to push PyTorch Models to the Hugging Face Hub. This
21
- mixin was adapted from the PyTorchModelHubMixin to also push a template
22
- README.md for the HugGAN sprint.
23
- """
24
-
25
- def push_to_hub(
26
- self,
27
- repo_path_or_name: Optional[str] = None,
28
- repo_url: Optional[str] = None,
29
- commit_message: Optional[str] = "Add model",
30
- organization: Optional[str] = None,
31
- private: Optional[bool] = None,
32
- api_endpoint: Optional[str] = None,
33
- use_auth_token: Optional[Union[bool, str]] = None,
34
- git_user: Optional[str] = None,
35
- git_email: Optional[str] = None,
36
- config: Optional[dict] = None,
37
- skip_lfs_files: bool = False,
38
- default_model_card: Optional[str] = TEMPLATE_MODEL_CARD_PATH
39
- ) -> str:
40
- """
41
- Upload model checkpoint or tokenizer files to the Hub while
42
- synchronizing a local clone of the repo in `repo_path_or_name`.
43
- Parameters:
44
- repo_path_or_name (`str`, *optional*):
45
- Can either be a repository name for your model or tokenizer in
46
- the Hub or a path to a local folder (in which case the
47
- repository will have the name of that local folder). If not
48
- specified, will default to the name given by `repo_url` and a
49
- local directory with that name will be created.
50
- repo_url (`str`, *optional*):
51
- Specify this in case you want to push to an existing repository
52
- in the hub. If unspecified, a new repository will be created in
53
- your namespace (unless you specify an `organization`) with
54
- `repo_name`.
55
- commit_message (`str`, *optional*):
56
- Message to commit while pushing. Will default to `"add config"`,
57
- `"add tokenizer"` or `"add model"` depending on the type of the
58
- class.
59
- organization (`str`, *optional*):
60
- Organization in which you want to push your model or tokenizer
61
- (you must be a member of this organization).
62
- private (`bool`, *optional*):
63
- Whether the repository created should be private.
64
- api_endpoint (`str`, *optional*):
65
- The API endpoint to use when pushing the model to the hub.
66
- use_auth_token (`bool` or `str`, *optional*):
67
- The token to use as HTTP bearer authorization for remote files.
68
- If `True`, will use the token generated when running
69
- `transformers-cli login` (stored in `~/.huggingface`). Will
70
- default to `True` if `repo_url` is not specified.
71
- git_user (`str`, *optional*):
72
- will override the `git config user.name` for committing and
73
- pushing files to the hub.
74
- git_email (`str`, *optional*):
75
- will override the `git config user.email` for committing and
76
- pushing files to the hub.
77
- config (`dict`, *optional*):
78
- Configuration object to be saved alongside the model weights.
79
- default_model_card (`str`, *optional*):
80
- Path to a markdown file to use as your default model card.
81
- Returns:
82
- The url of the commit of your model in the given repository.
83
- """
84
-
85
- if repo_path_or_name is None and repo_url is None:
86
- raise ValueError(
87
- "You need to specify a `repo_path_or_name` or a `repo_url`."
88
- )
89
-
90
- if use_auth_token is None and repo_url is None:
91
- token = HfFolder.get_token()
92
- if token is None:
93
- raise ValueError(
94
- "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
95
- "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
96
- "token as the `use_auth_token` argument."
97
- )
98
- elif isinstance(use_auth_token, str):
99
- token = use_auth_token
100
- else:
101
- token = None
102
-
103
- if repo_path_or_name is None:
104
- repo_path_or_name = repo_url.split("/")[-1]
105
-
106
- # If no URL is passed and there's no path to a directory containing files, create a repo
107
- if repo_url is None and not os.path.exists(repo_path_or_name):
108
- repo_id = Path(repo_path_or_name).name
109
- if organization:
110
- repo_id = f"{organization}/{repo_id}"
111
- repo_url = HfApi(endpoint=api_endpoint).create_repo(
112
- repo_id=repo_id,
113
- token=token,
114
- private=private,
115
- repo_type=None,
116
- exist_ok=True,
117
- )
118
-
119
- repo = Repository(
120
- repo_path_or_name,
121
- clone_from=repo_url,
122
- use_auth_token=use_auth_token,
123
- git_user=git_user,
124
- git_email=git_email,
125
- skip_lfs_files=skip_lfs_files
126
- )
127
- repo.git_pull(rebase=True)
128
-
129
- # Save the files in the cloned repo
130
- self.save_pretrained(repo_path_or_name, config=config)
131
-
132
- model_card_path = Path(repo_path_or_name) / 'README.md'
133
- if not model_card_path.exists():
134
- model_card_path.write_text(TEMPLATE_MODEL_CARD_PATH.read_text())
135
-
136
- # Commit and push!
137
- repo.git_add()
138
- repo.git_commit(commit_message)
139
- return repo.git_push()
140
-
141
-
142
- def weights_init_normal(m):
143
- classname = m.__class__.__name__
144
- if classname.find("Conv") != -1:
145
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
146
- elif classname.find("BatchNorm2d") != -1:
147
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
148
- torch.nn.init.constant_(m.bias.data, 0.0)
149
-
150
-
151
- ##############################
152
- # U-NET
153
- ##############################
154
-
155
-
156
- class UNetDown(nn.Module):
157
- def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
158
- super(UNetDown, self).__init__()
159
- layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
160
- if normalize:
161
- layers.append(nn.InstanceNorm2d(out_size))
162
- layers.append(nn.LeakyReLU(0.2))
163
- if dropout:
164
- layers.append(nn.Dropout(dropout))
165
- self.model = nn.Sequential(*layers)
166
-
167
- def forward(self, x):
168
- return self.model(x)
169
-
170
-
171
- class UNetUp(nn.Module):
172
- def __init__(self, in_size, out_size, dropout=0.0):
173
- super(UNetUp, self).__init__()
174
- layers = [
175
- nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
176
- nn.InstanceNorm2d(out_size),
177
- nn.ReLU(inplace=True),
178
- ]
179
- if dropout:
180
- layers.append(nn.Dropout(dropout))
181
-
182
- self.model = nn.Sequential(*layers)
183
-
184
- def forward(self, x, skip_input):
185
- x = self.model(x)
186
- x = torch.cat((x, skip_input), 1)
187
-
188
- return x
189
-
190
-
191
- class GeneratorUNet(nn.Module, HugGANModelHubMixin):
192
- def __init__(self, in_channels=3, out_channels=3):
193
- super(GeneratorUNet, self).__init__()
194
-
195
- self.down1 = UNetDown(in_channels, 64, normalize=False)
196
- self.down2 = UNetDown(64, 128)
197
- self.down3 = UNetDown(128, 256)
198
- self.down4 = UNetDown(256, 512, dropout=0.5)
199
- self.down5 = UNetDown(512, 512, dropout=0.5)
200
- self.down6 = UNetDown(512, 512, dropout=0.5)
201
- self.down7 = UNetDown(512, 512, dropout=0.5)
202
- self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
203
-
204
- self.up1 = UNetUp(512, 512, dropout=0.5)
205
- self.up2 = UNetUp(1024, 512, dropout=0.5)
206
- self.up3 = UNetUp(1024, 512, dropout=0.5)
207
- self.up4 = UNetUp(1024, 512, dropout=0.5)
208
- self.up5 = UNetUp(1024, 256)
209
- self.up6 = UNetUp(512, 128)
210
- self.up7 = UNetUp(256, 64)
211
-
212
- self.final = nn.Sequential(
213
- nn.Upsample(scale_factor=2),
214
- nn.ZeroPad2d((1, 0, 1, 0)),
215
- nn.Conv2d(128, out_channels, 4, padding=1),
216
- nn.Tanh(),
217
- )
218
-
219
- def forward(self, x):
220
- # U-Net generator with skip connections from encoder to decoder
221
- d1 = self.down1(x)
222
- d2 = self.down2(d1)
223
- d3 = self.down3(d2)
224
- d4 = self.down4(d3)
225
- d5 = self.down5(d4)
226
- d6 = self.down6(d5)
227
- d7 = self.down7(d6)
228
- d8 = self.down8(d7)
229
- u1 = self.up1(d8, d7)
230
- u2 = self.up2(u1, d6)
231
- u3 = self.up3(u2, d5)
232
- u4 = self.up4(u3, d4)
233
- u5 = self.up5(u4, d3)
234
- u6 = self.up6(u5, d2)
235
- u7 = self.up7(u6, d1)
236
-
237
- return self.final(u7)
238
-
239
- def load_image_infer(image_file):
240
- # Configure dataloaders
241
- transform = Compose([
242
- Resize((args.image_size, args.image_size), Image.BICUBIC),
243
- ToTensor(),
244
- Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
245
- ])
246
- image_file = Image.fromarray(np.array(image_file)[:, ::-1, :], "RGB")
247
- image_file = transform(image_file)
248
-
249
- return image_file
250
-
251
- def generate_images(test_input):
252
- test_input = load_image_infer(test_input)
253
- prediction = generator(test_input).data
254
- fig = plt.figure(figsize=(128, 128))
255
- title = ['Predicted Image']
256
-
257
- plt.title('Predicted Image')
258
- # Getting the pixel values in the [0, 1] range to plot.
259
- plt.imshow(prediction[0,:,:,:] * 0.5 + 0.5)
260
- plt.axis('off')
261
- return fig
262
-
263
- generator = GeneratorUNet()
264
- generator.from_pretrained("huggan/pix2pix-edge2shoes")
265
-
266
- img = gr.inputs.Image(shape=(256,256))
267
- plot = gr.outputs.Image(type="plot")
268
-
269
- description = "Pix2pix model that translates image-to-image."
270
- gr.Interface(generate_images, inputs = img, outputs = plot,
271
- title = "Pix2Pix Shoes Reconstructor", description = description).launch()
 
1
  import gradio as gr
2
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
 
3
  from PIL import Image
4
+ from torchvision.utils import save_image
5
+
6
+ from huggan.pytorch.pix2pix.modeling_pix2pix import GeneratorUNet
7
+
8
+ transform = Compose(
9
+ [
10
+ Resize((256, 256), Image.BICUBIC),
11
+ ToTensor(),
12
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
13
+ ]
14
+ )
15
+ model = GeneratorUNet.from_pretrained('huggan/pix2pix-cityscapes')
16
+
17
+ def predict_fn(img):
18
+ inp = transform(img).unsqueeze(0)
19
+ out = model(inp)
20
+ save_image(out, 'out.png', normalize=True)
21
+ return 'out.png'
22
+
23
+ gr.Interface(predict_fn, inputs=gr.inputs.Image(type='pil'), outputs='image', examples=[['sample.jpeg']]).launch()