Maitreya Patel commited on
Commit
90ee823
1 Parent(s): 938774e

minor bug fixes

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. requirements.txt +0 -1
app.py CHANGED
@@ -43,7 +43,7 @@ class Ours:
43
  CLIPTextModelWithProjection.from_pretrained(
44
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
45
  projection_dim=1280,
46
- torch_dtype=torch.float16,
47
  )
48
  .eval()
49
  .requires_grad_(False)
@@ -55,7 +55,7 @@ class Ours:
55
 
56
  prior = PriorTransformer.from_pretrained(
57
  "ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
58
- torch_dtype=torch.float16,
59
  )
60
 
61
  self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
@@ -63,11 +63,11 @@ class Ours:
63
  prior=prior,
64
  text_encoder=text_encoder,
65
  tokenizer=tokenizer,
66
- torch_dtype=torch.float16,
67
  ).to(device)
68
 
69
  self.pipe = DiffusionPipeline.from_pretrained(
70
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
71
  ).to(device)
72
 
73
  def inference(self, text, negative_text, steps, guidance_scale):
 
43
  CLIPTextModelWithProjection.from_pretrained(
44
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
45
  projection_dim=1280,
46
+ torch_dtype=torch.float32,
47
  )
48
  .eval()
49
  .requires_grad_(False)
 
55
 
56
  prior = PriorTransformer.from_pretrained(
57
  "ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
58
+ torch_dtype=torch.float32,
59
  )
60
 
61
  self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
 
63
  prior=prior,
64
  text_encoder=text_encoder,
65
  tokenizer=tokenizer,
66
+ torch_dtype=torch.float32,
67
  ).to(device)
68
 
69
  self.pipe = DiffusionPipeline.from_pretrained(
70
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float32
71
  ).to(device)
72
 
73
  def inference(self, text, negative_text, steps, guidance_scale):
requirements.txt CHANGED
@@ -9,7 +9,6 @@ torch==2.0.0
9
  torchvision==0.15.1
10
  tqdm==4.66.1
11
  transformers==4.34.1
12
- gradio
13
  jmespath
14
  opencv-python
15
  PyWavelet
 
9
  torchvision==0.15.1
10
  tqdm==4.66.1
11
  transformers==4.34.1
 
12
  jmespath
13
  opencv-python
14
  PyWavelet