File size: 2,665 Bytes
157c221
72dddd7
157c221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72dddd7
0893e31
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from io import BytesIO
from PIL import Image
import requests
import time
import os
import boto3


S3_REGION = "fra1"
S3_ACCESS_ID = "0RN7BZXS59HYSBD3VB79"
S3_ACCESS_SECRET = "hfSPgBlWl5jsGHa2xuByVkSpancgVeA2CVQf2EMp"
S3_ENDPOINT_URL = "https://s3.solarcom.ch"
S3_BUCKET_NAME = "pissnelke"

s3_session = boto3.session.Session()
s3 = s3_session.client(
    service_name="s3",
    region_name=S3_REGION,
    aws_access_key_id=S3_ACCESS_ID,
    aws_secret_access_key=S3_ACCESS_SECRET,
    endpoint_url=S3_ENDPOINT_URL,
)

def get_mask_replicate(input_pil, positive_prompt, expand_by=0, negative_prompt="", replicate_api_key=""):
    # Set up the API endpoint and headers
    api_endpoint = "https://api.replicate.com/v1/predictions"
    headers = {
        "Authorization": f"Token {replicate_api_key}"
    }

    s3filepath = f"target/{os.urandom(20).hex()}.png"
    input_buffer = BytesIO()
    input_pil.save(input_buffer, 'JPEG')  # Use the appropriate format
    input_buffer.seek(0)
    s3.put_object(Bucket=S3_BUCKET_NAME, Key=s3filepath, Body=input_buffer)

    # Prepare the data for the POST request
    data = {
        "version": "ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c",
        "input": {
            "image": f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3filepath}",
            "mask_prompt": positive_prompt,
            "negative_mask_prompt": negative_prompt,
            "adjustment_factor": expand_by,
        }
    }

    # Make the initial POST request
    response = requests.post(api_endpoint, json=data, headers=headers)
    response_data = response.json()

    print(response_data)

    # Check the status of the prediction and wait for completion
    while True:
        prediction_response = requests.get(f"{api_endpoint}/{response_data['id']}", headers=headers)
        prediction_data = prediction_response.json()

        if prediction_data['status'] == 'failed':
            raise Exception(prediction_data.get('error'))

        if prediction_data.get('status') == 'succeeded':
            output_link = prediction_data['output'][2]
            break

        time.sleep(1)  # Avoid spamming the server, wait for a bit before the next status check

    # Get the output image
    output_response = requests.get(output_link)
    image_data = BytesIO(output_response.content)

    # Use PIL to handle the image
    output_image = Image.open(image_data)

    return output_image

verrueckt_pil = Image.open("sport.jpg")
x = get_mask_replicate(verrueckt_pil, "bra . blouse . skirt . dress", negative_prompt="face", expand_by=10, replicate_api_key="r8_GTeyENFqfOXFAI0COiGlB2RkhqEzqS64XBuIk")
x.save("hallo.png")
print(x)