vict0rsch commited on
Commit
ae61fe0
·
1 Parent(s): cd31093

update layout and add dev mode

Browse files
Files changed (2) hide show
  1. app.py +110 -32
  2. climategan_wrapper.py +20 -2
app.py CHANGED
@@ -6,18 +6,19 @@ import gradio as gr
6
  import googlemaps
7
  from skimage import io
8
  from urllib import parse
 
9
  from climategan_wrapper import ClimateGAN
10
 
11
 
12
- def predict(api_key):
13
  def _predict(*args):
14
- print("args: ", args)
15
- image = place = None
16
- if len(args) == 1:
17
  image = args[0]
 
18
  else:
19
- assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
20
- image, place = args
21
 
22
  if api_key and place:
23
  geocode_result = gmaps.geocode(place)
@@ -27,8 +28,40 @@ def predict(api_key):
27
  img_np = io.imread(static_map_url)
28
  else:
29
  img_np = image
30
- flood, wildfire, smog = model.inference(img_np)
31
- return img_np, flood, wildfire, smog
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  return _predict
34
 
@@ -40,28 +73,73 @@ if __name__ == "__main__":
40
  if api_key is not None:
41
  gmaps = googlemaps.Client(key=api_key)
42
 
43
- model = ClimateGAN(model_path="config/model/masker")
44
-
45
- inputs = inputs = [gr.inputs.Image(label="Input Image")]
46
- if api_key:
47
- inputs += [gr.inputs.Textbox(label="Address or place name")]
48
 
49
- gr.Interface(
50
- predict(api_key),
51
- inputs=inputs,
52
- outputs=[
53
- gr.outputs.Image(type="numpy", label="Original image"),
54
- gr.outputs.Image(type="numpy", label="Flooding"),
55
- gr.outputs.Image(type="numpy", label="Wildfire"),
56
- gr.outputs.Image(type="numpy", label="Smog"),
57
- ],
58
- title="ClimateGAN: Visualize Climate Change",
59
- description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
60
- article="<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>", # noqa: E501
61
- # examples=[
62
- # "Vancouver Art Gallery",
63
- # "Chicago Bean",
64
- # "Duomo Siracusa",
65
- # ],
66
- css=".footer{display:none !important}",
67
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import googlemaps
7
  from skimage import io
8
  from urllib import parse
9
+ import numpy as np
10
  from climategan_wrapper import ClimateGAN
11
 
12
 
13
+ def predict(cg: ClimateGAN, api_key):
14
  def _predict(*args):
15
+ image = place = painter = None
16
+ if len(args) == 2:
 
17
  image = args[0]
18
+ painter = args[1]
19
  else:
20
+ assert len(args) == 3, "Unknown number of inputs {}".format(len(args))
21
+ image, place, painter = args
22
 
23
  if api_key and place:
24
  geocode_result = gmaps.geocode(place)
 
28
  img_np = io.imread(static_map_url)
29
  else:
30
  img_np = image
31
+ output_dict = cg.infer_single(img_np, painter)
32
+
33
+ input_image = output_dict["input"]
34
+ masked_input = output_dict["masked_input"]
35
+ wildfire = output_dict["wildfire"]
36
+ smog = output_dict["smog"]
37
+
38
+ climategan_flood = output_dict.get(
39
+ "climategan_flood",
40
+ np.ones(input_image.shape) * 255,
41
+ )
42
+ stable_flood = output_dict.get(
43
+ "stable_flood",
44
+ np.ones(input_image.shape) * 255,
45
+ )
46
+ stable_copy_flood = output_dict.get(
47
+ "stable_copy_flood",
48
+ np.ones(input_image.shape) * 255,
49
+ )
50
+ concat = output_dict.get(
51
+ "concat",
52
+ np.ones(input_image.shape) * 255,
53
+ )
54
+
55
+ return (
56
+ input_image,
57
+ masked_input,
58
+ climategan_flood,
59
+ stable_flood,
60
+ stable_copy_flood,
61
+ concat,
62
+ wildfire,
63
+ smog,
64
+ )
65
 
66
  return _predict
67
 
 
73
  if api_key is not None:
74
  gmaps = googlemaps.Client(key=api_key)
75
 
76
+ cg = ClimateGAN(model_path="config/model/masker", dev_mode=True)
77
+ cg._setup_stable_diffusion()
 
 
 
78
 
79
+ with gr.Blocks() as blocks:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ gr.Markdown("# ClimateGAN: Visualize Climate Change")
83
+ gr.HTML(
84
+ 'Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.' # noqa: E501
85
+ )
86
+ with gr.Column():
87
+ gr.HTML(
88
+ "<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>" # noqa: E501
89
+ )
90
+ with gr.Row():
91
+ gr.Markdown("## Inputs")
92
+ with gr.Row():
93
+ with gr.Column():
94
+ inputs = [gr.inputs.Image(label="Input Image")]
95
+ with gr.Column():
96
+ if api_key:
97
+ inputs += [gr.inputs.Textbox(label="Address or place name")]
98
+ inputs += [
99
+ gr.inputs.Dropdown(
100
+ choices=[
101
+ "ClimateGAN Painter",
102
+ "Stable Diffusion Painter",
103
+ "Both",
104
+ ],
105
+ label="Choose Flood Painter",
106
+ default="Both",
107
+ )
108
+ ]
109
+ btn = gr.Button("See for yourself!", label="Run")
110
+ with gr.Row():
111
+ gr.Markdown("## Outputs")
112
+ with gr.Row():
113
+ outputs = []
114
+ outputs.append(
115
+ gr.outputs.Image(type="numpy", label="Original image"),
116
+ )
117
+ outputs.append(
118
+ gr.outputs.Image(type="numpy", label="Masked input image"),
119
+ )
120
+ with gr.Row():
121
+ outputs.append(
122
+ gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
123
+ )
124
+ outputs.append(
125
+ gr.outputs.Image(type="numpy", label="Stable Diffusion-Flooded image"),
126
+ )
127
+ outputs.append(
128
+ gr.outputs.Image(
129
+ type="numpy",
130
+ label="Stable Diffusion-Flooded image (restricted to masked area)",
131
+ )
132
+ ),
133
+ with gr.Row():
134
+ outputs.append(
135
+ gr.outputs.Image(type="numpy", label="Comparison of previous images"),
136
+ )
137
+ with gr.Row():
138
+ outputs.append(
139
+ gr.outputs.Image(type="numpy", label="Wildfire"),
140
+ )
141
+ outputs.append(
142
+ gr.outputs.Image(type="numpy", label="Smog"),
143
+ )
144
+ btn.click(predict(cg, api_key), inputs=inputs, outputs=outputs)
145
+ blocks.launch()
climategan_wrapper.py CHANGED
@@ -115,7 +115,7 @@ def to_m1_p1(img):
115
 
116
  # No need to do any timing in this, since it's just for the HF Space
117
  class ClimateGAN:
118
- def __init__(self, model_path) -> None:
119
  """
120
  A wrapper for the ClimateGAN model that you can use to generate
121
  events from images or folders containing images.
@@ -125,6 +125,10 @@ class ClimateGAN:
125
  """
126
  torch.set_grad_enabled(False)
127
  self.target_size = 640
 
 
 
 
128
  self.trainer = Trainer.resume_from_path(
129
  model_path,
130
  setup=True,
@@ -132,7 +136,6 @@ class ClimateGAN:
132
  new_exp=None,
133
  )
134
  self.trainer.G.half()
135
- self._stable_diffusion_is_setup = False
136
 
137
  def _setup_stable_diffusion(self):
138
  """
@@ -140,6 +143,9 @@ class ClimateGAN:
140
  Make sure you have accepted the license on the model's card
141
  https://huggingface.co/CompVis/stable-diffusion-v1-4
142
  """
 
 
 
143
  try:
144
  self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
145
  "runwayml/stable-diffusion-inpainting",
@@ -216,6 +222,18 @@ class ClimateGAN:
216
  dict: a dictionary containing the output images {k: HxWxC}. C is omitted
217
  for masks (HxW).
218
  """
 
 
 
 
 
 
 
 
 
 
 
 
219
  image_array = (
220
  np.array(Image.open(orig_image))
221
  if isinstance(orig_image, str)
 
115
 
116
  # No need to do any timing in this, since it's just for the HF Space
117
  class ClimateGAN:
118
+ def __init__(self, model_path, dev_mode=False) -> None:
119
  """
120
  A wrapper for the ClimateGAN model that you can use to generate
121
  events from images or folders containing images.
 
125
  """
126
  torch.set_grad_enabled(False)
127
  self.target_size = 640
128
+ self._stable_diffusion_is_setup = False
129
+ self.dev_mode = dev_mode
130
+ if self.dev_mode:
131
+ return
132
  self.trainer = Trainer.resume_from_path(
133
  model_path,
134
  setup=True,
 
136
  new_exp=None,
137
  )
138
  self.trainer.G.half()
 
139
 
140
  def _setup_stable_diffusion(self):
141
  """
 
143
  Make sure you have accepted the license on the model's card
144
  https://huggingface.co/CompVis/stable-diffusion-v1-4
145
  """
146
+ if self.dev_mode:
147
+ return
148
+
149
  try:
150
  self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
151
  "runwayml/stable-diffusion-inpainting",
 
222
  dict: a dictionary containing the output images {k: HxWxC}. C is omitted
223
  for masks (HxW).
224
  """
225
+ if self.dev_mode:
226
+ return {
227
+ "input": np.random.randint(0, 255, (640, 640, 3)),
228
+ "mask": np.random.randint(0, 255, (640, 640)),
229
+ "masked_input": np.random.randint(0, 255, (640, 640, 3)),
230
+ "climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
231
+ "stable_flood": np.random.randint(0, 255, (640, 640, 3)),
232
+ "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
233
+ "concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
234
+ "smog": np.random.randint(0, 255, (640, 640, 3)),
235
+ "wildfire": np.random.randint(0, 255, (640, 640, 3)),
236
+ }
237
  image_array = (
238
  np.array(Image.open(orig_image))
239
  if isinstance(orig_image, str)