import sys import os import cv2 import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torchvision import glob import gradio as gr from PIL import Image from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry import logging from huggingface_hub import hf_hub_download token = os.environ['HUB_TOKEN'] loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token) sys.path.append(loc) from utils import * with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown(title) with gr.Accordion("Instructions For User 👉", open=False): gr.Markdown(description) x=gr.State(value=[]) y=gr.State(value=[]) label=gr.State(value=[]) with gr.Row(): with gr.Column(scale=13): with gr.Row(): with gr.Column(): mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods') with gr.Column(): clear_bn=gr.Button("Clear Selection") interseg_button = gr.Button("Interactive Segment",variant='primary') with gr.Row(): input_img = gr.Image(label="Input") gallery = gr.Image(label="Points") input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label]) with gr.Row(): output_img = gr.Image(label="Result") mask_img = gr.Image(label="Mask") with gr.Row(): with gr.Column(): thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold") with gr.Column(): points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side") with gr.Column(scale=2,min_width=8): example = gr.Examples( examples=[[s,0.9,32] for s in glob.glob('./images/*')], fn=auto_seg, inputs=[input_img,thresh,points], outputs=[output_img], cache_examples=False,examples_per_page=5) autoseg_button = gr.Button("Auto Segment",variant="primary") emptyBtn = gr.Button("Restart",variant="secondary") interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img]) autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img]) clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True) emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,) gr.Markdown(descriptionend) if __name__ == "__main__": demo.launch(debug=False,show_api=False)