Jialu commited on
Commit
e72da55
1 Parent(s): feb2691

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
6
+
7
+ import gradio as gr
8
+
9
+ model_id = "stabilityai/stable-diffusion-2-1-base"
10
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
11
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16)
12
+ pipe = pipe.to("cuda")
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model, preprocess = clip.load("ViT-B/32", device=device)
16
+
17
+ def generate_image(text_prompt):
18
+ images = pipe(text_prompt, num_images_per_prompt=10).images
19
+ return images
20
+
21
+ def build_generation_block(prompt):
22
+ with gr.Row(variant="compact"):
23
+ text = gr.Textbox(
24
+ label=prompt,
25
+ show_label=False,
26
+ max_lines=1,
27
+ placeholder=prompt,
28
+ ).style(
29
+ container=False,
30
+ )
31
+ btn = gr.Button("Generate image").style(full_width=False)
32
+
33
+ gallery = gr.Gallery(
34
+ label="Generated images", show_label=False, elem_id="gallery"
35
+ ).style(columns=[5], rows=[2], object_fit="contain", height="auto")
36
+
37
+ btn.click(generate_image, text, gallery)
38
+
39
+ return text, gallery
40
+
41
+ def compute_association_score(image_null, image_pos, image_neg):
42
+ def compute_score(images):
43
+ # print(images[0])
44
+ features = [preprocess(Image.open(i['name'])) for i in images]
45
+ features = torch.stack(features).to(device)
46
+ with torch.no_grad():
47
+ image_features = model.encode_image(features)
48
+ image_features /= image_features.norm(dim=-1, keepdim=True)
49
+ return image_features.cpu().numpy()
50
+
51
+ emb_null = compute_score(image_null)
52
+ emb_pos = compute_score(image_pos)
53
+ emb_neg = compute_score(image_neg)
54
+
55
+ return np.mean(emb_pos @ emb_null.T) - np.mean(emb_neg @ emb_null.T)
56
+
57
+
58
+
59
+ with gr.Blocks() as demo:
60
+ with gr.Group():
61
+ gr.HTML("<h1 align='center'>T2IAT: Measuring Valence and Stereotypical Biases in Text-to-Image Generation")
62
+ gr.HTML("<h1 align='center'><strong style='color:#A52A2A'>ACL 2023 (Findings)</strong></h1>")
63
+ gr.HTML("<h2 align='center' style='color:#29A6A6'>Jialu Wang, Xinyue Gabby Liu, Zonglin Di, Yang Liu, Xin Eric Wang</h2>")
64
+ gr.HTML("<h2 align='center'>University of California, Santa Cruz</h2>")
65
+ gr.HTML("""
66
+ <h2 align="center">
67
+ <span style='display:inline'>
68
+ <a href="https://arxiv.org/abs/2306.00905" class="external-link button is-normal is-rounded is-dark">
69
+ <span class="icon">
70
+ <svg style="display:inline-block;font-size:inherit;height:1em;overflow:visible;vertical-align:-.125em" aria-hidden="true" focusable="false" data-prefix="fas" data-icon="file-pdf" role="img" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 384 512" data-fa-i2svg=""><path fill="currentColor" d="M181.9 256.1c-5-16-4.9-46.9-2-46.9 8.4 0 7.6 36.9 2 46.9zm-1.7 47.2c-7.7 20.2-17.3 43.3-28.4 62.7 18.3-7 39-17.2 62.9-21.9-12.7-9.6-24.9-23.4-34.5-40.8zM86.1 428.1c0 .8 13.2-5.4 34.9-40.2-6.7 6.3-29.1 24.5-34.9 40.2zM248 160h136v328c0 13.3-10.7 24-24 24H24c-13.3 0-24-10.7-24-24V24C0 10.7 10.7 0 24 0h200v136c0 13.2 10.8 24 24 24zm-8 171.8c-20-12.2-33.3-29-42.7-53.8 4.5-18.5 11.6-46.6 6.2-64.2-4.7-29.4-42.4-26.5-47.8-6.8-5 18.3-.4 44.1 8.1 77-11.6 27.6-28.7 64.6-40.8 85.8-.1 0-.1.1-.2.1-27.1 13.9-73.6 44.5-54.5 68 5.6 6.9 16 10 21.5 10 17.9 0 35.7-18 61.1-61.8 25.8-8.5 54.1-19.1 79-23.2 21.7 11.8 47.1 19.5 64 19.5 29.2 0 31.2-32 19.7-43.4-13.9-13.6-54.3-9.7-73.6-7.2zM377 105L279 7c-4.5-4.5-10.6-7-17-7h-6v128h128v-6.1c0-6.3-2.5-12.4-7-16.9zm-74.1 255.3c4.1-2.7-2.5-11.9-42.8-9 37.1 15.8 42.8 9 42.8 9z"></path></svg><!-- <i class="fas fa-file-pdf"></i> Font Awesome fontawesome.com -->
71
+ </span>
72
+ <span>Paper</span>
73
+ </a>
74
+ </span>
75
+ <span style='display:inline'>
76
+ <a href="https://github.com/eric-ai-lab/T2IAT" class="external-link button is-normal is-rounded is-dark">
77
+ <span class="icon">
78
+ <svg style="display:inline-block;font-size:inherit;height:1em;overflow:visible;vertical-align:-.125em" aria-hidden="true" focusable="false" data-prefix="fab" data-icon="github" role="img" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 496 512" data-fa-i2svg=""><path fill="currentColor" d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6zm-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3zm44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9zM244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8zM97.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1zm-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7zm32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1zm-11.4-14.7c-1.6 1-1.6 3.6 0 5.9 1.6 2.3 4.3 3.3 5.6 2.3 1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2z"></path></svg><!-- <i class="fab fa-github" aria-hidden="true"></i> Font Awesome fontawesome.com -->
79
+ </span>
80
+ <span>Code</span>
81
+ </a>
82
+ </span>
83
+ </h2>
84
+ """)
85
+
86
+ gr.HTML("""
87
+ <div>
88
+ <p style="padding: 25px 200px; text-align: justify;">
89
+ <strong>Abstract:</strong> In the last few years, text-to-image generative models have gained remarkable success in generating images with unprecedented quality accompanied by a breakthrough of inference speed. Despite their rapid progress, human biases that manifest in the training examples, particularly with regard to common stereotypical biases, like gender and skin tone, still have been found in these generative models. In this work, we seek to measure more complex human biases exist in the task of text-to-image generations. Inspired by the well-known Implicit Association Test (IAT) from social psychology, we propose a novel Text-to-Image Association Test (T2IAT) framework that quantifies the implicit stereotypes between concepts and valence, and those in the images. We replicate the previously documented bias tests on generative models, including morally neutral tests on flowers and insects as well as demographic stereotypical tests on diverse social attributes. The results of these experiments demonstrate the presence of complex stereotypical behaviors in image generations.
90
+ </p>
91
+ </div>
92
+ """)
93
+
94
+ # gr.Image(
95
+ # "images/Text2ImgAssocationTest.png"
96
+ # ).style(
97
+ # height=300,
98
+ # weight=400
99
+ # )
100
+
101
+ with gr.Group():
102
+ gr.HTML("""
103
+ <h3>First step: generate images with neutral prompts</h3>
104
+ """)
105
+ text_null, gallery_null = build_generation_block("Enter the neutral prompt.")
106
+
107
+ with gr.Group():
108
+ gr.HTML("""
109
+ <h3>Second step: generate attribute-guided images by including the attributes into the prompts</h3>
110
+ """)
111
+ text_pos, gallery_pos = build_generation_block("Enter your prompt with attribute A.")
112
+ text_neg, gallery_neg = build_generation_block("Enter your prompt with attribute B.")
113
+
114
+ with gr.Group():
115
+ gr.HTML("<h3>Final step: compute the association score between your specified attributes!")
116
+
117
+ with gr.Row():
118
+ score = gr.Number(label="association score")
119
+ btn = gr.Button("Compute Association Score!")
120
+ btn.click(compute_association_score, [gallery_null, gallery_pos, gallery_neg], score)
121
+
122
+ gr.HTML("<p>The absolute value of the association score represents the strength of the bias between the compared attributes, A and B, subject to the concepts that users choose in image generation. The higher score, the stronger the association, and vice versa.</p>")
123
+
124
+ if __name__ == "__main__":
125
+ demo.queue(concurrency_count=3)
126
+ demo.launch(title='T2IAT')
127
+