Update modeling_GOT.py
Browse files- modeling_GOT.py +72 -74
modeling_GOT.py
CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
@@ -575,87 +575,86 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
575 |
)
|
576 |
|
577 |
|
578 |
-
|
579 |
-
|
|
|
580 |
|
581 |
-
|
582 |
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
|
605 |
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
# left_num = outputs.count('\left')
|
611 |
|
612 |
-
|
613 |
-
|
614 |
|
615 |
|
616 |
-
|
617 |
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
|
651 |
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
# new_web = lines[0] + gt + lines[1]
|
656 |
|
657 |
-
|
658 |
-
|
659 |
|
660 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
661 |
|
@@ -807,13 +806,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
807 |
|
808 |
if render:
|
809 |
print('==============rendering===============')
|
|
|
810 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
811 |
|
812 |
if outputs.endswith(stop_str):
|
813 |
outputs = outputs[:-len(stop_str)]
|
814 |
outputs = outputs.strip()
|
815 |
|
816 |
-
html_path = "./render_tools/" + "content-mmd-to-html.html"
|
817 |
html_path_2 = save_render_file
|
818 |
right_num = outputs.count('\\right')
|
819 |
left_num = outputs.count('\left')
|
@@ -831,10 +830,9 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
831 |
|
832 |
gt = gt[:-2]
|
833 |
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
new_web = lines[0] + 'const text =' + gt + lines[1]
|
838 |
|
839 |
with smart_open(html_path_2, 'w') as web_f_new:
|
840 |
web_f_new.write(new_web)
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
575 |
)
|
576 |
|
577 |
|
578 |
+
if render:
|
579 |
+
print('==============rendering===============')
|
580 |
+
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
581 |
|
582 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
583 |
|
584 |
+
if outputs.endswith(stop_str):
|
585 |
+
outputs = outputs[:-len(stop_str)]
|
586 |
+
outputs = outputs.strip()
|
587 |
+
|
588 |
+
if '**kern' in outputs:
|
589 |
+
import verovio
|
590 |
+
from cairosvg import svg2png
|
591 |
+
import cv2
|
592 |
+
import numpy as np
|
593 |
+
tk = verovio.toolkit()
|
594 |
+
tk.loadData(outputs)
|
595 |
+
tk.setOptions({"pageWidth": 2100, "footer": 'none',
|
596 |
+
'barLineWidth': 0.5, 'beamMaxSlope': 15,
|
597 |
+
'staffLineWidth': 0.2, 'spacingStaff': 6})
|
598 |
+
tk.getPageCount()
|
599 |
+
svg = tk.renderToSVG()
|
600 |
+
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
601 |
+
|
602 |
+
svg_to_html(svg, save_render_file)
|
603 |
+
|
604 |
+
if ocr_type == 'format' and '**kern' not in outputs:
|
605 |
|
606 |
|
607 |
+
if '\\begin{tikzpicture}' not in outputs:
|
608 |
+
html_path_2 = save_render_file
|
609 |
+
right_num = outputs.count('\\right')
|
610 |
+
left_num = outputs.count('\left')
|
|
|
611 |
|
612 |
+
if right_num != left_num:
|
613 |
+
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
|
614 |
|
615 |
|
616 |
+
outputs = outputs.replace('"', '``').replace('$', '')
|
617 |
|
618 |
+
outputs_list = outputs.split('\n')
|
619 |
+
gt= ''
|
620 |
+
for out in outputs_list:
|
621 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
622 |
|
623 |
+
gt = gt[:-2]
|
624 |
+
|
625 |
+
|
626 |
+
lines = content_mmd_to_html
|
627 |
+
lines = lines.split("const text =")
|
628 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
629 |
+
|
630 |
+
else:
|
631 |
+
html_path_2 = save_render_file
|
632 |
+
outputs = outputs.translate(translation_table)
|
633 |
+
outputs_list = outputs.split('\n')
|
634 |
+
gt= ''
|
635 |
+
for out in outputs_list:
|
636 |
+
if out:
|
637 |
+
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
638 |
+
while out[-1] == ' ':
|
639 |
+
out = out[:-1]
|
640 |
+
if out is None:
|
641 |
+
break
|
642 |
|
643 |
+
if out:
|
644 |
+
if out[-1] != ';':
|
645 |
+
gt += out[:-1] + ';\n'
|
646 |
+
else:
|
647 |
+
gt += out + '\n'
|
648 |
+
else:
|
649 |
+
gt += out + '\n'
|
650 |
|
651 |
|
652 |
+
lines = tik_html
|
653 |
+
lines = lines.split("const text =")
|
654 |
+
new_web = lines[0] + gt + lines[1]
|
|
|
655 |
|
656 |
+
with smart_open(html_path_2, 'w') as web_f_new:
|
657 |
+
web_f_new.write(new_web)
|
658 |
|
659 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
660 |
|
|
|
806 |
|
807 |
if render:
|
808 |
print('==============rendering===============')
|
809 |
+
from .render_tools import content_mmd_to_html
|
810 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
811 |
|
812 |
if outputs.endswith(stop_str):
|
813 |
outputs = outputs[:-len(stop_str)]
|
814 |
outputs = outputs.strip()
|
815 |
|
|
|
816 |
html_path_2 = save_render_file
|
817 |
right_num = outputs.count('\\right')
|
818 |
left_num = outputs.count('\left')
|
|
|
830 |
|
831 |
gt = gt[:-2]
|
832 |
|
833 |
+
lines = content_mmd_to_html
|
834 |
+
lines = lines.split("const text =")
|
835 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
|
|
836 |
|
837 |
with smart_open(html_path_2, 'w') as web_f_new:
|
838 |
web_f_new.write(new_web)
|