Ziqi commited on
Commit
f1d1956
1 Parent(s): cb90e16
app.py CHANGED
@@ -78,17 +78,36 @@ def show_warning(warning_text: str) -> gr.Blocks:
78
  gr.Markdown(warning_text)
79
  return demo
80
 
 
 
81
 
82
  def create_inference_demo(func: inference_fn) -> gr.Blocks:
83
  with gr.Blocks() as demo:
84
  with gr.Row():
85
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  model_id = gr.Dropdown(
87
  choices=['painted_on', 'carved_by', 'inside'],
88
  value='painted_on',
89
  label='Relation',
90
  visible=True)
91
- # reload_button = gr.Button('Reload Weight List')
92
  prompt = gr.Textbox(
93
  label='Prompt',
94
  max_lines=1,
@@ -120,27 +139,49 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
120
  result = gr.Image(label='Result')
121
 
122
 
123
- prompt.submit(fn=func,
124
- inputs=[
125
- model_id,
126
- prompt,
127
- num_samples,
128
- guidance_scale,
129
- ddim_steps
130
- ],
131
- outputs=result,
132
- queue=False)
133
-
134
- run_button.click(fn=func,
135
- inputs=[
136
- model_id,
137
- prompt,
138
- num_samples,
139
- guidance_scale,
140
- ddim_steps
141
- ],
142
- outputs=result,
143
- queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return demo
145
 
146
 
 
78
  gr.Markdown(warning_text)
79
  return demo
80
 
81
+ def set_example_image(example: list):
82
+ return gr.update(value=example[0])
83
 
84
  def create_inference_demo(func: inference_fn) -> gr.Blocks:
85
  with gr.Blocks() as demo:
86
  with gr.Row():
87
  with gr.Column():
88
+ exemplar_img = gr.Image(
89
+ label='Exemplar Image',
90
+ type='pil',
91
+ interaction=False
92
+ )
93
+ # paths = sorted(pathlib.Path('exemplars').glob('*.jpg'))
94
+ # exemplar_dataset = gr.Dataset(components=[exemplar_img],
95
+ # samples=[[path.as_posix()]
96
+ # for path in paths])
97
+ exemplar_dataset = gr.Dataset(
98
+ components=[exemplar_img],
99
+ samples = [
100
+ ['exemplars/painted_on.jpg'],
101
+ ['exemplars/carved_by.jpg'],
102
+ ['exemplars/inside.jpg']
103
+ ]
104
+ )
105
+
106
  model_id = gr.Dropdown(
107
  choices=['painted_on', 'carved_by', 'inside'],
108
  value='painted_on',
109
  label='Relation',
110
  visible=True)
 
111
  prompt = gr.Textbox(
112
  label='Prompt',
113
  max_lines=1,
 
139
  result = gr.Image(label='Result')
140
 
141
 
142
+ exemplar_dataset.click(fn=set_example_image,
143
+ inputs=exemplar_dataset,
144
+ outputs=exemplar_dataset.components,
145
+ queue=False)
146
+ prompt.submit(
147
+ fn=func,
148
+ # inputs=[
149
+ # model_id,
150
+ # prompt,
151
+ # num_samples,
152
+ # guidance_scale,
153
+ # ddim_steps
154
+ # ],
155
+ inputs=[
156
+ exemplar_dataset,
157
+ prompt,
158
+ num_samples,
159
+ guidance_scale,
160
+ ddim_steps
161
+ ],
162
+ outputs=result,
163
+ queue=False
164
+ )
165
+
166
+ run_button.click(
167
+ fn=func,
168
+ # inputs=[
169
+ # model_id,
170
+ # prompt,
171
+ # num_samples,
172
+ # guidance_scale,
173
+ # ddim_steps
174
+ # ],
175
+ inputs=[
176
+ exemplar_dataset,
177
+ prompt,
178
+ num_samples,
179
+ guidance_scale,
180
+ ddim_steps
181
+ ],
182
+ outputs=result,
183
+ queue=False
184
+ )
185
  return demo
186
 
187
 
exemplars/carved_by.jpg ADDED
exemplars/inside.jpg ADDED
exemplars/painted_on.jpg ADDED
inference.py CHANGED
@@ -42,12 +42,14 @@ def make_image_grid(imgs, rows, cols):
42
 
43
 
44
  def inference_fn(
45
- model_id: str,
46
  prompt: str,
47
  num_samples: int,
48
  guidance_scale: float,
49
  ddim_steps: int,
50
  ) -> PIL.Image.Image:
 
 
51
 
52
  # create inference pipeline
53
  if torch.cuda.is_available():
 
42
 
43
 
44
  def inference_fn(
45
+ examples: list,
46
  prompt: str,
47
  num_samples: int,
48
  guidance_scale: float,
49
  ddim_steps: int,
50
  ) -> PIL.Image.Image:
51
+ # select model_id
52
+ model_id = pathlib.Path(examples[0]).stem
53
 
54
  # create inference pipeline
55
  if torch.cuda.is_available():