lsb commited on
Commit
cd4ffa3
1 Parent(s): 2fb22dd

redact, v1

Browse files
Files changed (1) hide show
  1. app.py +54 -3
app.py CHANGED
@@ -1,7 +1,58 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return f"Hello {name}!"
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, ["string"], "string", live=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
 
3
+ import torch
4
+ from fastai.vision.all import *
5
+ from PIL import ImageFilter, ImageEnhance
6
+ from diffusers.utils import make_image_grid
7
+ from tqdm import tqdm
8
+ from diffusers import AutoPipelineForInpainting, LCMScheduler, DDIMScheduler
9
+ from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel
10
+ import numpy as np
11
+ from PIL import Image
12
 
13
+ preferred_dtype = torch.float32
14
+ preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
17
+
18
+ segmodel = load_learner("camvid-256.pkl")
19
+
20
+ seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
21
+ 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
22
+ 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',
23
+ 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',
24
+ 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',
25
+ 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',
26
+ 'VegetationMisc', 'Void', 'Wall']
27
+
28
+ ban_cars_mask = np.array([0, 0, 0, 0, 0, 1,
29
+ 0, 0, 1, 0, 1,
30
+ 1, 1, 0, 0,
31
+ 1, 0, 1, 1, 1,
32
+ 1, 0, 1, 1,
33
+ 1, 0, 0, 0, 1,
34
+ 0, 1, 0], dtype=np.uint8)
35
+
36
+ def get_seg_mask(img):
37
+ mask = segmodel.predict(img)[0]
38
+ return mask
39
+
40
+ def display_mask(img, mask):
41
+ # Convert the grayscale mask to RGB
42
+ mask_rgb = np.stack([np.zeros_like(mask), mask, np.zeros_like(mask)], axis=-1)
43
+ # Convert the image to PIL format
44
+ img_pil = Image.fromarray(img)
45
+ # Convert the mask to PIL format
46
+ mask_pil = Image.fromarray((mask_rgb * 255).astype(np.uint8))
47
+ # Overlay the mask on the image
48
+ overlaid_img = Image.blend(img_pil, mask_pil, alpha=0.5)
49
+ return overlaid_img
50
+
51
+ def redact_image(img):
52
+ img = img.resize((256, 256))
53
+ mask = get_seg_mask(img)
54
+ car_mask = ban_cars_mask[mask]
55
+ return display_mask(img, car_mask)
56
+
57
+ iface = gr.Interface(fn=redact_image, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
58
  iface.launch()