File size: 807 Bytes
38f925c
9b53dda
d04db90
38f925c
 
9b53dda
 
b4b6420
38f925c
 
b4b6420
38f925c
d04db90
9b53dda
b4b6420
 
9b53dda
 
 
 
 
 
 
 
 
 
 
 
38f925c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr

import spaces
import torch
from PIL import Image
import numpy as np
import cv2
import os
from simple_lama_inpainting import SimpleLama

big_lama_url = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'model/big-lama.pt')

@spaces.GPU
def lama_inpainting(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    os.environ['LAMA_MODEL'] = big_lama_url
    lama: SimpleLama = SimpleLama()
    res = lama(
        Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGB"),
        Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)).convert("L"),
    )
    return cv2.cvtColor(np.array(res), cv2.COLOR_RGB2BGR)


inpaint = gr.Interface(
    fn=lama_inpainting,
    inputs=[gr.Image(label="image"), gr.Image(label="mask")],
    outputs=gr.Image(),
)
inpaint.launch()