ttengwang commited on
Commit
108f2df
1 Parent(s): b7e072a

share ocr_reader to accelerate inferenec

Browse files
app.py CHANGED
@@ -17,7 +17,7 @@ from caption_anything.text_refiner import build_text_refiner
17
  from caption_anything.segmenter import build_segmenter
18
  from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
  from segment_anything import sam_model_registry
20
-
21
 
22
  args = parse_augment()
23
  args.segmenter = "huge"
@@ -30,6 +30,8 @@ else:
30
 
31
  shared_captioner = build_captioner(args.captioner, args.device, args)
32
  shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
 
 
33
  tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
34
  shared_chatbot_tools = build_chatbot_tools(tools_dict)
35
 
@@ -57,13 +59,13 @@ class ImageSketcher(gr.Image):
57
  return super().preprocess(x)
58
 
59
 
60
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
61
  session_id=None):
62
  segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
63
  captioner = captioner
64
  if session_id is not None:
65
  print('Init caption anything for session {}'.format(session_id))
66
- return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
67
 
68
 
69
  def init_openai_api_key(api_key=""):
@@ -146,6 +148,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
146
  api_key="",
147
  captioner=shared_captioner,
148
  sam_model=shared_sam_model,
 
149
  session_id=iface.app_id
150
  )
151
  model.segmenter.set_image(image_input)
@@ -154,6 +157,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
154
  input_size = model.input_size
155
 
156
  if visual_chatgpt is not None:
 
157
  new_image_path = get_new_image_name('chat_image', func_name='upload')
158
  image_input.save(new_image_path)
159
  visual_chatgpt.current_image = new_image_path
@@ -192,6 +196,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
192
  api_key="",
193
  captioner=shared_captioner,
194
  sam_model=shared_sam_model,
 
195
  text_refiner=text_refiner,
196
  session_id=iface.app_id
197
  )
@@ -213,6 +218,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
213
  x, y = input_points[-1]
214
 
215
  if visual_chatgpt is not None:
 
216
  new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
217
  Image.open(out["crop_save_path"]).save(new_crop_save_path)
218
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
@@ -273,6 +279,7 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
273
  api_key="",
274
  captioner=shared_captioner,
275
  sam_model=shared_sam_model,
 
276
  text_refiner=text_refiner,
277
  session_id=iface.app_id
278
  )
@@ -325,6 +332,7 @@ def cap_everything(image_input, visual_chatgpt, text_refiner):
325
  api_key="",
326
  captioner=shared_captioner,
327
  sam_model=shared_sam_model,
 
328
  text_refiner=text_refiner,
329
  session_id=iface.app_id
330
  )
 
17
  from caption_anything.segmenter import build_segmenter
18
  from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
  from segment_anything import sam_model_registry
20
+ import easyocr
21
 
22
  args = parse_augment()
23
  args.segmenter = "huge"
 
30
 
31
  shared_captioner = build_captioner(args.captioner, args.device, args)
32
  shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
33
+ ocr_lang = ["ch_tra", "en"]
34
+ shared_ocr_reader = easyocr.Reader(ocr_lang)
35
  tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
36
  shared_chatbot_tools = build_chatbot_tools(tools_dict)
37
 
 
59
  return super().preprocess(x)
60
 
61
 
62
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
63
  session_id=None):
64
  segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
65
  captioner = captioner
66
  if session_id is not None:
67
  print('Init caption anything for session {}'.format(session_id))
68
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
69
 
70
 
71
  def init_openai_api_key(api_key=""):
 
148
  api_key="",
149
  captioner=shared_captioner,
150
  sam_model=shared_sam_model,
151
+ ocr_reader=shared_ocr_reader,
152
  session_id=iface.app_id
153
  )
154
  model.segmenter.set_image(image_input)
 
157
  input_size = model.input_size
158
 
159
  if visual_chatgpt is not None:
160
+ print('upload_callback: add caption to chatGPT memory')
161
  new_image_path = get_new_image_name('chat_image', func_name='upload')
162
  image_input.save(new_image_path)
163
  visual_chatgpt.current_image = new_image_path
 
196
  api_key="",
197
  captioner=shared_captioner,
198
  sam_model=shared_sam_model,
199
+ ocr_reader=shared_ocr_reader,
200
  text_refiner=text_refiner,
201
  session_id=iface.app_id
202
  )
 
218
  x, y = input_points[-1]
219
 
220
  if visual_chatgpt is not None:
221
+ print('inference_click: add caption to chatGPT memory')
222
  new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
223
  Image.open(out["crop_save_path"]).save(new_crop_save_path)
224
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
 
279
  api_key="",
280
  captioner=shared_captioner,
281
  sam_model=shared_sam_model,
282
+ ocr_reader=shared_ocr_reader,
283
  text_refiner=text_refiner,
284
  session_id=iface.app_id
285
  )
 
332
  api_key="",
333
  captioner=shared_captioner,
334
  sam_model=shared_sam_model,
335
+ ocr_reader=shared_ocr_reader,
336
  text_refiner=text_refiner,
337
  session_id=iface.app_id
338
  )
caption_anything/captioner/blip2.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoProcessor, Blip2ForConditionalGeneration
6
 
7
  from caption_anything.utils.utils import is_platform_win, load_image
8
  from .base_captioner import BaseCaptioner
 
9
 
10
  class BLIP2Captioner(BaseCaptioner):
11
  def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
@@ -33,8 +34,7 @@ class BLIP2Captioner(BaseCaptioner):
33
  if not self.dialogue:
34
  inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
35
  out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
36
- captions = self.processor.batch_decode(out.sequences, skip_special_tokens=True)
37
- caption = [caption.strip() for caption in captions][0]
38
  if self.enable_filter and filter:
39
  print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
40
  clip_score = self.filter_caption(image, caption, args['reference_caption'])
 
6
 
7
  from caption_anything.utils.utils import is_platform_win, load_image
8
  from .base_captioner import BaseCaptioner
9
+ import time
10
 
11
  class BLIP2Captioner(BaseCaptioner):
12
  def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
 
34
  if not self.dialogue:
35
  inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
36
  out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
37
+ caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
 
38
  if self.enable_filter and filter:
39
  print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
40
  clip_score = self.filter_caption(image, caption, args['reference_caption'])
caption_anything/model.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  from PIL import Image
9
  import easyocr
10
  import copy
 
11
  from caption_anything.captioner import build_captioner, BaseCaptioner
12
  from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
13
  from caption_anything.text_refiner import build_text_refiner
@@ -16,14 +17,15 @@ from caption_anything.utils.utils import mask_painter_foreground_all, mask_paint
16
  from caption_anything.utils.densecap_painter import draw_bbox
17
 
18
  class CaptionAnything:
19
- def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
20
  self.args = args
21
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
22
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
23
  self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
 
 
24
 
25
- self.lang = ["ch_tra", "en"]
26
- self.reader = easyocr.Reader(self.lang)
27
  self.text_refiner = None
28
  if not args.disable_gpt:
29
  if text_refiner is not None:
@@ -31,6 +33,7 @@ class CaptionAnything:
31
  elif api_key != "":
32
  self.init_refiner(api_key)
33
  self.require_caption_prompt = args.captioner == 'blip2'
 
34
 
35
  @property
36
  def image_embedding(self):
@@ -213,7 +216,7 @@ class CaptionAnything:
213
  def parse_ocr(self, image, thres=0.2):
214
  width, height = get_image_shape(image)
215
  image = load_image(image, return_type='numpy')
216
- bounds = self.reader.readtext(image)
217
  bounds = [bound for bound in bounds if bound[2] > thres]
218
  print('Process OCR Text:\n', bounds)
219
 
@@ -257,7 +260,7 @@ class CaptionAnything:
257
  if __name__ == "__main__":
258
  from caption_anything.utils.parser import parse_augment
259
  args = parse_augment()
260
- image_path = 'image/ocr/Untitled.png'
261
  image = Image.open(image_path)
262
  prompts = [
263
  {
 
8
  from PIL import Image
9
  import easyocr
10
  import copy
11
+ import time
12
  from caption_anything.captioner import build_captioner, BaseCaptioner
13
  from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
14
  from caption_anything.text_refiner import build_text_refiner
 
17
  from caption_anything.utils.densecap_painter import draw_bbox
18
 
19
  class CaptionAnything:
20
+ def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
21
  self.args = args
22
  self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
23
  self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
24
  self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
25
+ self.ocr_lang = ["ch_tra", "en"]
26
+ self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
27
 
28
+
 
29
  self.text_refiner = None
30
  if not args.disable_gpt:
31
  if text_refiner is not None:
 
33
  elif api_key != "":
34
  self.init_refiner(api_key)
35
  self.require_caption_prompt = args.captioner == 'blip2'
36
+ print('text_refiner init time: ', time.time() - t0)
37
 
38
  @property
39
  def image_embedding(self):
 
216
  def parse_ocr(self, image, thres=0.2):
217
  width, height = get_image_shape(image)
218
  image = load_image(image, return_type='numpy')
219
+ bounds = self.ocr_reader.readtext(image)
220
  bounds = [bound for bound in bounds if bound[2] > thres]
221
  print('Process OCR Text:\n', bounds)
222
 
 
260
  if __name__ == "__main__":
261
  from caption_anything.utils.parser import parse_augment
262
  args = parse_augment()
263
+ image_path = 'result/wt/memes/87226084.jpg'
264
  image = Image.open(image_path)
265
  prompts = [
266
  {