|
import os, json, requests, runpod |
|
|
|
import torch, random |
|
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline |
|
from kolors.models.modeling_chatglm import ChatGLMModel |
|
from kolors.models.tokenization_chatglm import ChatGLMTokenizer |
|
from diffusers import UNet2DConditionModel, AutoencoderKL |
|
from diffusers import EulerDiscreteScheduler |
|
|
|
discord_token = os.getenv('com_camenduru_discord_token') |
|
web_uri = os.getenv('com_camenduru_web_uri') |
|
web_token = os.getenv('com_camenduru_web_token') |
|
|
|
with torch.inference_mode(): |
|
ckpt_dir = f'/content/Kolors/weights/Kolors' |
|
text_encoder = ChatGLMModel.from_pretrained( |
|
f'{ckpt_dir}/text_encoder', |
|
torch_dtype=torch.float16).half() |
|
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') |
|
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() |
|
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") |
|
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() |
|
pipe = StableDiffusionXLPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
force_zeros_for_empty_prompt=False) |
|
pipe = pipe.to("cuda") |
|
pipe.enable_model_cpu_offload() |
|
|
|
def closestNumber(n, m): |
|
q = int(n / m) |
|
n1 = m * q |
|
if (n * m) > 0: |
|
n2 = m * (q + 1) |
|
else: |
|
n2 = m * (q - 1) |
|
if abs(n - n1) < abs(n - n2): |
|
return n1 |
|
return n2 |
|
|
|
@torch.inference_mode() |
|
def generate(input): |
|
values = input["input"] |
|
|
|
prompt = values['prompt'] |
|
width = values['width'] |
|
height = values['height'] |
|
num_inference_steps = values['num_inference_steps'] |
|
guidance_scale = values['guidance_scale'] |
|
num_images_per_prompt = values['num_images_per_prompt'] |
|
seed = values['seed'] |
|
|
|
if seed == 0: |
|
seed = random.randint(0, 18446744073709551615) |
|
|
|
image = pipe( |
|
prompt=prompt, |
|
width=closestNumber(width, 8), |
|
height=closestNumber(height, 8), |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_images_per_prompt, |
|
generator=torch.Generator(pipe.device).manual_seed(seed)).images[0] |
|
image.save(f'/content/Kolors/scripts/outputs/kolors.jpg') |
|
|
|
result = "/content/Kolors/scripts/outputs/kolors.jpg" |
|
response = None |
|
try: |
|
source_id = values['source_id'] |
|
del values['source_id'] |
|
source_channel = values['source_channel'] |
|
del values['source_channel'] |
|
job_id = values['job_id'] |
|
del values['job_id'] |
|
default_filename = os.path.basename(result) |
|
files = {default_filename: open(result, "rb").read()} |
|
payload = {"content": f"{json.dumps(values)} <@{source_id}>"} |
|
response = requests.post( |
|
f"https://discord.com/api/v9/channels/{source_channel}/messages", |
|
data=payload, |
|
headers={"authorization": f"Bot {discord_token}"}, |
|
files=files |
|
) |
|
response.raise_for_status() |
|
except Exception as e: |
|
print(f"An unexpected error occurred: {e}") |
|
finally: |
|
if os.path.exists(result): |
|
os.remove(result) |
|
|
|
if response and response.status_code == 200: |
|
try: |
|
payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']} |
|
requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"}) |
|
except Exception as e: |
|
print(f"An unexpected error occurred: {e}") |
|
finally: |
|
return {"result": response.json()['attachments'][0]['url']} |
|
else: |
|
return {"result": "ERROR"} |
|
|
|
runpod.serverless.start({"handler": generate}) |