gauge / test_rect.py
tjw's picture
remove ipython.display
a68f143
raw
history blame
4.88 kB
# %%
import matplotlib.style
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
from pathlib import Path
from PIL import Image
from PIL import ImageDraw
import numpy as np
from collections import namedtuple
import sys
print(sys.version_info)
#%%
class Florence:
def __init__(self, model_id:str, hack=False):
if hack:
return
self.model = (
AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, torch_dtype="auto"
)
.eval()
.cuda()
)
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.model_id = model_id
def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
model, processor = self.model, self.processor
prompt = task_prompt + (extra_text if extra_text else "")
inputs = processor(text=prompt, images=img, return_tensors="pt").to(
"cuda", torch.float16
)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(img.width, img.height),
)
return parsed_answer
def model_init():
fl = Florence("microsoft/Florence-2-large", hack=False)
fl_ft = Florence("microsoft/Florence-2-large-ft", hack=False)
return fl, fl_ft
# florence-2 tasks
TASK_OD = "<OD>"
TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
TASK_OCR = "<OCR_WITH_REGION>"
TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
#%%
from skimage.measure import LineModelND, ransac
def get_polygons(fl:Florence, img2:Image, prompt):
parsed_answer = fl.run(img2, TASK_SEGMENTATION, prompt)
assert len(parsed_answer) == 1
k,v = parsed_answer.popitem()
assert 'polygons' in v
assert len(v['polygons']) == 1
polygons = v['polygons'][0]
return polygons
def get_ocr(fl:Florence, img2:Image):
parsed_answer = fl.run(img2, TASK_OCR)
assert len(parsed_answer)==1
k,v = parsed_answer.popitem()
return v
imgs = list(Path('images/other').glob('*.jpg'))
meter_labels = list(map(str, range(0, 600, 100)))
def read_meter(img, fl:Florence, fl_ft:Florence):
if isinstance(img, str) or isinstance(img, Path):
print(img)
img = Image.open(img)
red_polygons = get_polygons(fl, img, 'red triangle pointer')
# draw the rectangle
draw = ImageDraw.Draw(img)
ocr_text = {}
ocr1 = get_ocr(fl, img)
ocr2 = get_ocr(fl_ft, img)
quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
labels = ocr1['labels']+ocr2['labels']
for quad_box, label in zip(quad_boxes, labels):
if label in meter_labels:
ocr_text[int(label)] = quad_box
for label, quad_box in ocr_text.items():
draw.polygon(quad_box, outline='green', width=3)
draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')
text_centers = np.array(list(ocr_text.values())).reshape(-1, 4, 2).mean(axis=1)
lm = LineModelND()
lm.estimate(text_centers)
orign, direction = lm.params
# project text centers to the line
text_centers_shifted = text_centers - orign
text_centers_norm = text_centers_shifted @ direction
lm2 = LineModelND()
I = np.array(list(ocr_text.keys()))
L = text_centers_norm
data = np.stack([I, L], axis=1)
lm2.estimate(data)
ls = lm2.predict(list(range(0, 600, 100)))[:, 1]
x0, y0 = ls[0] * direction + orign
x1, y1 = ls[-1] * direction + orign
draw.line((x0, y0, x1, y1), fill='yellow', width=3)
for l in ls:
x, y = l * direction + orign
draw.ellipse((x-5, y-5, x+5, y+5), outline='yellow', width=3)
red_coords = np.concatenate(red_polygons).reshape(-1, 2)
red_shifted = red_coords - orign
red_norm = red_shifted @ direction
red_l = red_norm.mean()
red_i = np.clip(lm2.predict_x([red_l]), 0, 500)
red_l = lm2.predict_y(red_i)[0]
red_center = red_l * direction + orign
draw.ellipse((red_center[0]-5, red_center[1]-5, red_center[0]+5, red_center[1]+5), outline='red', width=3)
return red_i[0], img
def main():
fl, fl_ft = model_init()
for img_fn in imgs:
print(img_fn)
img = Image.open(img_fn)
red_i, img2 = read_meater(img, fl, fl_ft)
print(red_i)
display(img2)
if __name__ == '__main__':
from IPython.display import display
main()
#%%