foz commited on
Commit
aada7c5
1 Parent(s): 046b08b

Fix requirements

Browse files
Files changed (5) hide show
  1. app.py +7 -14
  2. app_pose.py +0 -2
  3. model.py +68 -96
  4. requirements.txt +0 -1
  5. utils.py +4 -6
app.py CHANGED
@@ -1,17 +1,14 @@
1
  import gradio as gr
2
  import torch
3
 
4
- from model import Model, ModelType
5
  from app_pose import create_demo as create_demo_pose
6
  import argparse
7
  import os
8
 
9
- on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
10
- model = Model(device='cuda', dtype=torch.float16)
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('--public_access', action='store_true',
13
- help="if enabled, the app can be access from a public url", default=False)
14
- args = parser.parse_args()
15
 
16
 
17
  with gr.Blocks(css='style.css') as demo:
@@ -22,10 +19,6 @@ with gr.Blocks(css='style.css') as demo:
22
  '''
23
 
24
 
25
- if on_huggingspace:
26
- demo.queue(max_size=20)
27
- demo.launch(debug=True)
28
- else:
29
- _, _, link = demo.queue(api_open=False).launch(
30
- file_directories=['temporal'], share=args.public_access)
31
- print(link)
 
1
  import gradio as gr
2
  import torch
3
 
4
+ from model import Model
5
  from app_pose import create_demo as create_demo_pose
6
  import argparse
7
  import os
8
 
9
+ model = Model()
10
+
11
+
 
 
 
12
 
13
 
14
  with gr.Blocks(css='style.css') as demo:
 
19
  '''
20
 
21
 
22
+
23
+ demo.launch(debug=True)
24
+
 
 
 
 
app_pose.py CHANGED
@@ -1,7 +1,5 @@
1
  from model import Model
2
  import gradio as gr
3
- import os
4
- on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
5
 
6
  examples = [
7
  ['Motion 1', "An astronaut dancing in the outer space"],
 
1
  from model import Model
2
  import gradio as gr
 
 
3
 
4
  examples = [
5
  ['Motion 1', "An astronaut dancing in the outer space"],
model.py CHANGED
@@ -4,111 +4,95 @@ import numpy as np
4
  import torch
5
 
6
 
7
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
8
- from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
9
- from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler
10
 
11
- from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
 
 
 
 
 
 
 
12
 
13
 
14
  import utils
15
  import gradio_utils
16
  import os
17
- on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
18
 
19
  from einops import rearrange
20
 
 
21
 
22
- class ModelType(Enum):
23
- ControlNetPose = 5,
24
-
25
 
26
  class Model:
27
- def __init__(self, device, dtype, **kwargs):
28
- self.device = device
29
- self.dtype = dtype
30
- self.generator = torch.Generator(device=device)
31
- self.pipe_dict = {
32
- ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
33
- }
34
-
35
- self.pipe = None
36
- self.model_type = None
37
-
38
- self.states = {}
39
- self.model_name = ""
40
-
41
- def set_model(self, model_type: ModelType, model_id: str, **kwargs):
42
- if hasattr(self, "pipe") and self.pipe is not None:
43
- del self.pipe
44
- torch.cuda.empty_cache()
45
- gc.collect()
46
- print('kwargs', kwargs)
47
- print('device', self.device)
48
- safety_checker = kwargs.pop('safety_checker', None)
49
- controlnet = kwargs.pop('controlnet', None)
50
- self.pipe = self.pipe_dict[model_type].from_pretrained(
51
- model_id, safety_checker=safety_checker, controlnet=controlnet, torch_dtype=torch.float16).to(self.device)#, torch_dtype=torch.float16).to(self.device)
52
-
53
- self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
54
- self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
55
-
56
- self.model_type = model_type
57
- self.model_name = model_id
58
-
59
- def inference_chunk(self, frame_ids, **kwargs):
60
- if not hasattr(self, "pipe") or self.pipe is None:
61
- return
62
-
63
- prompt = np.array(kwargs.pop('prompt'))
64
- negative_prompt = np.array(kwargs.pop('negative_prompt', ''))
65
- latents = None
66
- if 'latents' in kwargs:
67
- latents = kwargs.pop('latents')[frame_ids]
68
- if 'image' in kwargs:
69
- kwargs['image'] = kwargs['image'][frame_ids]
70
- if 'video_length' in kwargs:
71
- kwargs['video_length'] = len(frame_ids)
72
- return self.pipe(prompt=prompt[frame_ids].tolist(),
73
- negative_prompt=negative_prompt[frame_ids].tolist(),
74
- latents=latents,
75
- generator=self.generator,
76
- **kwargs)
77
 
78
  def inference(self, **kwargs):
79
- if not hasattr(self, "pipe") or self.pipe is None:
80
- return
81
-
82
  seed = kwargs.pop('seed', 0)
83
- if seed < 0:
84
- seed = self.generator.seed()
85
- kwargs.pop('generator', '')
 
 
 
86
 
87
- if 'image' in kwargs:
88
- f = kwargs['image'].shape[0]
89
- else:
90
- f = kwargs['video_length']
91
 
92
  assert 'prompt' in kwargs
93
  prompt = [kwargs.pop('prompt')] * f
94
  negative_prompt = [kwargs.pop('negative_prompt', '')] * f
95
 
96
  frames_counter = 0
97
-
98
- # Processing frame_by_frame
99
  result = []
100
- for i in range(f):
101
- frame_ids = [0] + [i]
102
- self.generator.manual_seed(seed)
103
  print(f'Processing frame {i + 1} / {f}')
104
- result.append(self.inference_chunk(frame_ids=frame_ids,
105
  prompt=prompt,
106
  negative_prompt=negative_prompt,
107
- **kwargs).images[1:])
 
108
  frames_counter += 1
109
- if on_huggingspace and frames_counter >= 80:
110
- break
111
- result = np.concatenate(result)
112
  return result
113
 
114
  def process_controlnet_pose(self,
@@ -120,33 +104,22 @@ class Model:
120
  seed=42,
121
  eta=0.0,
122
  resolution=512,
123
- use_cf_attn=True,
124
  save_path=None):
125
  print("Module Pose")
126
  video_path = gradio_utils.motion_to_video_path(video_path)
127
- if self.model_type != ModelType.ControlNetPose:
128
- controlnet = ControlNetModel.from_pretrained(
129
- "fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16)
130
- self.set_model(ModelType.ControlNetPose,
131
- model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet)
132
- self.pipe.scheduler = DDIMScheduler.from_config(
133
- self.pipe.scheduler.config)
134
 
135
- video_path = gradio_utils.motion_to_video_path(
136
- video_path) if 'Motion' in video_path else video_path
137
 
138
  added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
139
  negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
140
 
141
  video, fps = utils.prepare_video(
142
- video_path, resolution, self.device, self.dtype, False, output_fps=4)
143
  control = utils.pre_process_pose(
144
- video, apply_pose_detect=False).to(self.device).to(self.dtype)
 
 
145
  f, _, h, w = video.shape
146
- self.generator.manual_seed(seed)
147
- latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype,
148
- device=self.device, generator=self.generator)
149
- latents = latents.repeat(f, 1, 1, 1)
150
  result = self.inference(image=control,
151
  prompt=prompt + ', ' + added_prompt,
152
  height=h,
@@ -156,9 +129,8 @@ class Model:
156
  guidance_scale=guidance_scale,
157
  controlnet_conditioning_scale=controlnet_conditioning_scale,
158
  eta=eta,
159
- latents=latents,
160
  seed=seed,
161
  output_type='numpy',
162
  )
163
- return utils.create_gif(result, fps, path=save_path)
164
 
 
4
  import torch
5
 
6
 
 
 
 
7
 
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ from flax.jax_utils import replicate
13
+ from flax.training.common_utils import shard
14
+ from PIL import Image
15
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
16
 
17
 
18
  import utils
19
  import gradio_utils
20
  import os
 
21
 
22
  from einops import rearrange
23
 
24
+ import matplotlib.pyplot as plt
25
 
26
+ def create_key(seed=0):
27
+ return jax.random.PRNGKey(seed)
 
28
 
29
  class Model:
30
+ def __init__(self, **kwargs):
31
+ self.base_controlnet, self.base_controlnet_params = FlaxControlNetModel.from_pretrained(
32
+ #"JFoz/dog-cat-pose", dtype=jnp.bfloat16
33
+ "lllyasviel/control_v11p_sd15_openpose", dtype=jnp.bfloat16, from_pt=True
34
+ )
35
+ self.pipe, self.params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
+ "runwayml/stable-diffusion-v1-5", controlnet=self.base_controlnet, revision="flax", dtype=jnp.bfloat16,# from_pt=True,
37
+ )
38
+
39
+ def infer_frame(self, frame_id, prompt, negative_prompt, rng, **kwargs):
40
+
41
+ print(prompt, frame_id)
42
+
43
+ num_samples = 1
44
+ prompt_ids = self.pipe.prepare_text_inputs([prompt[frame_id]]*num_samples)
45
+ negative_prompt_ids = self.pipe.prepare_text_inputs([negative_prompt[frame_id]] * num_samples)
46
+ processed_image = self.pipe.prepare_image_inputs([kwargs['image'][frame_id]]*num_samples)
47
+
48
+ self.params["controlnet"] = self.base_controlnet_params
49
+
50
+
51
+ p_params = replicate(self.params)
52
+ prompt_ids = shard(prompt_ids)
53
+ negative_prompt_ids = shard(negative_prompt_ids)
54
+ processed_image = shard(processed_image)
55
+
56
+ output = self.pipe(
57
+ prompt_ids=prompt_ids,
58
+ image=processed_image,
59
+ params=p_params,
60
+ prng_seed=rng,
61
+ num_inference_steps=50,
62
+ neg_prompt_ids=negative_prompt_ids,
63
+ jit=True,
64
+ ).images
65
+
66
+ output_images = np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
67
+ return output_images
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def inference(self, **kwargs):
70
+
 
 
71
  seed = kwargs.pop('seed', 0)
72
+
73
+ rng = create_key(0)
74
+ rng = jax.random.split(rng, jax.device_count())
75
+
76
+ f = len(kwargs['image'])
77
+ print('frames', f)
78
 
 
 
 
 
79
 
80
  assert 'prompt' in kwargs
81
  prompt = [kwargs.pop('prompt')] * f
82
  negative_prompt = [kwargs.pop('negative_prompt', '')] * f
83
 
84
  frames_counter = 0
85
+
 
86
  result = []
87
+ for i in range(0, f):
 
 
88
  print(f'Processing frame {i + 1} / {f}')
89
+ result.append(self.infer_frame(frame_id=i,
90
  prompt=prompt,
91
  negative_prompt=negative_prompt,
92
+ rng = rng,
93
+ **kwargs))
94
  frames_counter += 1
95
+ result = np.stack(result, axis=0)
 
 
96
  return result
97
 
98
  def process_controlnet_pose(self,
 
104
  seed=42,
105
  eta=0.0,
106
  resolution=512,
 
107
  save_path=None):
108
  print("Module Pose")
109
  video_path = gradio_utils.motion_to_video_path(video_path)
 
 
 
 
 
 
 
110
 
 
 
111
 
112
  added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
113
  negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
114
 
115
  video, fps = utils.prepare_video(
116
+ video_path, resolution, False, output_fps=4)
117
  control = utils.pre_process_pose(
118
+ video, apply_pose_detect=False)
119
+
120
+ print('N frames', len(control))
121
  f, _, h, w = video.shape
122
+
 
 
 
123
  result = self.inference(image=control,
124
  prompt=prompt + ', ' + added_prompt,
125
  height=h,
 
129
  guidance_scale=guidance_scale,
130
  controlnet_conditioning_scale=controlnet_conditioning_scale,
131
  eta=eta,
 
132
  seed=seed,
133
  output_type='numpy',
134
  )
135
+ return utils.create_gif(result.astype(jnp.float16), fps, path=save_path)
136
 
requirements.txt CHANGED
@@ -7,7 +7,6 @@ git+https://github.com/huggingface/diffusers@main
7
  torch
8
  accelerate
9
  decord==0.6.0
10
- diffusers==0.16.1
11
  einops
12
  gradio
13
  imageio
 
7
  torch
8
  accelerate
9
  decord==0.6.0
 
10
  einops
11
  gradio
12
  imageio
utils.py CHANGED
@@ -15,7 +15,7 @@ from controlnet_aux import OpenposeDetector
15
 
16
  apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
17
 
18
- def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
19
  vr = decord.VideoReader(video_path)
20
  initial_fps = vr.get_avg_fps()
21
  if output_fps == -1:
@@ -37,7 +37,7 @@ def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True,
37
  video = video.asnumpy()
38
  _, h, w, _ = video.shape
39
  video = rearrange(video, "f h w c -> f c h w")
40
- video = torch.Tensor(video).to(device).to(dtype)
41
 
42
  # Use max if you want the larger side to be equal to resolution (e.g. 512)
43
  # k = float(resolution) / min(h, w)
@@ -63,10 +63,8 @@ def pre_process_pose(input_video, apply_pose_detect: bool = True):
63
  detected_map = img
64
  H, W, C = img.shape
65
  detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
66
- detected_maps.append(detected_map[None])
67
- detected_maps = np.concatenate(detected_maps)
68
- control = torch.from_numpy(detected_maps.copy()).float() / 255.0
69
- return rearrange(control, 'f h w c -> f c h w')
70
 
71
 
72
 
 
15
 
16
  apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
17
 
18
+ def prepare_video(video_path:str, resolution:int, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
19
  vr = decord.VideoReader(video_path)
20
  initial_fps = vr.get_avg_fps()
21
  if output_fps == -1:
 
37
  video = video.asnumpy()
38
  _, h, w, _ = video.shape
39
  video = rearrange(video, "f h w c -> f c h w")
40
+ video = torch.Tensor(video)
41
 
42
  # Use max if you want the larger side to be equal to resolution (e.g. 512)
43
  # k = float(resolution) / min(h, w)
 
63
  detected_map = img
64
  H, W, C = img.shape
65
  detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
66
+ detected_maps.append(Image.fromarray(detected_map))
67
+ return detected_maps
 
 
68
 
69
 
70