manhkhanhUIT commited on
Commit
8728fb1
1 Parent(s): d44c9b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -31,20 +31,20 @@ def lab2rgb(L, AB):
31
  rgb = color.lab2rgb(Lab) * 255
32
  return rgb
33
 
34
- def get_transform(params=None, grayscale=False, method=Image.BICUBIC):
35
  #params
36
- preprocess = 'resize_and_crop'
37
  load_size = 256
38
  crop_size = 256
39
  transform_list = []
40
  if grayscale:
41
  transform_list.append(transforms.Grayscale(1))
42
- if 'resize' in preprocess:
43
  osize = [load_size, load_size]
44
  transform_list.append(transforms.Resize(osize, method))
45
- if 'crop' in preprocess:
46
- if params is None:
47
- transform_list.append(transforms.RandomCrop(crop_size))
48
 
49
  return transforms.Compose(transform_list)
50
 
@@ -67,7 +67,7 @@ def inferRestoration(img, model_name):
67
  return result
68
 
69
  def inferColorization(img,model_name):
70
- print(model_name)
71
  if model_name == "Pix2Pix Resnet 9block":
72
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
73
  elif model_name == "Pix2Pix Unet 256":
@@ -96,10 +96,12 @@ def inferColorization(img,model_name):
96
  image_pil = transforms.ToPILImage()(result)
97
  return image_pil
98
 
99
- transform_seq = get_transform()
100
- im = transform_seq(img)
101
- im = np.array(img)
102
- lab = color.rgb2lab(im).astype(np.float32)
 
 
103
  lab_t = transforms.ToTensor()(lab)
104
  A = lab_t[[0], ...] / 50.0 - 1.0
105
  B = lab_t[[1, 2], ...] / 110.0
@@ -160,4 +162,4 @@ examples = [['example/1.jpeg',"BOPBTL","Deoldify"],['example/2.jpg',"BOPBTL","De
160
  iface = gr.Interface(run,
161
  [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
162
  outputs="image",
163
- examples=examples).launch(debug=True,share=False)
 
31
  rgb = color.lab2rgb(Lab) * 255
32
  return rgb
33
 
34
+ def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC):
35
  #params
36
+ preprocess = 'resize'
37
  load_size = 256
38
  crop_size = 256
39
  transform_list = []
40
  if grayscale:
41
  transform_list.append(transforms.Grayscale(1))
42
+ if model_name == "Pix2Pix Unet 256":
43
  osize = [load_size, load_size]
44
  transform_list.append(transforms.Resize(osize, method))
45
+ # if 'crop' in preprocess:
46
+ # if params is None:
47
+ # transform_list.append(transforms.RandomCrop(crop_size))
48
 
49
  return transforms.Compose(transform_list)
50
 
 
67
  return result
68
 
69
  def inferColorization(img,model_name):
70
+ #print(model_name)
71
  if model_name == "Pix2Pix Resnet 9block":
72
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
73
  elif model_name == "Pix2Pix Unet 256":
 
96
  image_pil = transforms.ToPILImage()(result)
97
  return image_pil
98
 
99
+ transform_seq = get_transform(model_name)
100
+ img = transform_seq(img)
101
+ # if model_name == "Pix2Pix Unet 256":
102
+ # img.resize((256,256))
103
+ img = np.array(img)
104
+ lab = color.rgb2lab(img).astype(np.float32)
105
  lab_t = transforms.ToTensor()(lab)
106
  A = lab_t[[0], ...] / 50.0 - 1.0
107
  B = lab_t[[1, 2], ...] / 110.0
 
162
  iface = gr.Interface(run,
163
  [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
164
  outputs="image",
165
+ examples=examples).launch(debug=True,share=True)