aiinference222 / handler.py
mart9992's picture
d
0893e31
import io
import os
import random
import string
import sys
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
import numpy as np
import requests
from PIL import Image
from io import BytesIO
import pydash as _
import boto3
import time
import subprocess
from GroundedSegmentAnything import installation
S3_REGION = "fra1"
S3_ACCESS_ID = "0RN7BZXS59HYSBD3VB79"
S3_ACCESS_SECRET = "hfSPgBlWl5jsGHa2xuByVkSpancgVeA2CVQf2EMp"
S3_ENDPOINT_URL = "https://s3.solarcom.ch"
S3_BUCKET_NAME = "pissnelke"
s3_session = boto3.session.Session()
s3 = s3_session.client(
service_name="s3",
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_ID,
aws_secret_access_key=S3_ACCESS_SECRET,
endpoint_url=S3_ENDPOINT_URL,
)
get_nude_function = None
class EndpointHandler():
def __init__(self, path=""):
# get_nude(Image.open("girl.png"))
os.environ['path'] = path
print("running apt-get update && apt-get install ffmpeg libsm6 libxext6 -y")
command = "apt-get update && apt-get install ffmpeg libsm6 libxext6 -y"
process = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
print("ran apt-get update && apt-get install ffmpeg libsm6 libxext6 -y")
print("path", path)
global get_nude_function
from main import get_nude
get_nude_function = get_nude
def get_pipe(self, loras, lora_weights):
pipe = _.clone_deep(self.base_pipe)
lora_names = _.map_(loras, lambda x: x.split(".")[0])
lora_weights = _.map_(lora_weights, lambda x: float(x))
if len(lora_weights) > 0 and len(lora_weights) == len(lora_names):
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
return pipe
def __call__(self, data):
original_image_res = requests.get(data.get("original_link"))
original_pil = Image.open(BytesIO(original_image_res.content))
replicate_api_key = data.get("replicate_api_key", "")
with_small_tits = data.get("with_small_tits", False)
with_big_tits = data.get("with_big_tits", False)
nude_pils = get_nude_function(cfg_scale=data.get("cfg_scale"), generate_max_size=data.get("generate_max_size"), original_max_size=data.get(
"original_max_size"), original_pil=original_pil, positive_prompt=data.get("positive_prompt"), steps=data.get("steps"), replicate_api_key=replicate_api_key, with_small_tits=with_small_tits, with_big_tits=with_big_tits)
filenames = []
for image in nude_pils:
byte_arr = io.BytesIO()
image.save(byte_arr, format='PNG')
byte_arr = byte_arr.getvalue()
random_string = ''.join(random.choice(
string.ascii_letters + string.digits) for i in range(20))
image_filename = random_string + ".jpeg"
s3.put_object(Body=byte_arr, Bucket=S3_BUCKET_NAME,
Key=image_filename)
filenames.append(image_filename)
return {
"filenames": filenames
}