multimodalart HF staff commited on
Commit
ab9e9c4
1 Parent(s): e3d2366

Add normalisation

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import sys
4
 
5
  from IPython import display
6
  import torch
 
7
  from torchvision import utils as tv_utils
8
  from torchvision.transforms import functional as TF
9
  import gradio as gr
@@ -28,7 +29,10 @@ _, side_y, side_x = model.shape
28
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
29
  model = model.half().cuda().eval().requires_grad_(False)
30
  clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
31
-
 
 
 
32
  def run_all(prompt, steps, n_images, weight, clip_guided):
33
  import random
34
  seed = int(random.randint(0, 2147483647))
 
4
 
5
  from IPython import display
6
  import torch
7
+ from torchvision import transforms
8
  from torchvision import utils as tv_utils
9
  from torchvision.transforms import functional as TF
10
  import gradio as gr
 
29
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
30
  model = model.half().cuda().eval().requires_grad_(False)
31
  clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
32
+ clip_model.eval().requires_grad_(False)
33
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
34
+ std=[0.26862954, 0.26130258, 0.27577711])
35
+ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
36
  def run_all(prompt, steps, n_images, weight, clip_guided):
37
  import random
38
  seed = int(random.randint(0, 2147483647))