Circularmachines's picture
app.py updates
458490b
raw
history blame
2.59 kB
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
ds = load_dataset("Circularmachines/batch_indexing_machine_100_small_imgs", split="train")
pred=np.load('pred.npy')
keep_bool=np.load('keep_bool.npy')
#st.set_page_config(
# page_title="Streamlit Image Coordinates: Image Update",
# page_icon="🎯",
# layout="wide",
#)
#"# :dart: Streamlit Image Coordinates: Image Update"
if "point" not in st.session_state:
st.session_state["point"] = (200,200)
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = False
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [0,1,2,3]
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
def find():
point=st.session_state["point"]
point=(point[0]//16,point[1]//16)
#point=point[0]*36+point[1]
st.write(point)
st.write(keep_bool[point[0],point[1]])
# for i in range(4):
# st.session_state["sideimg"][i]+=1
# st.image(ds[0]['image'])
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
patch_size = 16
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2 = st.columns([5,1])
with col1:
current_image=ds[st.session_state["img"]]['image']#.resize(size=(384,384))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = value["x"]//16*16, value["y"]//16*16
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
st.button('Change Batch', on_click=button_click)
st.button('Find similar parts', on_click=find)
st.write(st.session_state["img"])
st.write(st.session_state["point"])
st.write(st.session_state["draw"])
with col2:
for i in range(3):
st.image(np.array(ds[st.session_state["sideimg"][i]]['image'])[::4,::4,:])