Mikhaylov Alexey commited on
Commit
d4adb6a
·
1 Parent(s): 8b2dd2a

memory cache mask

Browse files
Files changed (1) hide show
  1. app.py +69 -29
app.py CHANGED
@@ -8,12 +8,23 @@ import os
8
  from scipy import ndimage
9
  # from dotenv import load_dotenv
10
  import numpy as np
 
 
11
 
12
  # load_dotenv()
13
 
14
 
15
- openai.api_key = os.getenv('OPENAI_API_KEY')
16
- removebg_api_key = os.getenv('REMOVEBG_API_KEY')
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def get_circle_footprint(size):
@@ -22,7 +33,8 @@ def get_circle_footprint(size):
22
  return fp
23
 
24
 
25
- def test(prompt, img_path, mask_margin):
 
26
  response = requests.post(
27
  'https://api.remove.bg/v1.0/removebg',
28
  files={'image_file': open(img_path, 'rb')},
@@ -34,39 +46,67 @@ def test(prompt, img_path, mask_margin):
34
  )
35
  if response.status_code == requests.codes.ok:
36
  zipFile = zipfile.ZipFile(io.BytesIO(response.content))
37
- maskIm = Image.open(io.BytesIO(zipFile.read('alpha.png')))
38
-
39
- alpha= maskIm.getchannel(0)
40
- if mask_margin > 0:
41
- inflated_alpha = ndimage.maximum_filter(input=np.array(alpha), footprint=get_circle_footprint(mask_margin))
42
- alpha = Image.fromarray(np.uint8(inflated_alpha))
43
- maskIm.paste((255),[0,0,maskIm.size[0],maskIm.size[1]])
44
- maskIm.putalpha(alpha)
45
-
46
- maskFile = io.BytesIO()
47
- maskIm.save(maskFile, format='PNG')
48
- maskFile.seek(0)
49
-
50
- response = openai.Image.create_edit(
51
- image=open(img_path, "rb"),
52
- mask=maskFile,
53
- prompt=prompt,
54
- n=1,
55
- size="512x512"
56
- )
57
- return response['data'][0]['url']
58
-
59
- # with open('no-bg.zip', 'wb') as out:
60
- # out.write(response.content)
61
  else:
62
  print("Error:", response.status_code, response.text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
67
- demo = gr.Interface(test, inputs = [
68
  'text',
69
- gr.Image(type='filepath', shape=(500,500), label='image'),
70
  gr.Slider(minimum=0, maximum=10, value=5, step=1, label="mask margin")
71
  ], outputs=["image"])
72
 
 
8
  from scipy import ndimage
9
  # from dotenv import load_dotenv
10
  import numpy as np
11
+ import hashlib
12
+ import queue
13
 
14
  # load_dotenv()
15
 
16
 
17
+ cache = dict()
18
+ que = queue.Queue(30)
19
+
20
+
21
+ def save_to_memory_cache(key, file):
22
+ print('save mask')
23
+ cache[key] = file
24
+ que.put(key)
25
+ if que.full():
26
+ rkey = que.get()
27
+ del cache[rkey]
28
 
29
 
30
  def get_circle_footprint(size):
 
33
  return fp
34
 
35
 
36
+ def request_mask(img_path):
37
+ removebg_api_key = os.getenv('REMOVEBG_API_KEY')
38
  response = requests.post(
39
  'https://api.remove.bg/v1.0/removebg',
40
  files={'image_file': open(img_path, 'rb')},
 
46
  )
47
  if response.status_code == requests.codes.ok:
48
  zipFile = zipfile.ZipFile(io.BytesIO(response.content))
49
+ maskImFile = zipFile.read('alpha.png')
50
+ return maskImFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
  print("Error:", response.status_code, response.text)
53
+ return None
54
+
55
+
56
+ def get_file_hash(path):
57
+ with open(path, 'rb') as inputfile:
58
+ fh = hashlib.sha256()
59
+ fb = inputfile.read(65536)
60
+ while len(fb) > 0:
61
+ fh.update(fb)
62
+ fb = inputfile.read(65536)
63
+ return fh.hexdigest()
64
+
65
+
66
+ def process_image(prompt, img_path, mask_margin):
67
+ openai.api_key = os.getenv('OPENAI_API_KEY')
68
 
69
+ hsh = get_file_hash(img_path)
70
+ print('hash',hsh)
71
 
72
+ maskImFile = None
73
+ if hsh in cache:
74
+ maskImFile = cache[hsh]
75
+ else:
76
+ maskImFile = request_mask(img_path)
77
+ if maskImFile != None:
78
+ save_to_memory_cache(hsh, maskImFile)
79
+ else:
80
+ print('no mask received')
81
+ return 'https://i.imgur.com/DUd0OWN.png'
82
+
83
+ maskIm = Image.open(io.BytesIO(maskImFile))
84
+
85
+ alpha = maskIm.getchannel(0)
86
+ if mask_margin > 0:
87
+ inflated_alpha = ndimage.maximum_filter(input=np.array(
88
+ alpha), footprint=get_circle_footprint(mask_margin))
89
+ alpha = Image.fromarray(np.uint8(inflated_alpha))
90
+ maskIm.paste((255), [0, 0, maskIm.size[0], maskIm.size[1]])
91
+ maskIm.putalpha(alpha)
92
+
93
+ maskFile = io.BytesIO()
94
+ maskIm.save(maskFile, format='PNG')
95
+ maskFile.seek(0)
96
+
97
+ response = openai.Image.create_edit(
98
+ image=open(img_path, "rb"),
99
+ mask=maskFile,
100
+ prompt=prompt,
101
+ n=1,
102
+ size="512x512"
103
+ )
104
+ return response['data'][0]['url']
105
 
106
  # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
107
+ demo = gr.Interface(process_image, inputs=[
108
  'text',
109
+ gr.Image(type='filepath', shape=(500, 500), label='image'),
110
  gr.Slider(minimum=0, maximum=10, value=5, step=1, label="mask margin")
111
  ], outputs=["image"])
112