Spaces:
Running
Running
File size: 6,641 Bytes
b74625d 00506cc b74625d 4bb2e87 b74625d 76ea628 b74625d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import gradio as gr
import torch
from torch import nn
import torchvision.transforms as T
from linea.models import build_linea
from linea.util.slconfig import DictAction, SLConfig
from PIL import Image, ImageDraw
LINEA_MODELS = {
"LINEA-N": './linea/configs/linea/linea_hgnetv2_n.py',
"LINEA-S": './linea/configs/linea/linea_hgnetv2_s.py',
"LINEA-M": './linea/configs/linea/linea_hgnetv2_m.py',
"LINEA-L": './linea/configs/linea/linea_hgnetv2_l.py'
}
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]),
]
)
example_images = [
["assets/example1.jpg"],
["assets/example2.jpg"],
["assets/example3.jpg"],
["assets/example4.jpg"],
]
description = """
<h1 align="center">
<ins>LINEA</ins>
<br>
Fast and accurate line detection using scalable transformers
</h1>
<h2 align="center">
<a href="https://www.linkedin.com/in/sebastianjr/">Sebastian Janampa</a>
and
<a href="https://www.linkedin.com/in/marios-pattichis-207b0119/">Marios Pattichis</a>
</h2>
<h2 align="center">
<a href="https://github.com/SebastianJanampa/LINEA.git">GitHub</a> |
<a href="https://colab.research.google.com/github/SebastianJanampa/LINEA/blob/master/LINEA_tutorial.ipynb">Colab</a>
</h2>
## Getting Started
LINEA is a family of transformers models that detectes the line segments on an image.
Its key component is its new attention mechanism called **line attention**.
To get started, upload an image or select one of the examples below.
You can choose between different model size, change the confidence threshold and visualize the results.
"""
def create_model(model_size):
cfg = SLConfig.fromfile(LINEA_MODELS[model_size])
cfg.pretrained = False
model, postprocessor = build_linea(cfg)
letter = model_size[-1].lower()
url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth"
)
model.load_state_dict(state_dict['model'], strict=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = model.deploy()
self.postprocessor = postprocessor.deploy()
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
model = Model()
model.eval()
return model
def draw(images, lines, scores, thrh):
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
line = lines[i][scr > thrh]
scrs = scr[scr > thrh]
for j, l in enumerate(line):
draw.line(list(l), fill="red", width=5)
draw.text(
(l[0], l[1]),
text=f"{round(scrs[j].item(), 2)}",
fill="blue",
)
return images
def filter(lines, scores, threshold):
filtered_lines, filter_scores = [], []
for line, scr in zip(lines, scores):
idx = scr > threshold
filtered_lines.append(line[idx])
filter_scores.append(scr[idx])
return filtered_lines, filter_scores
def format_output(lines, scores):
n = len(lines[0])
txt = f"{n} lines were detected\n"
txt += "Detected lines:\n"
for line, scr in zip(lines[0], scores[0]):
txt += f"\tx1: {line[0].item():.2f}"
txt += f"\ty1: {line[1].item():.2f}"
txt += f"\tx2: {line[2].item():.2f}"
txt += f"\ty2: {line[3].item():.2f}"
txt += f"\tscore: {scr.item():.2f}\n"
return txt
def process_results(
image_path,
model_size,
threshold
):
""" Process the image an returns the detected lines """
if image_path is None:
raise gr.Error("Please upload an image first.")
model = create_model(model_size)
im_pil = Image.open(image_path).convert("RGB")
w, h = im_pil.size
orig_size = torch.tensor([[w, h]])
im_data = transforms(im_pil).unsqueeze(0)
output = model(im_data, orig_size)
lines, scores = output
result_images = draw([im_pil], lines, scores, thrh=threshold)
filtered_lines, filtered_scores = filter(lines, scores, threshold)
return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores)
def update_threshold(
image_path,
raw_results,
threshold
):
lines, scores = raw_results
im_pil = Image.open(image_path).convert("RGB")
result_images = draw([im_pil], lines, scores, thrh=threshold)
filtered_lines, filtered_scores = filter(lines, scores, threshold)
return format_output(filtered_lines, filtered_scores), result_images[0]
def update_model(
image_path,
model_size,
threshold
):
create_model(model_size)
if image_path is None:
raise gr.Error("Please upload an image first.")
return None, None, None
return process_results(image_path, model_size, threshold)
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("""## Input Image""")
image_path = gr.Image(label="Upload image", type="filepath")
model_size = gr.Dropdown(
choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M"
)
threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
value=0.30,
)
submit_btn = gr.Button("Detect Lines")
gr.Examples(examples=example_images, inputs=[image_path, model_size])
with gr.Column():
gr.Markdown("""## Results""")
image_output = gr.Image(label="Detected Lines")
text_output = gr.Textbox(label="Predicted lines", type="text", lines=5)
# Define the action when the button is clicked
raw_results = gr.State()
plot_inputs = [
raw_results,
threshold
]
submit_btn.click(
fn=process_results,
inputs=[image_path, model_size] + plot_inputs[1:],
outputs=[text_output, image_output, raw_results],
)
# Define the action when the plot checkboxes are clicked
threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output])
demo.launch() |