rynmurdock commited on
Commit
94aebbe
1 Parent(s): dd107f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -46
app.py CHANGED
@@ -7,34 +7,17 @@ import numpy as np
7
  from sklearn.svm import LinearSVC
8
  from sklearn import preprocessing
9
  import pandas as pd
10
- import kornia
11
- import torchvision
12
 
13
  import random
14
  import time
15
 
16
- from diffusers import LCMScheduler
17
- from diffusers.models import ImageProjection
18
- from patch_sdxl import SDEmb
19
  import torch
20
-
21
 
22
  prompt_list = [p for p in list(set(
23
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
24
 
25
-
26
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
27
- lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
28
-
29
- pipe = SDEmb.from_pretrained(model_id, variant="fp16")
30
- pipe.load_lora_weights(lcm_lora_id)
31
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
32
- pipe.to(device=DEVICE, dtype=torch.float16)
33
-
34
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
35
-
36
-
37
-
38
  calibrate_prompts = [
39
  "4k photo",
40
  'surrealist art',
@@ -57,20 +40,6 @@ ys = []
57
 
58
  start_time = time.time()
59
 
60
- output_hidden_state = False if isinstance(pipe.unet.encoder_hid_proj, ImageProjection) else True
61
-
62
-
63
- transform = kornia.augmentation.RandomResizedCrop(size=(224, 224), scale=(.3, .5))
64
- nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
65
- def patch_encode_image(image):
66
- image = torch.tensor(torchvision.transforms.functional.pil_to_tensor(image).to(torch.float16)).repeat(16, 1, 1, 1).to(DEVICE)
67
- image = image / 255
68
- patches = nom(transform(image))
69
- output, _ = pipe.encode_image(
70
- patches, DEVICE, 1, output_hidden_state
71
- )
72
- return output.mean(0, keepdim=True)
73
-
74
 
75
  glob_idx = 0
76
 
@@ -96,7 +65,6 @@ def next_image():
96
  pooled_embeds, _ = pipe.encode_image(
97
  image[0], DEVICE, 1, output_hidden_state
98
  )
99
- #pooled_embeds = patch_encode_image(image[0])
100
 
101
  embs.append(pooled_embeds)
102
  return image[0]
@@ -131,19 +99,10 @@ def next_image():
131
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
132
  print(prompt)
133
 
134
- image = pipe(
135
- prompt=prompt,
136
- ip_adapter_emb=im_emb,
137
- height=1024,
138
- width=1024,
139
- num_inference_steps=8,
140
- guidance_scale=0,
141
- ).images
142
-
143
- im_emb, _ = pipe.encode_image(
144
- image[0], DEVICE, 1, output_hidden_state
145
  )
146
- #im_emb = patch_encode_image(image[0])
147
 
148
  embs.append(im_emb)
149
 
 
7
  from sklearn.svm import LinearSVC
8
  from sklearn import preprocessing
9
  import pandas as pd
 
 
10
 
11
  import random
12
  import time
13
 
14
+ import replicate
 
 
15
  import torch
16
+ import pickle
17
 
18
  prompt_list = [p for p in list(set(
19
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  calibrate_prompts = [
22
  "4k photo",
23
  'surrealist art',
 
40
 
41
  start_time = time.time()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  glob_idx = 0
45
 
 
65
  pooled_embeds, _ = pipe.encode_image(
66
  image[0], DEVICE, 1, output_hidden_state
67
  )
 
68
 
69
  embs.append(pooled_embeds)
70
  return image[0]
 
99
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
100
  print(prompt)
101
 
102
+ image, im_emb = replicate.run(
103
+ "rynmurdock/zahir:43177e0594f3bc2e3560170ff0ffb6d1cacdddda1be25fbcd4348ef02b0b7d0f",
104
+ input={"prompt": prompt, 'im_emg': pickle.dumps(im_emb)}
 
 
 
 
 
 
 
 
105
  )
 
106
 
107
  embs.append(im_emb)
108