import base64 import json import io import numpy as np from PIL import Image from pyodide import to_js, create_proxy import gc from js import ( console, document, devicePixelRatio, ImageData, Uint8ClampedArray, CanvasRenderingContext2D as Context2d, requestAnimationFrame, update_overlay, setup_overlay, window ) PAINT_SELECTION = "selection" IMAGE_SELECTION = "canvas" BRUSH_SELECTION = "eraser" NOP_MODE = 0 PAINT_MODE = 1 IMAGE_MODE = 2 BRUSH_MODE = 3 def hold_canvas(): pass def prepare_canvas(width, height, canvas) -> Context2d: ctx = canvas.getContext("2d") canvas.style.width = f"{width}px" canvas.style.height = f"{height}px" canvas.width = width canvas.height = height ctx.clearRect(0, 0, width, height) return ctx # class MultiCanvas: # def __init__(self,layer,width=800, height=600) -> None: # pass def multi_canvas(layer, width=800, height=600): lst = [ CanvasProxy(document.querySelector(f"#canvas{i}"), width, height) for i in range(layer) ] return lst class CanvasProxy: def __init__(self, canvas, width=800, height=600) -> None: self.canvas = canvas self.ctx = prepare_canvas(width, height, canvas) self.width = width self.height = height def clear_rect(self, x, y, w, h): self.ctx.clearRect(x, y, w, h) def clear(self,): self.clear_rect(0, 0, self.canvas.width, self.canvas.height) def stroke_rect(self, x, y, w, h): self.ctx.strokeRect(x, y, w, h) def fill_rect(self, x, y, w, h): self.ctx.fillRect(x, y, w, h) def put_image_data(self, image, x, y): data = Uint8ClampedArray.new(to_js(image.tobytes())) height, width, _ = image.shape image_data = ImageData.new(data, width, height) self.ctx.putImageData(image_data, x, y) del image_data # def draw_image(self,canvas, x, y, w, h): # self.ctx.drawImage(canvas,x,y,w,h) def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight): self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight) @property def stroke_style(self): return self.ctx.strokeStyle @stroke_style.setter def stroke_style(self, value): self.ctx.strokeStyle = value @property def fill_style(self): return self.ctx.strokeStyle @fill_style.setter def fill_style(self, value): self.ctx.fillStyle = value # RGBA for masking class InfCanvas: def __init__( self, width, height, selection_size=256, grid_size=64, patch_size=4096, test_mode=False, ) -> None: assert selection_size < min(height, width) self.width = width self.height = height self.display_width = width self.display_height = height self.canvas = multi_canvas(5, width=width, height=height) setup_overlay(width,height) # place at center self.view_pos = [patch_size//2-width//2, patch_size//2-height//2] self.cursor = [ width // 2 - selection_size // 2, height // 2 - selection_size // 2, ] self.data = {} self.grid_size = grid_size self.selection_size_w = selection_size self.selection_size_h = selection_size self.patch_size = patch_size # note that for image data, the height comes before width self.buffer = np.zeros((height, width, 4), dtype=np.uint8) self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8) self.sel_buffer_bak = np.zeros( (selection_size, selection_size, 4), dtype=np.uint8 ) self.sel_dirty = False self.buffer_dirty = False self.mouse_pos = [-1, -1] self.mouse_state = 0 # self.output = widgets.Output() self.test_mode = test_mode self.buffer_updated = False self.image_move_freq = 1 self.show_brush = False self.scale=1.0 self.eraser_size=32 def reset_large_buffer(self): self.canvas[2].canvas.width=self.width self.canvas[2].canvas.height=self.height # self.canvas[2].canvas.style.width=f"{self.display_width}px" # self.canvas[2].canvas.style.height=f"{self.display_height}px" self.canvas[2].canvas.style.display="block" self.canvas[2].clear() def draw_eraser(self, x, y): self.canvas[-2].clear() self.canvas[-2].fill_style = "#ffffff" self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size) self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size) def use_eraser(self,x,y): if self.sel_dirty: self.write_selection_to_buffer() self.draw_buffer() self.canvas[2].clear() self.buffer_dirty=True bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2 bx1,by1=bx0+self.eraser_size,by0+self.eraser_size bx0,by0=max(0,bx0),max(0,by0) bx1,by1=min(self.width,bx1),min(self.height,by1) self.buffer[by0:by1,bx0:bx1,:]*=0 self.draw_buffer() self.draw_selection_box() def setup_mouse(self): self.image_move_cnt = 0 def get_mouse_mode(): mode = document.querySelector("#mode").value if mode == PAINT_SELECTION: return PAINT_MODE elif mode == IMAGE_SELECTION: return IMAGE_MODE return BRUSH_MODE def get_event_pos(event): canvas = self.canvas[-1].canvas rect = canvas.getBoundingClientRect() x = (canvas.width * (event.clientX - rect.left)) / rect.width y = (canvas.height * (event.clientY - rect.top)) / rect.height return x, y def handle_mouse_down(event): self.mouse_state = get_mouse_mode() if self.mouse_state==BRUSH_MODE: x,y=get_event_pos(event) self.use_eraser(x,y) def handle_mouse_out(event): last_state = self.mouse_state self.mouse_state = NOP_MODE self.image_move_cnt = 0 if last_state == IMAGE_MODE: self.update_view_pos(0, 0) if True: self.clear_background() self.draw_buffer() self.reset_large_buffer() self.draw_selection_box() gc.collect() if self.show_brush: self.canvas[-2].clear() self.show_brush = False def handle_mouse_up(event): last_state = self.mouse_state self.mouse_state = NOP_MODE self.image_move_cnt = 0 if last_state == IMAGE_MODE: self.update_view_pos(0, 0) if True: self.clear_background() self.draw_buffer() self.reset_large_buffer() self.draw_selection_box() gc.collect() async def handle_mouse_move(event): x, y = get_event_pos(event) x0, y0 = self.mouse_pos xo = x - x0 yo = y - y0 if self.mouse_state == PAINT_MODE: self.update_cursor(int(xo), int(yo)) if True: # self.clear_background() # console.log(self.buffer_updated) if self.buffer_updated: self.draw_buffer() self.buffer_updated = False self.draw_selection_box() elif self.mouse_state == IMAGE_MODE: self.image_move_cnt += 1 if self.image_move_cnt == self.image_move_freq: self.draw_buffer() self.canvas[2].clear() self.draw_selection_box() self.update_view_pos(int(xo), int(yo)) self.cached_view_pos=tuple(self.view_pos) self.canvas[2].canvas.style.display="none" large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size)) self.canvas[2].canvas.width=large_buffer.shape[1] self.canvas[2].canvas.height=large_buffer.shape[0] # self.canvas[2].canvas.style.width="" # self.canvas[2].canvas.style.height="" self.canvas[2].put_image_data(large_buffer,0,0) else: self.update_view_pos(int(xo), int(yo), False) self.canvas[1].clear() self.canvas[1].draw_image(self.canvas[2].canvas, self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]), self.width,self.height, 0,0,self.width,self.height ) self.clear_background() # self.image_move_cnt = 0 elif self.mouse_state == BRUSH_MODE: self.use_eraser(x,y) mode = document.querySelector("#mode").value if mode == BRUSH_SELECTION: self.draw_eraser(x,y) self.show_brush = True elif self.show_brush: self.canvas[-2].clear() self.show_brush = False self.mouse_pos[0] = x self.mouse_pos[1] = y self.canvas[-1].canvas.addEventListener( "mousedown", create_proxy(handle_mouse_down) ) self.canvas[-1].canvas.addEventListener( "mousemove", create_proxy(handle_mouse_move) ) self.canvas[-1].canvas.addEventListener( "mouseup", create_proxy(handle_mouse_up) ) self.canvas[-1].canvas.addEventListener( "mouseout", create_proxy(handle_mouse_out) ) async def handle_mouse_wheel(event): x, y = get_event_pos(event) self.mouse_pos[0] = x self.mouse_pos[1] = y console.log(to_js(self.mouse_pos)) if event.deltaY>10: window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*") elif event.deltaY<-10: window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*") return False self.canvas[-1].canvas.addEventListener( "wheel", create_proxy(handle_mouse_wheel), False ) def clear_background(self): # fake transparent background h, w, step = self.height, self.width, self.grid_size stride = step * 2 x0, y0 = self.view_pos x0 = (-x0) % stride y0 = (-y0) % stride if y0>=step: val0,val1=stride,step else: val0,val1=step,stride # self.canvas.clear() self.canvas[0].fill_style = "#ffffff" self.canvas[0].fill_rect(0, 0, w, h) self.canvas[0].fill_style = "#aaaaaa" for y in range(y0-stride, h + step, step): start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1) for x in range(start, w + step, stride): self.canvas[0].fill_rect(x, y, step, step) self.canvas[0].stroke_rect(0, 0, w, h) def refine_selection(self): h,w=self.selection_size_h,self.selection_size_w h=min(h,self.height) w=min(w,self.width) self.selection_size_h=h*8//8 self.selection_size_w=w*8//8 self.update_cursor(1,0) def update_scale(self, scale, mx=-1, my=-1): self.sync_to_data() scaled_width=int(self.display_width*scale) scaled_height=int(self.display_height*scale) if max(scaled_height,scaled_width)>=self.patch_size*2-128: return if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w: return if mx>=0 and my>=0: scaled_mx=mx/self.scale*scale scaled_my=my/self.scale*scale self.view_pos[0]+=int(mx-scaled_mx) self.view_pos[1]+=int(my-scaled_my) self.scale=scale for item in self.canvas: item.canvas.width=scaled_width item.canvas.height=scaled_height item.clear() update_overlay(scaled_width,scaled_height) self.width=scaled_width self.height=scaled_height self.data2buffer() self.clear_background() self.draw_buffer() self.update_cursor(1,0) self.draw_selection_box() def update_view_pos(self, xo, yo, update=True): # if abs(xo) + abs(yo) == 0: # return if self.sel_dirty: self.write_selection_to_buffer() if self.buffer_dirty: self.buffer2data() self.view_pos[0] -= xo self.view_pos[1] -= yo if update: self.data2buffer() # self.read_selection_from_buffer() def update_cursor(self, xo, yo): if abs(xo) + abs(yo) == 0: return if self.sel_dirty: self.write_selection_to_buffer() self.cursor[0] += xo self.cursor[1] += yo self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0) self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0) # self.read_selection_from_buffer() def data2buffer(self): x, y = self.view_pos h, w = self.height, self.width if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]: self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8) # fill four parts for i in range(4): pos_src, pos_dst, data = self.select(x, y, i) xs0, xs1 = pos_src[0] ys0, ys1 = pos_src[1] xd0, xd1 = pos_dst[0] yd0, yd1 = pos_dst[1] self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :] def data2array(self, x, y, w, h): # x, y = self.view_pos # h, w = self.height, self.width ret=np.zeros((h, w, 4), dtype=np.uint8) # fill four parts for i in range(4): pos_src, pos_dst, data = self.select(x, y, i, w, h) xs0, xs1 = pos_src[0] ys0, ys1 = pos_src[1] xd0, xd1 = pos_dst[0] yd0, yd1 = pos_dst[1] ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :] return ret def buffer2data(self): x, y = self.view_pos h, w = self.height, self.width # fill four parts for i in range(4): pos_src, pos_dst, data = self.select(x, y, i) xs0, xs1 = pos_src[0] ys0, ys1 = pos_src[1] xd0, xd1 = pos_dst[0] yd0, yd1 = pos_dst[1] data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :] self.buffer_dirty = False def select(self, x, y, idx, width=0, height=0): if width==0: w, h = self.width, self.height else: w, h = width, height lst = [(0, 0), (0, h), (w, 0), (w, h)] if idx == 0: x0, y0 = x % self.patch_size, y % self.patch_size x1 = min(x0 + w, self.patch_size) y1 = min(y0 + h, self.patch_size) elif idx == 1: y += h x0, y0 = x % self.patch_size, y % self.patch_size x1 = min(x0 + w, self.patch_size) y1 = max(y0 - h, 0) elif idx == 2: x += w x0, y0 = x % self.patch_size, y % self.patch_size x1 = max(x0 - w, 0) y1 = min(y0 + h, self.patch_size) else: x += w y += h x0, y0 = x % self.patch_size, y % self.patch_size x1 = max(x0 - w, 0) y1 = max(y0 - h, 0) xi, yi = x // self.patch_size, y // self.patch_size cur = self.data.setdefault( (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8) ) x0_img, y0_img = lst[idx] x1_img = x0_img + x1 - x0 y1_img = y0_img + y1 - y0 sort = lambda a, b: ((a, b) if a < b else (b, a)) return ( (sort(x0, x1), sort(y0, y1)), (sort(x0_img, x1_img), sort(y0_img, y1_img)), cur, ) def draw_buffer(self): self.canvas[1].clear() self.canvas[1].put_image_data(self.buffer, 0, 0) def fill_selection(self, img): self.sel_buffer = img self.sel_dirty = True def draw_selection_box(self): x0, y0 = self.cursor w, h = self.selection_size_w, self.selection_size_h if self.sel_dirty: self.canvas[2].clear() self.canvas[2].put_image_data(self.sel_buffer, x0, y0) self.canvas[-1].clear() self.canvas[-1].stroke_style = "#0a0a0a" self.canvas[-1].stroke_rect(x0, y0, w, h) self.canvas[-1].stroke_style = "#ffffff" offset=round(self.scale) if self.scale>1.0 else 1 self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2) self.canvas[-1].stroke_style = "#000000" self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4) def write_selection_to_buffer(self): x0, y0 = self.cursor x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h self.buffer[y0:y1, x0:x1] = self.sel_buffer self.sel_dirty = False self.sel_buffer = np.zeros( (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 ) self.buffer_dirty = True self.buffer_updated = True # self.canvas[2].clear() def read_selection_from_buffer(self): x0, y0 = self.cursor x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h self.sel_buffer = self.buffer[y0:y1, x0:x1] self.sel_dirty = False def base64_to_numpy(self, base64_str): try: data = base64.b64decode(str(base64_str)) pil = Image.open(io.BytesIO(data)) arr = np.array(pil) ret = arr except: ret = np.tile( np.array([255, 0, 0, 255], dtype=np.uint8), (self.selection_size_h, self.selection_size_w, 1), ) return ret def numpy_to_base64(self, arr): out_pil = Image.fromarray(arr) out_buffer = io.BytesIO() out_pil.save(out_buffer, format="PNG") out_buffer.seek(0) base64_bytes = base64.b64encode(out_buffer.read()) base64_str = base64_bytes.decode("ascii") return base64_str def sync_to_data(self): if self.sel_dirty: self.write_selection_to_buffer() self.canvas[2].clear() self.draw_buffer() if self.buffer_dirty: self.buffer2data() def sync_to_buffer(self): if self.sel_dirty: self.canvas[2].clear() self.write_selection_to_buffer() self.draw_buffer() def resize(self,width,height,scale=None,**kwargs): self.display_width=width self.display_height=height for canvas in self.canvas: prepare_canvas(width=width,height=height,canvas=canvas.canvas) setup_overlay(width,height) if scale is None: scale=1 self.update_scale(scale) def save(self): self.sync_to_data() state={} state["width"]=self.display_width state["height"]=self.display_height state["selection_width"]=self.selection_size_w state["selection_height"]=self.selection_size_h state["view_pos"]=self.view_pos[:] state["cursor"]=self.cursor[:] state["scale"]=self.scale keys=list(self.data.keys()) data={} for key in keys: if self.data[key].sum()>0: data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key]) state["data"]=data return json.dumps(state) def load(self, state_json): self.reset() state=json.loads(state_json) self.display_width=state["width"] self.display_height=state["height"] self.selection_size_w=state["selection_width"] self.selection_size_h=state["selection_height"] self.view_pos=state["view_pos"][:] self.cursor=state["cursor"][:] self.scale=state["scale"] self.resize(state["width"],state["height"],scale=state["scale"]) for k,v in state["data"].items(): key=tuple(map(int,k.split(","))) self.data[key]=self.base64_to_numpy(v) self.data2buffer() self.display() def display(self): self.clear_background() self.draw_buffer() self.draw_selection_box() def reset(self): self.data.clear() self.buffer*=0 self.buffer_dirty=False self.buffer_updated=False self.sel_buffer*=0 self.sel_dirty=False self.view_pos = [0, 0] self.clear_background() for i in range(1,len(self.canvas)-1): self.canvas[i].clear() def export(self): self.sync_to_data() xmin, xmax, ymin, ymax = 0, 0, 0, 0 if len(self.data.keys()) == 0: return np.zeros( (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 ) for xi, yi in self.data.keys(): buf = self.data[(xi, yi)] if buf.sum() > 0: xmin = min(xi, xmin) xmax = max(xi, xmax) ymin = min(yi, ymin) ymax = max(yi, ymax) yn = ymax - ymin + 1 xn = xmax - xmin + 1 image = np.zeros( (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8 ) for xi, yi in self.data.keys(): buf = self.data[(xi, yi)] if buf.sum() > 0: y0 = (yi - ymin) * self.patch_size x0 = (xi - xmin) * self.patch_size image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf ylst, xlst = image[:, :, -1].nonzero() if len(ylst) > 0: yt, xt = ylst.min(), xlst.min() yb, xb = ylst.max(), xlst.max() image = image[yt : yb + 1, xt : xb + 1] return image else: return np.zeros( (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 )