svjack commited on
Commit
6e86f38
1 Parent(s): 27aee24

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +524 -0
app.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ pip install extcolors
3
+ '''
4
+
5
+ import os
6
+ import tensorflow as tf
7
+ os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
8
+ import numpy as np
9
+ import PIL.Image
10
+ import gradio as gr
11
+ import tensorflow_hub as hub
12
+ import matplotlib.pyplot as plt
13
+ from real_esrgan_app import *
14
+
15
+ import gradio as gr
16
+ import requests
17
+ import io
18
+ import random
19
+ import os
20
+ from PIL import Image, ImageDraw, ImageFont
21
+
22
+ from datasets import load_dataset
23
+ import pandas as pd
24
+ from time import sleep
25
+ from tqdm import tqdm
26
+
27
+ import extcolors
28
+ from gradio_client import Client
29
+
30
+ import cv2
31
+ import numpy as np
32
+ import glob
33
+ import pathlib
34
+
35
+ API_TOKEN = os.environ.get("HF_READ_TOKEN")
36
+
37
+ '''
38
+ dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
39
+ prompt_df = dataset["train"].to_pandas()
40
+ prompt_df = pd.read_csv("Stable-Diffusion-Prompts.csv")
41
+ '''
42
+
43
+ #DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1"
44
+ #DEFAULT_PROMPT = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
45
+ DEFAULT_PROMPT = "X go to Istanbul"
46
+ DEFAULT_ROLE = "Superman"
47
+ DEFAULT_BOOK_COVER = "book_cover_dir/JMW_Turner_-_Nantes_from_the_Ile_Feydeau.jpg"
48
+
49
+ hub_module = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
50
+
51
+ def tensor_to_image(tensor):
52
+ tensor = tensor*255
53
+ tensor = np.array(tensor, dtype=np.uint8)
54
+ if np.ndim(tensor)>3:
55
+ assert tensor.shape[0] == 1
56
+ tensor = tensor[0]
57
+ return PIL.Image.fromarray(tensor)
58
+
59
+
60
+ def perform_neural_transfer(content_image_input, style_image_input, super_resolution_type, hub_module = hub_module):
61
+ content_image = content_image_input.astype(np.float32)[np.newaxis, ...] / 255.
62
+ content_image = tf.image.resize(content_image, (400, 600))
63
+
64
+ #style_image_input = style_urls[style_image_input]
65
+ #style_image_input = plt.imread(style_image_input)
66
+ style_image = style_image_input.astype(np.float32)[np.newaxis, ...] / 255.
67
+
68
+ style_image = tf.image.resize(style_image, (256, 256))
69
+
70
+ outputs = hub_module(tf.constant(content_image), tf.constant(style_image))
71
+ stylized_image = outputs[0]
72
+
73
+ stylized_image = tensor_to_image(stylized_image)
74
+ content_image_input = tensor_to_image(content_image_input)
75
+ stylized_image = stylized_image.resize(content_image_input.size)
76
+
77
+ print("super_resolution_type :")
78
+ print(super_resolution_type)
79
+ #print(super_resolution_type.value)
80
+
81
+ if super_resolution_type not in ["base", "anime"]:
82
+ return stylized_image
83
+ else:
84
+ print("call else :")
85
+ stylized_image = inference(stylized_image, super_resolution_type)
86
+ return stylized_image
87
+
88
+ list_models = [
89
+ #"SDXL-1.0",
90
+ "Pixel-Art-XL",
91
+ "SD-1.5",
92
+ "OpenJourney-V4",
93
+ "Anything-V4",
94
+ "Disney-Pixar-Cartoon",
95
+ "Dalle-3-XL",
96
+ #"Midjourney-V4-XL",
97
+ ]
98
+
99
+ #list_prompts = get_samples()
100
+
101
+ def generate_txt2img(current_model, prompt, is_negative=False, image_style="None style", steps=50, cfg_scale=7,
102
+ seed=None, API_TOKEN = API_TOKEN):
103
+
104
+ '''
105
+ if current_model == "SD-1.5":
106
+ API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
107
+ elif current_model == "SDXL-1.0":
108
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
109
+ elif current_model == "OpenJourney-V4":
110
+ API_URL = "https://api-inference.huggingface.co/models/prompthero/openjourney"
111
+ elif current_model == "Anything-V4":
112
+ API_URL = "https://api-inference.huggingface.co/models/xyn-ai/anything-v4.0"
113
+ elif current_model == "Disney-Pixar-Cartoon":
114
+ API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/disney-pixar-cartoon"
115
+ elif current_model == "Pixel-Art-XL":
116
+ API_URL = "https://api-inference.huggingface.co/models/nerijs/pixel-art-xl"
117
+ elif current_model == "Dalle-3-XL":
118
+ API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl"
119
+ elif current_model == "Midjourney-V4-XL":
120
+ API_URL = "https://api-inference.huggingface.co/models/openskyml/midjourney-v4-xl"
121
+ '''
122
+ if current_model == "SD-1.5":
123
+ API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
124
+ elif current_model == "OpenJourney-V4":
125
+ API_URL = "https://api-inference.huggingface.co/models/prompthero/openjourney"
126
+ elif current_model == "Anything-V4":
127
+ API_URL = "https://api-inference.huggingface.co/models/xyn-ai/anything-v4.0"
128
+ elif current_model == "Disney-Pixar-Cartoon":
129
+ API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/disney-pixar-cartoon"
130
+ elif current_model == "Pixel-Art-XL":
131
+ API_URL = "https://api-inference.huggingface.co/models/nerijs/pixel-art-xl"
132
+ elif current_model == "Dalle-3-XL":
133
+ API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl"
134
+
135
+
136
+ #API_TOKEN = os.environ.get("HF_READ_TOKEN")
137
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
138
+
139
+ if type(prompt) != type(""):
140
+ prompt = DEFAULT_PROMPT
141
+
142
+ if image_style == "None style":
143
+ payload = {
144
+ "inputs": prompt + ", 8k",
145
+ "is_negative": is_negative,
146
+ "steps": steps,
147
+ "cfg_scale": cfg_scale,
148
+ "seed": seed if seed is not None else random.randint(-1, 2147483647)
149
+ }
150
+ elif image_style == "Cinematic":
151
+ payload = {
152
+ "inputs": prompt + ", realistic, detailed, textured, skin, hair, eyes, by Alex Huguet, Mike Hill, Ian Spriggs, JaeCheol Park, Marek Denko",
153
+ "is_negative": is_negative + ", abstract, cartoon, stylized",
154
+ "steps": steps,
155
+ "cfg_scale": cfg_scale,
156
+ "seed": seed if seed is not None else random.randint(-1, 2147483647)
157
+ }
158
+ elif image_style == "Digital Art":
159
+ payload = {
160
+ "inputs": prompt + ", faded , vintage , nostalgic , by Jose Villa , Elizabeth Messina , Ryan Brenizer , Jonas Peterson , Jasmine Star",
161
+ "is_negative": is_negative + ", sharp , modern , bright",
162
+ "steps": steps,
163
+ "cfg_scale": cfg_scale,
164
+ "seed": seed if seed is not None else random.randint(-1, 2147483647)
165
+ }
166
+ elif image_style == "Portrait":
167
+ payload = {
168
+ "inputs": prompt + ", soft light, sharp, exposure blend, medium shot, bokeh, (hdr:1.4), high contrast, (cinematic, teal and orange:0.85), (muted colors, dim colors, soothing tones:1.3), low saturation, (hyperdetailed:1.2), (noir:0.4), (natural skin texture, hyperrealism, soft light, sharp:1.2)",
169
+ "is_negative": is_negative,
170
+ "steps": steps,
171
+ "cfg_scale": cfg_scale,
172
+ "seed": seed if seed is not None else random.randint(-1, 2147483647)
173
+ }
174
+
175
+ image_bytes = requests.post(API_URL, headers=headers, json=payload).content
176
+ image = Image.open(io.BytesIO(image_bytes))
177
+ return image
178
+
179
+ from huggingface_hub import InferenceClient
180
+ import gradio as gr
181
+ import pandas as pd
182
+ import numpy as np
183
+ import os
184
+
185
+ event_reasoning_df = pd.DataFrame(
186
+ [['Use the following events as a background to answer questions related to the cause and effect of time.', 'Ok'],
187
+
188
+ ['What are the necessary preconditions for the next event?:X had a big meal.', 'X placed an order'],
189
+ ['What could happen after the next event?:X had a big meal.', 'X becomes fat'],
190
+ ['What is the motivation for the next event?:X had a big meal.', 'X is hungry'],
191
+ ['What are your feelings after the following event?:X had a big meal.', "X tastes good"],
192
+
193
+ ['What are the necessary preconditions for the next event?:X met his favorite star.', 'X bought a ticket'],
194
+ ['What could happen after the next event?:X met his favorite star.', 'X is motivated'],
195
+ ['What is the motivation for the next event?:X met his favorite star.', 'X wants to have some entertainment'],
196
+ ['What are your feelings after the following event?:X met his favorite star.', "X is in a happy mood"],
197
+
198
+ ['What are the necessary preconditions for the next event?: X to cheat', 'X has evil intentions'],
199
+ ['What could happen after the next event?:X to cheat', 'X is accused'],
200
+ ['What is the motivation for the next event?:X to cheat', 'X wants to get something for nothing'],
201
+ ['What are your feelings after the following event?:X to cheat', "X is starving and freezing in prison"],
202
+
203
+ ['What could happen after the next event?:X go to Istanbul', ''],
204
+ ],
205
+ columns = ["User", "Assistant"]
206
+ )
207
+
208
+ Mistral_7B_client = InferenceClient(
209
+ "mistralai/Mistral-7B-Instruct-v0.1"
210
+ )
211
+
212
+ NEED_PREFIX = 'What are the necessary preconditions for the next event?'
213
+ EFFECT_PREFIX = 'What could happen after the next event?'
214
+ INTENT_PREFIX = 'What is the motivation for the next event?'
215
+ REACT_PREFIX = 'What are your feelings after the following event?'
216
+
217
+ def format_prompt(message, history):
218
+ prompt = "<s>"
219
+ for user_prompt, bot_response in history:
220
+ prompt += f"[INST] {user_prompt} [/INST]"
221
+ prompt += f" {bot_response}</s> "
222
+ prompt += f"[INST] {message} [/INST]"
223
+ return prompt
224
+
225
+ def generate(
226
+ prompt, history, client = Mistral_7B_client,
227
+ temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1,
228
+ ):
229
+ temperature = float(temperature)
230
+ if temperature < 1e-2:
231
+ temperature = 1e-2
232
+ top_p = float(top_p)
233
+
234
+ generate_kwargs = dict(
235
+ temperature=temperature,
236
+ max_new_tokens=max_new_tokens,
237
+ top_p=top_p,
238
+ repetition_penalty=repetition_penalty,
239
+ do_sample=True,
240
+ seed=42,
241
+ )
242
+
243
+ formatted_prompt = format_prompt(prompt, history)
244
+
245
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
246
+ output = ""
247
+
248
+ for response in stream:
249
+ output += response.token.text
250
+ yield output
251
+ return output
252
+
253
+ hist = event_reasoning_df.iloc[:-1, :].apply(
254
+ lambda x: (x["User"], x["Assistant"]), axis = 1
255
+ )
256
+
257
+ def produce_4_event(event_fact, hist = hist):
258
+ NEED_PREFIX_prompt = "{}:{}".format(NEED_PREFIX, event_fact)
259
+ EFFECT_PREFIX_prompt = "{}:{}".format(EFFECT_PREFIX, event_fact)
260
+ INTENT_PREFIX_prompt = "{}:{}".format(INTENT_PREFIX, event_fact)
261
+ REACT_PREFIX_prompt = "{}:{}".format(REACT_PREFIX, event_fact)
262
+ NEED_PREFIX_output = list(generate(NEED_PREFIX_prompt, history = hist, max_new_tokens = 2048))[-1]
263
+ EFFECT_PREFIX_output = list(generate(EFFECT_PREFIX_prompt, history = hist, max_new_tokens = 2048))[-1]
264
+ INTENT_PREFIX_output = list(generate(INTENT_PREFIX_prompt, history = hist, max_new_tokens = 2048))[-1]
265
+ REACT_PREFIX_output = list(generate(REACT_PREFIX_prompt, history = hist, max_new_tokens = 2048))[-1]
266
+ NEED_PREFIX_output, EFFECT_PREFIX_output, INTENT_PREFIX_output, REACT_PREFIX_output = map(lambda x: x.replace("</s>", ""), [NEED_PREFIX_output, EFFECT_PREFIX_output, INTENT_PREFIX_output, REACT_PREFIX_output])
267
+ return {
268
+ NEED_PREFIX: NEED_PREFIX_output,
269
+ EFFECT_PREFIX: EFFECT_PREFIX_output,
270
+ INTENT_PREFIX: INTENT_PREFIX_output,
271
+ REACT_PREFIX: REACT_PREFIX_output,
272
+ }
273
+
274
+ def transform_4_event_as_sd_prompts(event_fact ,event_reasoning_dict, role_name = "superman"):
275
+ req = {}
276
+ for k, v in event_reasoning_dict.items():
277
+ if type(role_name) == type("") and role_name.strip():
278
+ v_ = v.replace("X", role_name)
279
+ else:
280
+ v_ = v
281
+ req[k] = list(generate("Transform this as a prompt in stable diffusion: {}".\
282
+ format(v_),
283
+ history = [], max_new_tokens = 2048))[-1].replace("</s>", "")
284
+ event_fact_ = event_fact.replace("X", role_name)
285
+ req["EVENT_FACT"] = list(generate("Transform this as a prompt in stable diffusion: {}".\
286
+ format(event_fact_),
287
+ history = [], max_new_tokens = 2048))[-1].replace("</s>", "")
288
+ req_list = [
289
+ req[INTENT_PREFIX], req[NEED_PREFIX],
290
+ req["EVENT_FACT"],
291
+ req[REACT_PREFIX], req[EFFECT_PREFIX]
292
+ ]
293
+ caption_list = [
294
+ event_reasoning_dict[INTENT_PREFIX], event_reasoning_dict[NEED_PREFIX],
295
+ event_fact,
296
+ event_reasoning_dict[REACT_PREFIX], event_reasoning_dict[EFFECT_PREFIX]
297
+ ]
298
+ caption_list = list(map(lambda x: x.replace("X", role_name), caption_list))
299
+ return caption_list ,req_list
300
+
301
+ def batch_as_list(input_, batch_size = 3):
302
+ req = []
303
+ for ele in input_:
304
+ if not req or len(req[-1]) >= batch_size:
305
+ req.append([ele])
306
+ else:
307
+ req[-1].append(ele)
308
+ return req
309
+
310
+ def add_margin(pil_img, top, right, bottom, left, color):
311
+ width, height = pil_img.size
312
+ new_width = width + right + left
313
+ new_height = height + top + bottom
314
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
315
+ result.paste(pil_img, (left, top))
316
+ return result
317
+
318
+ def add_caption_on_image(input_image, caption, marg_ratio = 0.15, row_token_num = 6):
319
+ from uuid import uuid1
320
+ assert hasattr(input_image, "save")
321
+ max_image_size = max(input_image.size)
322
+ marg_size = int(marg_ratio * max_image_size)
323
+ colors, pixel_count = extcolors.extract_from_image(input_image)
324
+ input_image = add_margin(input_image, marg_size, 0, 0, marg_size, colors[0][0])
325
+ '''
326
+ tmp_name = "{}.png".format(uuid1())
327
+ input_image.save(tmp_name)
328
+ ImageCaptioner.add_captions(tmp_name,
329
+ caption,
330
+ overwrite = 1,
331
+ size = int(marg_size / 4),
332
+ align = "TOP_LEFT",
333
+ output = tmp_name,
334
+ color = "black",
335
+ )
336
+ output_image = Image.open(tmp_name)
337
+ os.remove(tmp_name)
338
+ '''
339
+ font = ImageFont.truetype("DejaVuSerif-Italic.ttf" ,int(marg_size / 4))
340
+ caption_token_list = list(map(lambda x: x.strip() ,caption.split(" ")))
341
+ caption_list = list(map(" ".join ,batch_as_list(caption_token_list, row_token_num)))
342
+ draw = ImageDraw.Draw(input_image)
343
+ for line_num ,line_caption in enumerate(caption_list):
344
+ position = (
345
+ int(marg_size / 4) * (line_num + 1) * 1.1 ,
346
+ (int(marg_size / 4) * (
347
+ (line_num + 1) * 1.1
348
+ )))
349
+ draw.text(position, line_caption, fill="black", font = font)
350
+
351
+ return input_image
352
+
353
+
354
+ def expand2square(pil_img, background_color):
355
+ width, height = pil_img.size
356
+ if width == height:
357
+ return pil_img
358
+ elif width > height:
359
+ result = Image.new(pil_img.mode, (width, width), background_color)
360
+ result.paste(pil_img, (0, (width - height)))
361
+ return result
362
+ else:
363
+ result = Image.new(pil_img.mode, (height, height), background_color)
364
+ result.paste(pil_img, ((height - width)))
365
+ return result
366
+
367
+ def generate_video(images, video_name = 'ppt.avi'):
368
+ import cv2
369
+ from uuid import uuid1
370
+ im_names = []
371
+ for im in images:
372
+ name = "{}.png".format(uuid1())
373
+ im.save(name)
374
+ im_names.append(name)
375
+ frame = cv2.imread(im_names[0])
376
+
377
+ # setting the frame width, height width
378
+ # the width, height of first image
379
+ height, width, layers = frame.shape
380
+
381
+ video = cv2.VideoWriter(video_name, 0, 1, (width, height))
382
+
383
+ # Appending the images to the video one by one
384
+ for name in im_names:
385
+ video.write(cv2.imread(name))
386
+ os.remove(name)
387
+
388
+ # Deallocating memories taken for window creation
389
+ cv2.destroyAllWindows()
390
+ video.release() # releasing the video generated
391
+
392
+ def make_video_from_image_list(image_list, video_name = "ppt.avi"):
393
+ if os.path.exists(video_name):
394
+ os.remove(video_name)
395
+ assert all(map(lambda x: hasattr(x, "save"), image_list))
396
+ max_size = list(map(max ,zip(*map(lambda x: x.size, image_list))))
397
+ max_size = max(max_size)
398
+ image_list = list(map(lambda x: expand2square(x,
399
+ extcolors.extract_from_image(x)[0][0][0]
400
+ ).resize((max_size, max_size)), image_list))
401
+
402
+ generate_video(image_list, video_name = video_name)
403
+ return video_name
404
+
405
+ '''
406
+ style_transfer_client = Client("https://svjack-super-resolution-neural-style-transfer.hf.space")
407
+ def style_transfer_func(content_img, style_img, style_transfer_client = style_transfer_client):
408
+ from uuid import uuid1
409
+ assert hasattr(content_img, "save")
410
+ assert hasattr(style_img, "save")
411
+ content_im_name = "{}.png".format(uuid1())
412
+ style_im_name = "{}.png".format(uuid1())
413
+ content_img.save(content_im_name)
414
+ style_img.save(style_im_name)
415
+ out = style_transfer_client.predict(
416
+ content_im_name,
417
+ style_im_name,
418
+ "none",
419
+ fn_index=1
420
+ )
421
+ os.remove(content_im_name)
422
+ os.remove(style_im_name)
423
+ return Image.open(out)
424
+ '''
425
+ def style_transfer_func(content_img, style_img):
426
+ assert hasattr(content_img, "save")
427
+ assert hasattr(style_img, "save")
428
+ content_image_input = np.asarray(content_img)
429
+ style_image_input = np.asarray(style_img)
430
+ out = perform_neural_transfer(content_image_input, style_image_input, super_resolution_type = "none")
431
+ assert hasattr(out, "save")
432
+ return out
433
+
434
+
435
+ def gen_images_from_event_fact(current_model, event_fact = DEFAULT_PROMPT, role_name = DEFAULT_ROLE,
436
+ style_pic = None
437
+ ):
438
+ event_reasoning_dict = produce_4_event(event_fact)
439
+ caption_list ,event_reasoning_sd_list = transform_4_event_as_sd_prompts(event_fact ,
440
+ event_reasoning_dict,
441
+ role_name = role_name
442
+ )
443
+ img_list = []
444
+ for prompt in tqdm(event_reasoning_sd_list):
445
+ im = generate_txt2img(current_model, prompt, is_negative=False, image_style="None style")
446
+ img_list.append(im)
447
+ sleep(2)
448
+ img_list = list(filter(lambda x: hasattr(x, "save"), img_list))
449
+ if style_pic is not None and hasattr(style_pic, "size"):
450
+ style_pic = Image.fromarray(style_pic.astype(np.uint8))
451
+ print("perform styling.....")
452
+ img_list_ = []
453
+ for x in tqdm(img_list):
454
+ img_list_.append(
455
+ style_transfer_func(x, style_pic)
456
+ )
457
+ img_list = img_list_
458
+ img_list = list(map(lambda t2: add_caption_on_image(t2[0], t2[1]) ,zip(*[img_list, caption_list])))
459
+ img_mid = img_list[2]
460
+ img_list_reordered = [img_mid]
461
+ for ele in img_list:
462
+ if ele not in img_list_reordered:
463
+ img_list_reordered.append(ele)
464
+ video_path = make_video_from_image_list(img_list_reordered)
465
+ return video_path
466
+
467
+ def image_click(images, evt: gr.SelectData,
468
+ ):
469
+ img_selected = images[evt.index]["name"]
470
+ #print(img_selected)
471
+ return img_selected
472
+
473
+ def get_book_covers():
474
+ covers = pd.Series(
475
+ list(pathlib.Path("book_cover_dir").rglob("*.jpg")) + \
476
+ list(pathlib.Path("book_cover_dir").rglob("*.png")) + \
477
+ list(pathlib.Path("book_cover_dir").rglob("*.jpeg"))
478
+ ).map(str).map(lambda x: np.nan if x.split("/")[-1].startswith("_") else x).dropna().values.tolist()
479
+ return covers
480
+
481
+ with gr.Blocks() as demo:
482
+ favicon = '<img src="" width="48px" style="display: inline">'
483
+ gr.Markdown(
484
+ f"""<h1><center>🌻{favicon} AI Diffusion</center></h1>
485
+ """
486
+ )
487
+ with gr.Row():
488
+ with gr.Column(elem_id="prompt-container"):
489
+ current_model = gr.Dropdown(label="Current Model", choices=list_models, value="Pixel-Art-XL")
490
+ style_reference_input_gallery = gr.Gallery(get_book_covers(),
491
+ #width = 512,
492
+ height = 512,
493
+ label = "StoryBook Cover (click to use)")
494
+ with gr.Column(elem_id="prompt-container"):
495
+ #with gr.Row(elem_id="prompt-container"):
496
+ style_reference_input_image = gr.Image(
497
+ label = "StoryBook Cover (you can upload yourself or click from left gallery)",
498
+ #width = 512,
499
+ value = DEFAULT_BOOK_COVER,
500
+ interactive = True,
501
+ )
502
+
503
+ with gr.Row():
504
+ text_prompt = gr.Textbox(label="Prompt", placeholder="a cute dog", lines=1, elem_id="prompt-text-input", value = DEFAULT_PROMPT)
505
+ role_name = gr.Textbox(label="Role", placeholder=DEFAULT_ROLE, lines=1, elem_id="prompt-text-input", value = DEFAULT_ROLE)
506
+ text_button = gr.Button("Generate", variant='primary', elem_id="gen-button")
507
+
508
+ with gr.Row():
509
+ #image_output = gr.Image(type="pil", label="Output Image", elem_id="gallery")
510
+ #image_output = gr.Gallery(label="Output Images", elem_id="gallery")
511
+ video_output = gr.Video(label = "Story Video", elem_id="gallery")
512
+
513
+ #text_button.click(generate_txt2img, inputs=[current_model, text_prompt, negative_prompt, image_style], outputs=image_output)
514
+ style_reference_input_gallery.select(
515
+ image_click, style_reference_input_gallery, style_reference_input_image
516
+ )
517
+
518
+ text_button.click(gen_images_from_event_fact, inputs=[current_model, text_prompt, role_name, style_reference_input_image],
519
+ outputs=video_output)
520
+
521
+ #select_button.click(generate_txt2img, inputs=[current_model, select_prompt, negative_prompt, image_style], outputs=image_output)
522
+ #demo.load(get_params, None, select_prompt)
523
+
524
+ demo.launch(show_api=False)