import argparse import base64 from io import BytesIO from PIL import Image from handler import EndpointHandler, decode_base64_image def local_predict(prompts, encode_image): # Init handler my_handler = EndpointHandler() if encode_image: response = my_handler({"inputs": prompts, "image": encode_image}) else: response = my_handler({"inputs": prompts}) image = decode_base64_image(response["image"]) image.save("local_output.png") opt = argparse.ArgumentParser("Diffuser local test") opt.add_argument("-prompts", "--prompts", default="", type=str, help="Diffuser prompts") opt.add_argument("-image", "--image", default="", type=str, help="Init image") if __name__ == '__main__': args = opt.parse_args() encoded_string = "" if args.image: with open(args.image, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode() local_predict(args.prompts, encoded_string)