camenduru commited on
Commit
aff32ce
1 Parent(s): db356bf

Create worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +105 -0
worker_runpod.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, requests, runpod
2
+
3
+ import torch, random
4
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
5
+ from kolors.models.modeling_chatglm import ChatGLMModel
6
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
7
+ from diffusers import UNet2DConditionModel, AutoencoderKL
8
+ from diffusers import EulerDiscreteScheduler
9
+
10
+ discord_token = os.getenv('com_camenduru_discord_token')
11
+ web_uri = os.getenv('com_camenduru_web_uri')
12
+ web_token = os.getenv('com_camenduru_web_token')
13
+
14
+ with torch.inference_mode():
15
+ ckpt_dir = f'/content/Kolors/weights/Kolors'
16
+ text_encoder = ChatGLMModel.from_pretrained(
17
+ f'{ckpt_dir}/text_encoder',
18
+ torch_dtype=torch.float16).half()
19
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
20
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
21
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
22
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
23
+ pipe = StableDiffusionXLPipeline(
24
+ vae=vae,
25
+ text_encoder=text_encoder,
26
+ tokenizer=tokenizer,
27
+ unet=unet,
28
+ scheduler=scheduler,
29
+ force_zeros_for_empty_prompt=False)
30
+ pipe = pipe.to("cuda")
31
+ pipe.enable_model_cpu_offload()
32
+
33
+ def closestNumber(n, m):
34
+ q = int(n / m)
35
+ n1 = m * q
36
+ if (n * m) > 0:
37
+ n2 = m * (q + 1)
38
+ else:
39
+ n2 = m * (q - 1)
40
+ if abs(n - n1) < abs(n - n2):
41
+ return n1
42
+ return n2
43
+
44
+ @torch.inference_mode()
45
+ def generate(input):
46
+ values = input["input"]
47
+
48
+ prompt = values['prompt']
49
+ width = values['width']
50
+ height = values['height']
51
+ num_inference_steps = values['num_inference_steps']
52
+ guidance_scale = values['guidance_scale']
53
+ num_images_per_prompt = values['num_images_per_prompt']
54
+ seed = values['seed']
55
+
56
+ if seed == 0:
57
+ seed = random.randint(0, 18446744073709551615)
58
+
59
+ image = pipe(
60
+ prompt=prompt,
61
+ width=closestNumber(width, 8),
62
+ height=closestNumber(height, 8),
63
+ num_inference_steps=num_inference_steps,
64
+ guidance_scale=guidance_scale,
65
+ num_images_per_prompt=num_images_per_prompt,
66
+ generator=torch.Generator(pipe.device).manual_seed(seed)).images[0]
67
+ image.save(f'/content/Kolors/scripts/outputs/kolors.jpg')
68
+
69
+ result = "/content/Kolors/scripts/outputs/kolors.jpg"
70
+ response = None
71
+ try:
72
+ source_id = values['source_id']
73
+ del values['source_id']
74
+ source_channel = values['source_channel']
75
+ del values['source_channel']
76
+ job_id = values['job_id']
77
+ del values['job_id']
78
+ default_filename = os.path.basename(result)
79
+ files = {default_filename: open(result, "rb").read()}
80
+ payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
81
+ response = requests.post(
82
+ f"https://discord.com/api/v9/channels/{source_channel}/messages",
83
+ data=payload,
84
+ headers={"authorization": f"Bot {discord_token}"},
85
+ files=files
86
+ )
87
+ response.raise_for_status()
88
+ except Exception as e:
89
+ print(f"An unexpected error occurred: {e}")
90
+ finally:
91
+ if os.path.exists(result):
92
+ os.remove(result)
93
+
94
+ if response and response.status_code == 200:
95
+ try:
96
+ payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
97
+ requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
98
+ except Exception as e:
99
+ print(f"An unexpected error occurred: {e}")
100
+ finally:
101
+ return {"result": response.json()['attachments'][0]['url']}
102
+ else:
103
+ return {"result": "ERROR"}
104
+
105
+ runpod.serverless.start({"handler": generate})