Songwei Ge commited on
Commit
9d776c8
1 Parent(s): 3430584
Files changed (2) hide show
  1. app.py +2 -2
  2. models/region_diffusion.py +0 -5
app.py CHANGED
@@ -28,7 +28,7 @@ def main():
28
  model = RegionDiffusion(device)
29
 
30
  def generate(
31
- text_input: str,
32
  negative_text: str,
33
  height: int,
34
  width: int,
@@ -44,7 +44,7 @@ def main():
44
  # parse json to span attributes
45
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
46
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
47
- text_input)
48
 
49
  # create control input for region diffusion
50
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
 
28
  model = RegionDiffusion(device)
29
 
30
  def generate(
31
+ json.loads(text_input): str,
32
  negative_text: str,
33
  height: int,
34
  width: int,
 
44
  # parse json to span attributes
45
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
46
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
47
+ json.loads(text_input))
48
 
49
  # create control input for region diffusion
50
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
models/region_diffusion.py CHANGED
@@ -22,17 +22,12 @@ class RegionDiffusion(nn.Module):
22
  print(f'[INFO] loading stable diffusion...')
23
  model_id = 'runwayml/stable-diffusion-v1-5'
24
 
25
- # 1. Load the autoencoder model which will be used to decode the latents into image space.
26
  self.vae = AutoencoderKL.from_pretrained(
27
  model_id, subfolder="vae").to(self.device)
28
-
29
- # 2. Load the tokenizer and text encoder to tokenize and encode the text.
30
  self.tokenizer = CLIPTokenizer.from_pretrained(
31
  model_id, subfolder='tokenizer')
32
  self.text_encoder = CLIPTextModel.from_pretrained(
33
  model_id, subfolder='text_encoder').to(self.device)
34
-
35
- # 3. The UNet model for generating the latents.
36
  self.unet = UNet2DConditionModel.from_pretrained(
37
  model_id, subfolder="unet").to(self.device)
38
 
 
22
  print(f'[INFO] loading stable diffusion...')
23
  model_id = 'runwayml/stable-diffusion-v1-5'
24
 
 
25
  self.vae = AutoencoderKL.from_pretrained(
26
  model_id, subfolder="vae").to(self.device)
 
 
27
  self.tokenizer = CLIPTokenizer.from_pretrained(
28
  model_id, subfolder='tokenizer')
29
  self.text_encoder = CLIPTextModel.from_pretrained(
30
  model_id, subfolder='text_encoder').to(self.device)
 
 
31
  self.unet = UNet2DConditionModel.from_pretrained(
32
  model_id, subfolder="unet").to(self.device)
33