File size: 2,919 Bytes
fdb9dcc
 
 
 
 
 
 
 
c6342ce
fdb9dcc
d4adb6a
 
fdb9dcc
c6342ce
fdb9dcc
 
d4adb6a
 
 
 
 
 
 
 
 
 
 
fdb9dcc
 
 
 
 
 
 
 
d4adb6a
 
fdb9dcc
 
 
 
 
 
 
 
 
 
 
d4adb6a
 
fdb9dcc
 
d4adb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb9dcc
d4adb6a
 
fdb9dcc
d4adb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb9dcc
 
d4adb6a
fdb9dcc
d4adb6a
fdb9dcc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import requests
import zipfile
import io
from PIL import Image
import openai
import os
from scipy import ndimage
# from dotenv import load_dotenv
import numpy as np
import hashlib
import queue

# load_dotenv()


cache = dict()
que = queue.Queue(30)


def save_to_memory_cache(key, file):
    print('save mask')
    cache[key] = file
    que.put(key)
    if que.full():
        rkey = que.get()
        del cache[rkey]


def get_circle_footprint(size):
    xx, yy = np.mgrid[0:size * 2, 0:size * 2]
    fp = ((xx - size) ** 2 + (yy - size) ** 2) < (size) ** 2
    return fp


def request_mask(img_path):
    removebg_api_key = os.getenv('REMOVEBG_API_KEY')
    response = requests.post(
        'https://api.remove.bg/v1.0/removebg',
        files={'image_file': open(img_path, 'rb')},
        data={
            'size': 'auto',
            'format': 'zip'
        },
        headers={'X-Api-Key': removebg_api_key},
    )
    if response.status_code == requests.codes.ok:
        zipFile = zipfile.ZipFile(io.BytesIO(response.content))
        maskImFile = zipFile.read('alpha.png')
        return maskImFile
    else:
        print("Error:", response.status_code, response.text)
        return None


def get_file_hash(path):
    with open(path, 'rb') as inputfile:
        fh = hashlib.sha256()
        fb = inputfile.read(65536)
        while len(fb) > 0:
            fh.update(fb)
            fb = inputfile.read(65536)
    return fh.hexdigest()


def process_image(prompt, img_path, mask_margin):
    openai.api_key = os.getenv('OPENAI_API_KEY')

    hsh = get_file_hash(img_path)
    print('hash',hsh)

    maskImFile = None
    if hsh in cache:
        maskImFile = cache[hsh]
    else:
        maskImFile = request_mask(img_path)
        if maskImFile != None:
            save_to_memory_cache(hsh, maskImFile)
        else:
            print('no mask received')
            return 'https://i.imgur.com/DUd0OWN.png'

    maskIm = Image.open(io.BytesIO(maskImFile))

    alpha = maskIm.getchannel(0)
    if mask_margin > 0:
        inflated_alpha = ndimage.maximum_filter(input=np.array(
            alpha), footprint=get_circle_footprint(mask_margin))
        alpha = Image.fromarray(np.uint8(inflated_alpha))
    maskIm.paste((255), [0, 0, maskIm.size[0], maskIm.size[1]])
    maskIm.putalpha(alpha)

    maskFile = io.BytesIO()
    maskIm.save(maskFile, format='PNG')
    maskFile.seek(0)

    response = openai.Image.create_edit(
        image=open(img_path, "rb"),
        mask=maskFile,
        prompt=prompt,
        n=1,
        size="512x512"
    )
    return response['data'][0]['url']

# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo = gr.Interface(process_image, inputs=[
    'text',
    gr.Image(type='filepath', shape=(500, 500), label='image'),
    gr.Slider(minimum=0, maximum=10, value=5, step=1, label="mask margin")
], outputs=["image"])

demo.launch()