jositonaranja commited on
Commit
6e86492
1 Parent(s): 520a5a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -43,40 +43,13 @@ def show_images(batch: th.Tensor):
43
  reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
44
  display(Image.fromarray(reshaped.numpy()))
45
  # Sampling parameters
46
- prompt = "an oil painting of a corgi"
47
  batch_size = 1
48
  guidance_scale = 3.0
49
 
50
  # Tune this parameter to control the sharpness of 256x256 images.
51
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
52
  upsample_temp = 0.997
53
- ##############################
54
- # Sample from the base model #
55
- ##############################
56
 
57
- # Create the text tokens to feed to the model.
58
- tokens = model.tokenizer.encode(prompt)
59
- tokens, mask = model.tokenizer.padded_tokens_and_mask(
60
- tokens, options['text_ctx']
61
- )
62
-
63
- # Create the classifier-free guidance tokens (empty)
64
- full_batch_size = batch_size * 2
65
- uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
66
- [], options['text_ctx']
67
- )
68
-
69
- # Pack the tokens together into model kwargs.
70
- model_kwargs = dict(
71
- tokens=th.tensor(
72
- [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
73
- ),
74
- mask=th.tensor(
75
- [mask] * batch_size + [uncond_mask] * batch_size,
76
- dtype=th.bool,
77
- device=device,
78
- ),
79
- )
80
 
81
  # Create a classifier-free guidance sampling function
82
  def model_fn(x_t, ts, **kwargs):
@@ -89,7 +62,35 @@ def model_fn(x_t, ts, **kwargs):
89
  eps = th.cat([half_eps, half_eps], dim=0)
90
  return th.cat([eps, rest], dim=1)
91
 
92
- def run():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  print('run():')
 
43
  reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
44
  display(Image.fromarray(reshaped.numpy()))
45
  # Sampling parameters
 
46
  batch_size = 1
47
  guidance_scale = 3.0
48
 
49
  # Tune this parameter to control the sharpness of 256x256 images.
50
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
51
  upsample_temp = 0.997
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Create a classifier-free guidance sampling function
55
  def model_fn(x_t, ts, **kwargs):
 
62
  eps = th.cat([half_eps, half_eps], dim=0)
63
  return th.cat([eps, rest], dim=1)
64
 
65
+ def run(prompt):
66
+
67
+ ##############################
68
+ # Sample from the base model #
69
+ ##############################
70
+
71
+ # Create the text tokens to feed to the model.
72
+ tokens = model.tokenizer.encode(prompt)
73
+ tokens, mask = model.tokenizer.padded_tokens_and_mask(
74
+ tokens, options['text_ctx']
75
+ )
76
+
77
+ # Create the classifier-free guidance tokens (empty)
78
+ full_batch_size = batch_size * 2
79
+ uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
80
+ [], options['text_ctx']
81
+ )
82
+
83
+ # Pack the tokens together into model kwargs.
84
+ model_kwargs = dict(
85
+ tokens=th.tensor(
86
+ [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
87
+ ),
88
+ mask=th.tensor(
89
+ [mask] * batch_size + [uncond_mask] * batch_size,
90
+ dtype=th.bool,
91
+ device=device,
92
+ ),
93
+ )
94
 
95
 
96
  print('run():')