|
import gradio as gr
|
|
import numpy as np
|
|
import random
|
|
import time
|
|
import json
|
|
import os
|
|
from loguru import logger
|
|
from decouple import config
|
|
import io
|
|
import torch
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
from PIL import Image
|
|
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
import spaces
|
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
print(torch.version.cuda)
|
|
device = torch.cuda.get_device_name(torch.cuda.current_device())
|
|
print(device)
|
|
|
|
sam_checkpoint = "sam-hq/models/sam_hq_vit_h.pth"
|
|
model_type = "vit_h"
|
|
device = "cuda"
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
sam.to(device=device)
|
|
predictor = SamPredictor(sam)
|
|
|
|
@spaces.GPU(duration=5)
|
|
def generate_image(prompt, image):
|
|
predictor.set_image(image)
|
|
|
|
prompt = json.loads(prompt)
|
|
input_points = np.array(prompt['input_points'])
|
|
input_labels = np.array(prompt['input_labels'])
|
|
|
|
mask, _, _ = predictor.predict(
|
|
point_coords=input_points,
|
|
point_labels=input_labels,
|
|
box=None,
|
|
multimask_output=False,
|
|
hq_token_only=True,
|
|
)
|
|
|
|
rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8)
|
|
rgb_array[mask[0]] = 255
|
|
result = Image.fromarray(rgb_array)
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo = gr.Interface(fn=generate_image, inputs=[
|
|
"text",
|
|
gr.Image(image_mode='RGB', type="numpy")
|
|
],
|
|
outputs=[
|
|
gr.Image(type="numpy", image_mode='RGB')
|
|
])
|
|
demo.launch(debug=True)
|
|
logger.debug('demo.launch()')
|
|
|