simonJJJ commited on
Commit
835f433
1 Parent(s): 7727b21

update draw box

Browse files
Files changed (1) hide show
  1. tokenization_qwen.py +146 -14
tokenization_qwen.py CHANGED
@@ -19,6 +19,10 @@ from PIL import ImageFont
19
  from PIL import ImageDraw
20
  from transformers import PreTrainedTokenizer, AddedToken
21
 
 
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
@@ -393,12 +397,8 @@ class QWenTokenizer(PreTrainedTokenizer):
393
  bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
394
  assert len(bbox) == 4
395
  output.append({'box': bbox})
396
-
397
- ref_idx = i - 1
398
- while ref_idx >= 0 and 'box' in list_format[ref_idx]:
399
- ref_idx -= 1
400
- if ref_idx >= 0 and 'ref' in list_format[ref_idx]:
401
- output[-1]['ref'] = list_format[ref_idx]['ref'].strip()
402
  return output
403
 
404
  def draw_bbox_on_latest_picture(
@@ -412,21 +412,153 @@ class QWenTokenizer(PreTrainedTokenizer):
412
  if image.startswith("http://") or image.startswith("https://"):
413
  image = Image.open(requests.get(image, stream=True).raw)
414
  else:
415
- image = Image.open(image)
416
- h, w = image.height, image.width
417
- image = image.convert("RGB")
 
 
 
418
 
419
  boxes = self._fetch_all_box_with_ref(response)
420
  if not boxes:
421
  return None
422
- fnt = ImageFont.truetype("SimSun.ttf", 50)
423
- draw = ImageDraw.Draw(image)
 
424
  for box in boxes:
 
 
425
  x1, y1, x2, y2 = box['box']
426
  x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
427
- draw.rectangle((x1, y1, x2, y2), outline='red', width=4)
428
  if 'ref' in box:
429
- draw.text((x1, y1), box['ref'], fill='yellow', font=fnt)
430
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
 
 
 
 
19
  from PIL import ImageDraw
20
  from transformers import PreTrainedTokenizer, AddedToken
21
 
22
+ import matplotlib.pyplot as plt
23
+ import matplotlib.colors as mcolors
24
+ from matplotlib.font_manager import FontProperties
25
+
26
  logger = logging.getLogger(__name__)
27
 
28
 
 
397
  bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
398
  assert len(bbox) == 4
399
  output.append({'box': bbox})
400
+ if i > 0 and 'ref' in list_format[i-1]:
401
+ output[-1]['ref'] = list_format[i-1]['ref'].strip()
 
 
 
 
402
  return output
403
 
404
  def draw_bbox_on_latest_picture(
 
412
  if image.startswith("http://") or image.startswith("https://"):
413
  image = Image.open(requests.get(image, stream=True).raw)
414
  else:
415
+ # image = Image.open(image)
416
+ image = plt.imread(image)
417
+ # h, w = image.height, image.width
418
+ # image = image.convert("RGB")
419
+ h, w = image.shape[0], image.shape[1]
420
+ visualizer = Visualizer(image)
421
 
422
  boxes = self._fetch_all_box_with_ref(response)
423
  if not boxes:
424
  return None
425
+ # fnt = ImageFont.truetype("SimSun.ttf", 50)
426
+ # draw = ImageDraw.Draw(image)
427
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
428
  for box in boxes:
429
+ if 'ref' in box: # random new color for new refexps
430
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
431
  x1, y1, x2, y2 = box['box']
432
  x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
433
+ visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
434
  if 'ref' in box:
435
+ visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
436
+ return visualizer.output
437
+
438
+
439
+ import colorsys
440
+ import logging
441
+ import math
442
+ import numpy as np
443
+ import matplotlib as mpl
444
+ import matplotlib.colors as mplc
445
+ import matplotlib.figure as mplfigure
446
+ import torch
447
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
448
+ from PIL import Image
449
+ import random
450
+
451
+ logger = logging.getLogger(__name__)
452
+
453
+
454
+ class VisImage:
455
+ def __init__(self, img, scale=1.0):
456
+ self.img = img
457
+ self.scale = scale
458
+ self.width, self.height = img.shape[1], img.shape[0]
459
+ self._setup_figure(img)
460
+
461
+ def _setup_figure(self, img):
462
+ fig = mplfigure.Figure(frameon=False)
463
+ self.dpi = fig.get_dpi()
464
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
465
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
466
+ fig.set_size_inches(
467
+ (self.width * self.scale + 1e-2) / self.dpi,
468
+ (self.height * self.scale + 1e-2) / self.dpi,
469
+ )
470
+ self.canvas = FigureCanvasAgg(fig)
471
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
472
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
473
+ ax.axis("off")
474
+ self.fig = fig
475
+ self.ax = ax
476
+ self.reset_image(img)
477
+
478
+ def reset_image(self, img):
479
+ img = img.astype("uint8")
480
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
481
 
482
+ def save(self, filepath):
483
+ self.fig.savefig(filepath)
484
+
485
+ def get_image(self):
486
+ canvas = self.canvas
487
+ s, (width, height) = canvas.print_to_buffer()
488
+
489
+ buffer = np.frombuffer(s, dtype="uint8")
490
+
491
+ img_rgba = buffer.reshape(height, width, 4)
492
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
493
+ return rgb.astype("uint8")
494
+
495
+
496
+ class Visualizer:
497
+ def __init__(self, img_rgb, metadata=None, scale=1.0):
498
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
499
+ self.output = VisImage(self.img, scale=scale)
500
+ self.cpu_device = torch.device("cpu")
501
+
502
+ # too small texts are useless, therefore clamp to 14
503
+ self._default_font_size = max(
504
+ np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
505
+ )
506
+
507
+ def draw_text(
508
+ self,
509
+ text,
510
+ position,
511
+ *,
512
+ font_size=None,
513
+ color="g",
514
+ horizontal_alignment="center",
515
+ rotation=0,
516
+ ):
517
+ if not font_size:
518
+ font_size = self._default_font_size
519
+
520
+ # since the text background is dark, we don't want the text to be dark
521
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
522
+ color[np.argmax(color)] = max(0.8, np.max(color))
523
+
524
+ x, y = position
525
+ self.output.ax.text(
526
+ x,
527
+ y,
528
+ text,
529
+ size=font_size * self.output.scale,
530
+ fontproperties=FontProperties(fname=r"SimSun.ttf"),
531
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
532
+ verticalalignment="top",
533
+ horizontalalignment=horizontal_alignment,
534
+ color=color,
535
+ zorder=10,
536
+ rotation=rotation,
537
+ )
538
+ return self.output
539
+
540
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
541
+
542
+ x0, y0, x1, y1 = box_coord
543
+ width = x1 - x0
544
+ height = y1 - y0
545
+
546
+ linewidth = max(self._default_font_size / 4, 1)
547
+
548
+ self.output.ax.add_patch(
549
+ mpl.patches.Rectangle(
550
+ (x0, y0),
551
+ width,
552
+ height,
553
+ fill=False,
554
+ edgecolor=edge_color,
555
+ linewidth=linewidth * self.output.scale,
556
+ alpha=alpha,
557
+ linestyle=line_style,
558
+ )
559
+ )
560
+ return self.output
561
 
562
+ def get_output(self):
563
+
564
+ return self.output