Zeqiang-Lai commited on
Commit
728f93e
1 Parent(s): 75d2986
Files changed (3) hide show
  1. README.md +5 -0
  2. gradio_app.py +57 -4
  3. run.py +7 -2
README.md CHANGED
@@ -15,4 +15,9 @@ pinned: false
15
  python run.py --model_id runwayml/stable-diffusion-v1-5
16
  python run.py --model_id stabilityai/stable-diffusion-2-1
17
  python run.py --model_id stabilityai/stable-diffusion-xl-base-1.0
 
 
 
 
 
18
  ```
 
15
  python run.py --model_id runwayml/stable-diffusion-v1-5
16
  python run.py --model_id stabilityai/stable-diffusion-2-1
17
  python run.py --model_id stabilityai/stable-diffusion-xl-base-1.0
18
+ python run.py --model_id PixArt-alpha/PixArt-XL-2-512x512
19
+ python run.py --model_id PixArt-alpha/PixArt-XL-2-1024-MS
20
+ python run.py --model_id playgroundai/playground-v2-1024px-aesthetic
21
+ python run.py --model_id kandinsky-community/kandinsky-3
22
+ python run.py --model_id Lykon/dreamshaper-8
23
  ```
gradio_app.py CHANGED
@@ -1,7 +1,60 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import csv
4
 
5
+ root = 'saved'
6
+ prompt_path = 'assets/ViLG-300.csv'
7
 
8
+
9
+ def load_prompts(path):
10
+ if os.path.basename(path) == 'ViLG-300.csv':
11
+ def csv_to_dict(file_path):
12
+ result_dict = {}
13
+ with open(file_path, 'r', encoding='utf-8') as csv_file:
14
+ csv_reader = csv.DictReader(csv_file, delimiter=',')
15
+ for row in csv_reader:
16
+ prompt = row['\ufeffPrompt']
17
+ text = row['文本']
18
+ category = row['类别']
19
+ source = row['来源']
20
+ result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source}
21
+ return result_dict
22
+ data = list(csv_to_dict(path).keys())
23
+ else:
24
+ return NotImplementedError
25
+ return data
26
+
27
+
28
+ prompts = load_prompts(prompt_path)
29
+
30
+
31
+ def load_images(methods, idx):
32
+ idx = int(idx)
33
+ prompt = prompts[idx].strip()
34
+ images = []
35
+ for method in methods:
36
+ image = os.path.join(root, method, f'{idx}.jpg')
37
+ images.append((image, method))
38
+ return prompt, images
39
+
40
+
41
+ def load_methods():
42
+ methods = os.listdir(root)
43
+ return methods
44
+
45
+
46
+ def main():
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("# Text to Image Models Comparison")
49
+ with gr.Row():
50
+ idx = gr.Number(value=0, label='Index')
51
+ prompt = gr.Textbox()
52
+ methods = gr.Dropdown(multiselect=True, choices=load_methods(), value=load_methods(), label='Methods')
53
+ gallery = gr.Gallery(show_label=False, object_fit='fill', height=600, columns=5)
54
+ idx.change(load_images, [methods, idx], [prompt, gallery])
55
+ methods.change(load_images, [methods, idx], [prompt, gallery])
56
+ demo.launch()
57
+
58
+
59
+ if __name__ == '__main__':
60
+ main()
run.py CHANGED
@@ -28,14 +28,19 @@ def main(
28
  model_id="runwayml/stable-diffusion-v1-5",
29
  prompt_path="assets/ViLG-300.csv",
30
  save_path=None,
 
31
  ):
32
  if save_path is None:
33
  save_path = os.path.join('saved', model_id.replace('/', '_'))
34
  os.makedirs(save_path, exist_ok=True)
35
 
36
  prompts = load_prompts(prompt_path)
37
- pipeline = AutoPipelineForText2Image.from_pretrained(model_id)
38
- pipeline.to(device='cuda', dtype=torch.float16)
 
 
 
 
39
  for i, prompt in enumerate(prompts):
40
  print(f'{i}|{len(prompts)}: {prompt}')
41
  image = pipeline(prompt).images[0]
 
28
  model_id="runwayml/stable-diffusion-v1-5",
29
  prompt_path="assets/ViLG-300.csv",
30
  save_path=None,
31
+ dtype='fp16',
32
  ):
33
  if save_path is None:
34
  save_path = os.path.join('saved', model_id.replace('/', '_'))
35
  os.makedirs(save_path, exist_ok=True)
36
 
37
  prompts = load_prompts(prompt_path)
38
+ pipeline = AutoPipelineForText2Image.from_pretrained(
39
+ model_id,
40
+ torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16
41
+ )
42
+ pipeline.to(device='cuda')
43
+ pipeline.safety_checker = None
44
  for i, prompt in enumerate(prompts):
45
  print(f'{i}|{len(prompts)}: {prompt}')
46
  image = pipeline(prompt).images[0]