File size: 1,797 Bytes
de1c0ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import runpod
import json
import os
import requests
import time
import base64
from io import BytesIO
from PIL import Image


# load runpod api key and serverless model id
runpod.api_key = ""
RUNPOD_MODEL_ID = ""

endpoint = runpod.Endpoint(RUNPOD_MODEL_ID)

def get_result(user_image, clothes_image, body_part):

    input = {"human_img_b64": pil_to_b64(user_image),
            "garm_img_b64": pil_to_b64(clothes_image),
            "body_part": body_part,
             "is_checked_crop": True
            }
    
    start = time.time()
    
    # First way to call serverless api
    run_request = endpoint.run_sync(input)
    
    end = time.time()
    print('Time taken: ', end-start)

    output_image_b64 = run_request['output_image']
    output_image = b64_to_pil(output_image_b64)

    mask_image_b64 = run_request['mask_image']
    mask_image = b64_to_pil(mask_image_b64)
    
    return output_image, mask_image

def b64_to_pil(base64_string):
    # Decode the base64 string
    image_data = base64.b64decode(base64_string)

    # Create a PIL Image object from the decoded image data
    image = Image.open(BytesIO(image_data))
    return image

def pil_to_b64(pil_img):
    buffered = BytesIO()
    pil_img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str
    
# Second way to call serverless api
# url = f'https://api.runpod.ai/v2/{RUNPOD_MODEL_ID}/run_sync' # or change to runsync
# headers = {
#     'accept': 'application/json',
#     'Content-Type': 'application/json',
#     'Authorization': f'Bearer {runpod.api_key}'
# }
# start = time.time()
# response = requests.post(url, headers=headers, data=json.dumps(input))
# print('\n')
# print(response.json())
# end = time.time()
# print('Time taken: ', end-start)