kolors / worker_runpod.py
camenduru's picture
Create worker_runpod.py
aff32ce verified
raw
history blame
No virus
3.9 kB
import os, json, requests, runpod
import torch, random
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler
discord_token = os.getenv('com_camenduru_discord_token')
web_uri = os.getenv('com_camenduru_web_uri')
web_token = os.getenv('com_camenduru_web_token')
with torch.inference_mode():
ckpt_dir = f'/content/Kolors/weights/Kolors'
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=torch.float16).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
def closestNumber(n, m):
q = int(n / m)
n1 = m * q
if (n * m) > 0:
n2 = m * (q + 1)
else:
n2 = m * (q - 1)
if abs(n - n1) < abs(n - n2):
return n1
return n2
@torch.inference_mode()
def generate(input):
values = input["input"]
prompt = values['prompt']
width = values['width']
height = values['height']
num_inference_steps = values['num_inference_steps']
guidance_scale = values['guidance_scale']
num_images_per_prompt = values['num_images_per_prompt']
seed = values['seed']
if seed == 0:
seed = random.randint(0, 18446744073709551615)
image = pipe(
prompt=prompt,
width=closestNumber(width, 8),
height=closestNumber(height, 8),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=torch.Generator(pipe.device).manual_seed(seed)).images[0]
image.save(f'/content/Kolors/scripts/outputs/kolors.jpg')
result = "/content/Kolors/scripts/outputs/kolors.jpg"
response = None
try:
source_id = values['source_id']
del values['source_id']
source_channel = values['source_channel']
del values['source_channel']
job_id = values['job_id']
del values['job_id']
default_filename = os.path.basename(result)
files = {default_filename: open(result, "rb").read()}
payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
response = requests.post(
f"https://discord.com/api/v9/channels/{source_channel}/messages",
data=payload,
headers={"authorization": f"Bot {discord_token}"},
files=files
)
response.raise_for_status()
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
if os.path.exists(result):
os.remove(result)
if response and response.status_code == 200:
try:
payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
return {"result": response.json()['attachments'][0]['url']}
else:
return {"result": "ERROR"}
runpod.serverless.start({"handler": generate})