Zhouyan248 commited on
Commit
1dade9e
1 Parent(s): 03c3c6a

Update base/app.py

Browse files
Files changed (1) hide show
  1. base/app.py +54 -22
base/app.py CHANGED
@@ -15,12 +15,11 @@ args = OmegaConf.load("./base/configs/sample.yaml")
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # ------- get model ---------------
18
- model_t2V = model_t2v_fun(args)
19
- model_t2V.to(device)
20
- if device == "cuda":
21
- model_t2V.enable_xformers_memory_efficient_attention()
22
 
23
- # model_t2V.enable_xformers_memory_efficient_attention()
24
  css = """
25
  h1 {
26
  text-align: center;
@@ -31,13 +30,46 @@ h1 {
31
  }
32
  """
33
 
34
- def infer(prompt, seed_inp, ddim_steps,cfg):
 
 
 
 
 
 
 
 
 
 
 
35
  if seed_inp!=-1:
36
  setup_seed(seed_inp)
37
  else:
38
  seed_inp = random.choice(range(10000000))
39
  setup_seed(seed_inp)
40
- videos = model_t2V(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  print(videos[0].shape)
42
  if not os.path.exists(args.output_folder):
43
  os.mkdir(args.output_folder)
@@ -82,7 +114,7 @@ with gr.Blocks(css='style.css') as demo:
82
  with gr.Column():
83
 
84
  prompt = gr.Textbox(value="a corgi walking in the park at sunrise, oil painting style", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
85
-
86
  ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
87
  seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
88
  cfg = gr.Number(label="guidance_scale",value=7.5)
@@ -94,24 +126,24 @@ with gr.Blocks(css='style.css') as demo:
94
  clean_btn = gr.Button("Clean video")
95
  video_out = gr.Video(label="Video result", elem_id="video-output")
96
 
97
- inputs = [prompt, seed_inp, ddim_steps,cfg]
98
  outputs = [video_out]
99
 
100
  ex = gr.Examples(
101
- examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7],
102
- ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7],
103
- ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7],
104
- ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7],
105
- ['a teddy bear walking in the park, oil painting style, high quality',400,50,7],
106
- ['a teddy bear walking on the street, 2k, high quality',100,50,7],
107
- ['a panda taking a selfie, 2k, high quality',400,50,7],
108
- ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7],
109
- ['jungle river at sunset, ultra quality',400,50,7],
110
- ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7],
111
- ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7],
112
- ['a confused grizzly bear in calculus class',1000,50,7]],
113
  fn = infer,
114
- inputs=[prompt, seed_inp, ddim_steps,cfg],
115
  outputs=[video_out],
116
  cache_examples=False,
117
  )
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # ------- get model ---------------
18
+ # model_t2V = model_t2v_fun(args)
19
+ # model_t2V.to(device)
20
+ # if device == "cuda":
21
+ # model_t2V.enable_xformers_memory_efficient_attention()
22
 
 
23
  css = """
24
  h1 {
25
  text-align: center;
 
30
  }
31
  """
32
 
33
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
34
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
35
+ state_dict = find_model("./pretrained_models/lavie_base.pt")
36
+ unet.load_state_dict(state_dict)
37
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
38
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
39
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
40
+ unet.eval()
41
+ vae.eval()
42
+ text_encoder_one.eval()
43
+
44
+ def infer(prompt, seed_inp, ddim_steps,cfg, infer_type):
45
  if seed_inp!=-1:
46
  setup_seed(seed_inp)
47
  else:
48
  seed_inp = random.choice(range(10000000))
49
  setup_seed(seed_inp)
50
+ if infer_type == 'ddim':
51
+ scheduler = DDIMScheduler.from_pretrained(sd_path,
52
+ subfolder="scheduler",
53
+ beta_start=args.beta_start,
54
+ beta_end=args.beta_end,
55
+ beta_schedule=args.beta_schedule)
56
+ elif infer_type == 'eulerdiscrete':
57
+ scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
58
+ subfolder="scheduler",
59
+ beta_start=args.beta_start,
60
+ beta_end=args.beta_end,
61
+ beta_schedule=args.beta_schedule)
62
+ elif infer_type == 'ddpm':
63
+ scheduler = DDPMScheduler.from_pretrained(sd_path,
64
+ subfolder="scheduler",
65
+ beta_start=args.beta_start,
66
+ beta_end=args.beta_end,
67
+ beta_schedule=args.beta_schedule)
68
+ model = VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
69
+ model.to(device)
70
+ if device == "cuda":
71
+ model.enable_xformers_memory_efficient_attention()
72
+ videos = model(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
73
  print(videos[0].shape)
74
  if not os.path.exists(args.output_folder):
75
  os.mkdir(args.output_folder)
 
114
  with gr.Column():
115
 
116
  prompt = gr.Textbox(value="a corgi walking in the park at sunrise, oil painting style", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
117
+ infer_type = gr.Dropdown(['ddpm','ddim','eulerdiscrete'], label='infer_type',value='ddim')
118
  ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
119
  seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
120
  cfg = gr.Number(label="guidance_scale",value=7.5)
 
126
  clean_btn = gr.Button("Clean video")
127
  video_out = gr.Video(label="Video result", elem_id="video-output")
128
 
129
+ inputs = [prompt, seed_inp, ddim_steps, cfg, infer_type]
130
  outputs = [video_out]
131
 
132
  ex = gr.Examples(
133
+ examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7,'ddim'],
134
+ ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7,'ddim'],
135
+ ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7,'ddim'],
136
+ ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7,'ddim'],
137
+ ['a teddy bear walking in the park, oil painting style, high quality',400,50,7,'ddim'],
138
+ ['a teddy bear walking on the street, 2k, high quality',100,50,7,'ddim'],
139
+ ['a panda taking a selfie, 2k, high quality',400,50,7,'ddim'],
140
+ ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7,'ddim'],
141
+ ['jungle river at sunset, ultra quality',400,50,7,'ddim'],
142
+ ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7,'ddim'],
143
+ ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7,'ddim'],
144
+ ['a confused grizzly bear in calculus class',1000,50,7,'ddim']],
145
  fn = infer,
146
+ inputs=[prompt, seed_inp, ddim_steps,cfg,infer_type],
147
  outputs=[video_out],
148
  cache_examples=False,
149
  )