lsb commited on
Commit
829ed2e
1 Parent(s): 12d035e

fastai unet to segformer

Browse files
Files changed (1) hide show
  1. app.py +49 -36
app.py CHANGED
@@ -1,79 +1,92 @@
1
  import gradio as gr
2
 
3
  import torch
4
- from fastai.vision.all import *
5
- from PIL import ImageFilter, ImageEnhance, ImageDraw
6
- from diffusers.utils import make_image_grid
 
7
  from tqdm import tqdm
8
- from diffusers import AutoPipelineForInpainting, LCMScheduler, DDIMScheduler
9
- from diffusers import StableDiffusionInpaintPipeline, ControlNetModel
10
  import numpy as np
11
- from PIL import Image
12
  from datetime import datetime
13
 
14
- preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
 
 
 
15
  preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
16
 
17
- def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
18
-
19
- segmodel = load_learner("camvid-512.pkl")
20
 
21
- if preferred_dtype == torch.float16:
22
- segmodel = segmodel.to_fp16()
 
 
23
 
24
  inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
25
  "runwayml/stable-diffusion-inpainting",
26
  variant="fp16",
27
  torch_dtype=preferred_dtype,
 
28
  ).to(preferred_device)
29
 
30
- working_size = (512, 512)
 
 
 
 
31
 
32
  default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
33
 
34
- seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
35
- 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
36
- 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',
37
- 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',
38
- 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',
39
- 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',
40
- 'VegetationMisc', 'Void', 'Wall']
41
-
42
- ban_cars_mask = np.array([0, 0, 0, 0, 0, 1,
43
- 0, 0, 1, 0, 1,
44
- 1, 1, 0, 0,
45
- 1, 0, 1, 1, 1,
46
- 1, 0, 1, 1,
47
- 1, 0, 0, 0, 1,
48
- 0, 1, 0], dtype=np.uint8)
49
 
50
  def get_seg_mask(img):
51
- mask = segmodel.predict(img)[0]
52
- return mask
 
 
 
 
53
 
54
 
55
  def app(img, prompt):
56
  start_time = datetime.now().timestamp()
57
  old_size = Image.fromarray(img).size
58
  img = np.array(Image.fromarray(img).resize(working_size))
59
- mask = ban_cars_mask[get_seg_mask(img)] * 255
 
60
  mask_time = datetime.now().timestamp()
61
- print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape)
62
  overlay_img = inpainting_pipeline(
63
  prompt=prompt,
64
- image=img,
65
- mask_image=mask,
66
  strength=0.95,
67
- num_inference_steps=20,
68
  ).images[0]
 
69
  end_time = datetime.now().timestamp()
70
  draw = ImageDraw.Draw(overlay_img)
71
  # replace spaces with newlines after many words to line break prompt
72
  prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
73
 
74
- draw.text((50, 10), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(123, 0, 123))
 
75
  return overlay_img
76
 
 
 
 
 
 
77
  #ideally:
78
  #iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
79
  iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
 
1
  import gradio as gr
2
 
3
  import torch
4
+ from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
5
+ from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
6
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
7
+
8
  from tqdm import tqdm
 
 
9
  import numpy as np
 
10
  from datetime import datetime
11
 
12
+ # ideally:
13
+ # preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
14
+ # but segformer does not work on mps lolololol
15
+ preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
16
  preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
17
 
18
+ seg_model_img_size = 768
19
+ seg_model_size = 0
 
20
 
21
+ seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
22
+ seg_model = SegformerForSemanticSegmentation.from_pretrained(
23
+ f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
24
+ ).to(preferred_device).to(preferred_dtype)
25
 
26
  inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
27
  "runwayml/stable-diffusion-inpainting",
28
  variant="fp16",
29
  torch_dtype=preferred_dtype,
30
+ safety_checker=None,
31
  ).to(preferred_device)
32
 
33
+ inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config)
34
+ inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
35
+ inpainting_pipeline.fuse_lora()
36
+
37
+ working_size = (seg_model_img_size, seg_model_img_size)
38
 
39
  default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
40
 
41
+ seg_vocabulary = seg_model.config.label2id
42
+ print(f"vocab: {seg_vocabulary}")
43
+
44
+ ban_cars_mask = [0] * len(seg_vocabulary)
45
+ banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
46
+ for c in banned_classes:
47
+ ban_cars_mask[seg_vocabulary[c]] = 1
48
+ ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
49
+
 
 
 
 
 
 
50
 
51
  def get_seg_mask(img):
52
+ inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
53
+ outputs = seg_model(**inputs)
54
+ logits = outputs.logits[0]
55
+ mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
56
+ blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000)
57
+ return blurred_widened_mask
58
 
59
 
60
  def app(img, prompt):
61
  start_time = datetime.now().timestamp()
62
  old_size = Image.fromarray(img).size
63
  img = np.array(Image.fromarray(img).resize(working_size))
64
+ mask = get_seg_mask(img)
65
+ mask.save("mask.jpg")
66
  mask_time = datetime.now().timestamp()
67
+ #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
68
  overlay_img = inpainting_pipeline(
69
  prompt=prompt,
70
+ image=Image.fromarray(img),
71
+ mask_image=(mask),
72
  strength=0.95,
73
+ num_inference_steps=4,
74
  ).images[0]
75
+ #overlay_img.save("overlay_raw.jpg")
76
  end_time = datetime.now().timestamp()
77
  draw = ImageDraw.Draw(overlay_img)
78
  # replace spaces with newlines after many words to line break prompt
79
  prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
80
 
81
+ draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
82
+ #overlay_img.save("overlay_with_text.jpg")
83
  return overlay_img
84
 
85
+ ### kick the tires before we start
86
+
87
+ for i in tqdm(range(2)):
88
+ app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg")
89
+
90
  #ideally:
91
  #iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
92
  iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")