Luis Oala commited on
Commit
c904588
1 Parent(s): c5b0c77

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +145 -0
server.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from fastapi import FastAPI
4
+ from PIL import Image
5
+ import torch as th
6
+ from glide_text2im.download import load_checkpoint
7
+ from glide_text2im.model_creation import (
8
+ create_model_and_diffusion,
9
+ model_and_diffusion_defaults,
10
+ model_and_diffusion_defaults_upsampler
11
+ )
12
+ print("Loading models...")
13
+ app = FastAPI()
14
+ # This notebook supports both CPU and GPU.
15
+ # On CPU, generating one sample may take on the order of 20 minutes.
16
+ # On a GPU, it should be under a minute.
17
+ has_cuda = th.cuda.is_available()
18
+ device = th.device('cpu' if not has_cuda else 'cuda')
19
+ # Create base model.
20
+ options = model_and_diffusion_defaults()
21
+ options['use_fp16'] = has_cuda
22
+ options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
23
+ model, diffusion = create_model_and_diffusion(**options)
24
+ model.eval()
25
+ if has_cuda:
26
+ model.convert_to_fp16()
27
+ model.to(device)
28
+ model.load_state_dict(load_checkpoint('base', device))
29
+ print('total base parameters', sum(x.numel() for x in model.parameters()))
30
+ # Create upsampler model.
31
+ options_up = model_and_diffusion_defaults_upsampler()
32
+ options_up['use_fp16'] = has_cuda
33
+ options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
34
+ model_up, diffusion_up = create_model_and_diffusion(**options_up)
35
+ model_up.eval()
36
+ if has_cuda:
37
+ model_up.convert_to_fp16()
38
+ model_up.to(device)
39
+ model_up.load_state_dict(load_checkpoint('upsample', device))
40
+ print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
41
+ def get_images(batch: th.Tensor):
42
+ """ Display a batch of images inline. """
43
+ scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
44
+ reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
45
+ Image.fromarray(reshaped.numpy())
46
+ # Create a classifier-free guidance sampling function
47
+ guidance_scale = 3.0
48
+ def model_fn(x_t, ts, **kwargs):
49
+ half = x_t[: len(x_t) // 2]
50
+ combined = th.cat([half, half], dim=0)
51
+ model_out = model(combined, ts, **kwargs)
52
+ eps, rest = model_out[:, :3], model_out[:, 3:]
53
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
54
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
55
+ eps = th.cat([half_eps, half_eps], dim=0)
56
+ return th.cat([eps, rest], dim=1)
57
+ @app.get("/")
58
+ def read_root():
59
+ return {"glide!"}
60
+ @app.get("/{generate}")
61
+ def sample(prompt):
62
+ # Sampling parameters
63
+ batch_size = 1
64
+ # Tune this parameter to control the sharpness of 256x256 images.
65
+ # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
66
+ upsample_temp = 0.997
67
+ ##############################
68
+ # Sample from the base model #
69
+ ##############################
70
+ # Create the text tokens to feed to the model.
71
+ tokens = model.tokenizer.encode(prompt)
72
+ tokens, mask = model.tokenizer.padded_tokens_and_mask(
73
+ tokens, options['text_ctx']
74
+ )
75
+ # Create the classifier-free guidance tokens (empty)
76
+ full_batch_size = batch_size * 2
77
+ uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
78
+ [], options['text_ctx']
79
+ )
80
+ # Pack the tokens together into model kwargs.
81
+ model_kwargs = dict(
82
+ tokens=th.tensor(
83
+ [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
84
+ ),
85
+ mask=th.tensor(
86
+ [mask] * batch_size + [uncond_mask] * batch_size,
87
+ dtype=th.bool,
88
+ device=device,
89
+ ),
90
+ )
91
+ # Sample from the base model.
92
+ model.del_cache()
93
+ samples = diffusion.p_sample_loop(
94
+ model_fn,
95
+ (full_batch_size, 3, options["image_size"], options["image_size"]),
96
+ device=device,
97
+ clip_denoised=True,
98
+ progress=True,
99
+ model_kwargs=model_kwargs,
100
+ cond_fn=None,
101
+ )[:batch_size]
102
+ model.del_cache()
103
+ ##############################
104
+ # Upsample the 64x64 samples #
105
+ ##############################
106
+ tokens = model_up.tokenizer.encode(prompt)
107
+ tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
108
+ tokens, options_up['text_ctx']
109
+ )
110
+ # Create the model conditioning dict.
111
+ model_kwargs = dict(
112
+ # Low-res image to upsample.
113
+ low_res=((samples+1)*127.5).round()/127.5 - 1,
114
+ # Text tokens
115
+ tokens=th.tensor(
116
+ [tokens] * batch_size, device=device
117
+ ),
118
+ mask=th.tensor(
119
+ [mask] * batch_size,
120
+ dtype=th.bool,
121
+ device=device,
122
+ ),
123
+ )
124
+ # Sample from the base model.
125
+ model_up.del_cache()
126
+ up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
127
+ up_samples = diffusion_up.ddim_sample_loop(
128
+ model_up,
129
+ up_shape,
130
+ noise=th.randn(up_shape, device=device) * upsample_temp,
131
+ device=device,
132
+ clip_denoised=True,
133
+ progress=True,
134
+ model_kwargs=model_kwargs,
135
+ cond_fn=None,
136
+ )[:batch_size]
137
+ model_up.del_cache()
138
+ # Show the output
139
+ image = get_images(up_samples)
140
+ image = to_base64(image)
141
+ return {"image": image}
142
+ def to_base64(pil_image):
143
+ buffered = BytesIO()
144
+ pil_image.save(buffered, format="JPEG")
145
+ return base64.b64encode(buffered.getvalue())