boris commited on
Commit
4e4a30f
1 Parent(s): 3a3bee8

refactor: move to tools

Browse files
dev/inference/README.md DELETED
@@ -1 +0,0 @@
1
- Scripts to generate predictions for assessment and reporting.
 
 
dev/inference/wandb-examples-from-backend.py DELETED
@@ -1,76 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image, ImageDraw, ImageFont
5
- import wandb
6
- import os
7
-
8
- from dalle_mini.backend import ServiceError, get_images_from_backend
9
- from dalle_mini.helpers import captioned_strip
10
-
11
- os.environ["WANDB_SILENT"] = "true"
12
- os.environ["WANDB_CONSOLE"] = "off"
13
-
14
- def log_to_wandb(prompts):
15
- try:
16
- backend_url = os.environ["BACKEND_SERVER"]
17
- for _ in range(1):
18
- for prompt in prompts:
19
- print(f"Getting selections for: {prompt}")
20
- # make a separate run per prompt
21
- with wandb.init(
22
- entity='wandb',
23
- project='hf-flax-dalle-mini',
24
- job_type='predictions',# tags=['openai'],
25
- config={'prompt': prompt}
26
- ):
27
- imgs = []
28
- selected = get_images_from_backend(prompt, backend_url)
29
- strip = captioned_strip(selected, prompt)
30
- imgs.append(wandb.Image(strip))
31
- wandb.log({"images": imgs})
32
- except ServiceError as error:
33
- print(f"Service unavailable, status: {error.status_code}")
34
- except KeyError:
35
- print("Error: BACKEND_SERVER unset")
36
-
37
- prompts = [
38
- # "white snow covered mountain under blue sky during daytime",
39
- # "aerial view of beach during daytime",
40
- # "aerial view of beach at night",
41
- # "a farmhouse surrounded by beautiful flowers",
42
- # "an armchair in the shape of an avocado",
43
- # "young woman riding her bike trough a forest",
44
- # "a unicorn is passing by a rainbow in a field of flowers",
45
- # "illustration of a baby shark swimming around corals",
46
- # "painting of an oniric forest glade surrounded by tall trees",
47
- # "sunset over green mountains",
48
- # "a forest glade surrounded by tall trees in a sunny Spring morning",
49
- # "fishing village under the moonlight in a serene sunset",
50
- # "cartoon of a carrot with big eyes",
51
- # "still life in the style of Kandinsky",
52
- # "still life in the style of Picasso",
53
- # "a graphite sketch of a gothic cathedral",
54
- # "a graphite sketch of Elon Musk",
55
- # "a watercolor pond with green leaves and yellow flowers",
56
- # "a logo of a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps",
57
- # "happy celebration in a small village in Africa",
58
- # "a logo of an armchair in the shape of an avocado"
59
- # "Pele and Maradona in a hypothetical match",
60
- # "Mohammed Ali and Mike Tyson in a hypothetical match",
61
- # "a storefront that has the word 'openai' written on it",
62
- # "a pentagonal green clock",
63
- # "a collection of glasses is sitting on a table",
64
- # "a small red block sitting on a large green block",
65
- # "an extreme close-up view of a capybara sitting in a field",
66
- # "a cross-section view of a walnut",
67
- # "a professional high-quality emoji of a lovestruck cup of boba",
68
- # "a photo of san francisco's golden gate bridge",
69
- # "an illustration of a baby daikon radish in a tutu walking a dog",
70
- # "a picture of the Eiffel tower on the Moon",
71
- # "a colorful stairway to heaven",
72
- "this is a detailed high-resolution scan of a human brain"
73
- ]
74
-
75
- for _ in range(1):
76
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/inference/wandb-examples.py DELETED
@@ -1,163 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import random
5
-
6
- import jax
7
- from flax.training.common_utils import shard
8
- from flax.jax_utils import replicate, unreplicate
9
-
10
- from transformers.models.bart.modeling_flax_bart import *
11
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
-
13
- import os
14
-
15
- from PIL import Image
16
- import numpy as np
17
- import matplotlib.pyplot as plt
18
-
19
- import torch
20
- import torchvision.transforms as T
21
- import torchvision.transforms.functional as TF
22
- from torchvision.transforms import InterpolationMode
23
-
24
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
25
- from vqgan_jax.modeling_flax_vqgan import VQModel
26
-
27
- # ## CLIP Scoring
28
- from transformers import CLIPProcessor, FlaxCLIPModel
29
-
30
- import wandb
31
- import os
32
-
33
- from dalle_mini.helpers import captioned_strip
34
-
35
-
36
- os.environ["WANDB_SILENT"] = "true"
37
- os.environ["WANDB_CONSOLE"] = "off"
38
-
39
- # TODO: used for legacy support
40
- BASE_MODEL = 'facebook/bart-large-cnn'
41
-
42
- # set id to None so our latest images don't get overwritten
43
- id = None
44
- run = wandb.init(id=id,
45
- entity='wandb',
46
- project="hf-flax-dalle-mini",
47
- job_type="predictions",
48
- resume="allow"
49
- )
50
- artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', type='bart_model')
51
- artifact_dir = artifact.download()
52
-
53
- # create our model
54
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
55
-
56
- # TODO: legacy support (earlier models)
57
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
58
- model.config.force_bos_token_to_be_generated = False
59
- model.config.forced_bos_token_id = None
60
- model.config.forced_eos_token_id = None
61
-
62
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
63
-
64
- def custom_to_pil(x):
65
- x = np.clip(x, 0., 1.)
66
- x = (255*x).astype(np.uint8)
67
- x = Image.fromarray(x)
68
- if not x.mode == "RGB":
69
- x = x.convert("RGB")
70
- return x
71
-
72
- def generate(input, rng, params):
73
- return model.generate(
74
- **input,
75
- max_length=257,
76
- num_beams=1,
77
- do_sample=True,
78
- prng_key=rng,
79
- eos_token_id=50000,
80
- pad_token_id=50000,
81
- params=params,
82
- )
83
-
84
- def get_images(indices, params):
85
- return vqgan.decode_code(indices, params=params)
86
-
87
- def plot_images(images):
88
- fig = plt.figure(figsize=(40, 20))
89
- columns = 4
90
- rows = 2
91
- plt.subplots_adjust(hspace=0, wspace=0)
92
-
93
- for i in range(1, columns*rows +1):
94
- fig.add_subplot(rows, columns, i)
95
- plt.imshow(images[i-1])
96
- plt.gca().axes.get_yaxis().set_visible(False)
97
- plt.show()
98
-
99
- def stack_reconstructions(images):
100
- w, h = images[0].size[0], images[0].size[1]
101
- img = Image.new("RGB", (len(images)*w, h))
102
- for i, img_ in enumerate(images):
103
- img.paste(img_, (i*w,0))
104
- return img
105
-
106
- p_generate = jax.pmap(generate, "batch")
107
- p_get_images = jax.pmap(get_images, "batch")
108
-
109
- bart_params = replicate(model.params)
110
- vqgan_params = replicate(vqgan.params)
111
-
112
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
113
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
-
115
- def hallucinate(prompt, num_images=64):
116
- prompt = [prompt] * jax.device_count()
117
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
118
- inputs = shard(inputs)
119
-
120
- all_images = []
121
- for i in range(num_images // jax.device_count()):
122
- key = random.randint(0, 1e7)
123
- rng = jax.random.PRNGKey(key)
124
- rngs = jax.random.split(rng, jax.local_device_count())
125
- indices = p_generate(inputs, rngs, bart_params).sequences
126
- indices = indices[:, :, 1:]
127
-
128
- images = p_get_images(indices, vqgan_params)
129
- images = np.squeeze(np.asarray(images), 1)
130
- for image in images:
131
- all_images.append(custom_to_pil(image))
132
- return all_images
133
-
134
- def clip_top_k(prompt, images, k=8):
135
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
136
- # FIXME: image should be resized and normalized prior to being processed by CLIP
137
- outputs = clip(**inputs)
138
- logits = outputs.logits_per_text
139
- scores = np.array(logits[0]).argsort()[-k:][::-1]
140
- return [images[score] for score in scores]
141
-
142
- def log_to_wandb(prompts):
143
- strips = []
144
- for prompt in prompts:
145
- print(f"Generating candidates for: {prompt}")
146
- images = hallucinate(prompt, num_images=32)
147
- selected = clip_top_k(prompt, images, k=8)
148
- strip = captioned_strip(selected, prompt)
149
- strips.append(wandb.Image(strip))
150
- wandb.log({"images": strips})
151
-
152
- prompts = prompts = [
153
- "white snow covered mountain under blue sky during daytime",
154
- "aerial view of beach during daytime",
155
- "aerial view of beach at night",
156
- "an armchair in the shape of an avocado",
157
- "young woman riding her bike trough a forest",
158
- "rice fields by the mediterranean coast",
159
- "white houses on the hill of a greek coastline",
160
- "illustration of a shark with a baby shark",
161
- ]
162
-
163
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{dev → tools}/inference/inference_pipeline.ipynb RENAMED
File without changes
dev/inference/wandb-backend.ipynb → tools/inference/log_inference_samples.ipynb RENAMED
File without changes
{dev → tools}/inference/samples.txt RENAMED
File without changes