michellemoorre commited on
Commit
412b3d8
·
1 Parent(s): 26a558b

Fix ui and add apex

Browse files
Files changed (3) hide show
  1. app.py +35 -25
  2. models/pipeline.py +2 -1
  3. requirements.txt +1 -0
app.py CHANGED
@@ -13,7 +13,6 @@ model_repo_id = "michellemoorre/var-test"
13
  pipe = TVARPipeline.from_pretrained(model_repo_id, device=device)
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
- MAX_IMAGE_SIZE = 1024
17
 
18
 
19
  @spaces.GPU(duration=65)
@@ -25,8 +24,9 @@ def infer(
25
  guidance_scale=4.0,
26
  top_k=450,
27
  top_p=0.95,
28
- re=False,
29
  re_max_depth=10,
 
30
  progress=gr.Progress(track_tqdm=True),
31
  ):
32
  if randomize_seed:
@@ -39,6 +39,8 @@ def infer(
39
  top_p=top_p,
40
  top_k=top_k,
41
  re=re,
 
 
42
  g_seed=seed,
43
  )[0]
44
 
@@ -73,6 +75,23 @@ with gr.Blocks(css=css) as demo:
73
 
74
  result = gr.Image(label="Result", show_label=False)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Accordion("Advanced Settings", open=False):
77
  negative_prompt = gr.Text(
78
  label="Negative prompt",
@@ -81,30 +100,12 @@ with gr.Blocks(css=css) as demo:
81
  visible=True,
82
  )
83
 
84
- seed = gr.Slider(
85
- label="Seed",
86
- minimum=0,
87
- maximum=MAX_SEED,
88
- step=1,
89
- value=0,
90
- )
91
-
92
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
93
-
94
- with gr.Row():
95
- guidance_scale = gr.Slider(
96
- label="Guidance scale",
97
- minimum=0.0,
98
- maximum=7.5,
99
- step=0.1,
100
- value=4.5,
101
- )
102
  with gr.Row():
103
  top_k = gr.Slider(
104
  label="Sampling top k",
105
- minimum=1,
106
  maximum=1000,
107
- step=10,
108
  value=450,
109
  )
110
  top_p = gr.Slider(
@@ -114,15 +115,23 @@ with gr.Blocks(css=css) as demo:
114
  step=0.05,
115
  value=0.95,
116
  )
 
 
117
  with gr.Row():
118
- re = gr.Checkbox(label="Rejection Sampling", value=False)
119
  re_max_depth = gr.Slider(
120
- label="Rejection Sampling Depth",
121
  minimum=0,
122
  maximum=20,
123
- step=1,
124
  value=10,
125
  )
 
 
 
 
 
 
 
126
 
127
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
128
  gr.on(
@@ -138,6 +147,7 @@ with gr.Blocks(css=css) as demo:
138
  top_p,
139
  re,
140
  re_max_depth,
 
141
  ],
142
  outputs=[result, seed],
143
  )
 
13
  pipe = TVARPipeline.from_pretrained(model_repo_id, device=device)
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
 
16
 
17
 
18
  @spaces.GPU(duration=65)
 
24
  guidance_scale=4.0,
25
  top_k=450,
26
  top_p=0.95,
27
+ re=True,
28
  re_max_depth=10,
29
+ re_start_iter=2,
30
  progress=gr.Progress(track_tqdm=True),
31
  ):
32
  if randomize_seed:
 
39
  top_p=top_p,
40
  top_k=top_k,
41
  re=re,
42
+ re_max_depth=re_max_depth,
43
+ re_start_iter=re_start_iter,
44
  g_seed=seed,
45
  )[0]
46
 
 
75
 
76
  result = gr.Image(label="Result", show_label=False)
77
 
78
+ seed = gr.Number(
79
+ label="Seed",
80
+ minimum=0,
81
+ maximum=MAX_SEED,
82
+ value=0,
83
+ )
84
+
85
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
86
+
87
+ guidance_scale = gr.Slider(
88
+ label="Guidance scale",
89
+ minimum=0.0,
90
+ maximum=7.5,
91
+ step=0.5,
92
+ value=4.,
93
+ )
94
+
95
  with gr.Accordion("Advanced Settings", open=False):
96
  negative_prompt = gr.Text(
97
  label="Negative prompt",
 
100
  visible=True,
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Row():
104
  top_k = gr.Slider(
105
  label="Sampling top k",
106
+ minimum=10,
107
  maximum=1000,
108
+ step=20,
109
  value=450,
110
  )
111
  top_p = gr.Slider(
 
115
  step=0.05,
116
  value=0.95,
117
  )
118
+
119
+ re = gr.Checkbox(label="Rejection Sampling (RE)", value=True)
120
  with gr.Row():
 
121
  re_max_depth = gr.Slider(
122
+ label="RE Depth",
123
  minimum=0,
124
  maximum=20,
125
+ step=4,
126
  value=10,
127
  )
128
+ re_start_iter = gr.Slider(
129
+ label="RE Start Scale",
130
+ minimum=0,
131
+ maximum=9,
132
+ step=1,
133
+ value=2,
134
+ )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
137
  gr.on(
 
147
  top_p,
148
  re,
149
  re_max_depth,
150
+ re_start_iter,
151
  ],
152
  outputs=[result, seed],
153
  )
models/pipeline.py CHANGED
@@ -96,6 +96,7 @@ class TVARPipeline:
96
  more_smooth=False,
97
  re=False,
98
  re_max_depth=10,
 
99
  return_pil=True,
100
  encoded_prompt = None,
101
  encoded_null_prompt = None,
@@ -180,7 +181,7 @@ class TVARPipeline:
180
  idx_Bl = sample_with_top_k_top_p_(
181
  logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
182
  )[:, :, 0]
183
- if re:
184
  selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0]
185
  mx = selected_logits.sum(dim=-1)[:, None]
186
  for _ in range(re_max_depth):
 
96
  more_smooth=False,
97
  re=False,
98
  re_max_depth=10,
99
+ re_start_iter=2,
100
  return_pil=True,
101
  encoded_prompt = None,
102
  encoded_null_prompt = None,
 
181
  idx_Bl = sample_with_top_k_top_p_(
182
  logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
183
  )[:, :, 0]
184
+ if re and si >= re_start_iter:
185
  selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0]
186
  mx = selected_logits.sum(dim=-1)[:, None]
187
  for _ in range(re_max_depth):
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch==2.5.0
 
2
  decord==0.6.0
3
  einops==0.8.0
4
  huggingface_hub==0.26.1
 
1
  torch==2.5.0
2
+ apex==0.9.10dev
3
  decord==0.6.0
4
  einops==0.8.0
5
  huggingface_hub==0.26.1