CassianK commited on
Commit
00613e2
·
verified ·
1 Parent(s): f498e99

Upload 12 files

Browse files
DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: change modes
2
+ # Tiny: base_size = 512, image_size = 512, crop_mode = False
3
+ # Small: base_size = 640, image_size = 640, crop_mode = False
4
+ # Base: base_size = 1024, image_size = 1024, crop_mode = False
5
+ # Large: base_size = 1280, image_size = 1280, crop_mode = False
6
+ # Gundam: base_size = 1024, image_size = 640, crop_mode = True
7
+
8
+ BASE_SIZE = 1024
9
+ IMAGE_SIZE = 640
10
+ CROP_MODE = True
11
+ MIN_CROPS= 2
12
+ MAX_CROPS= 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
13
+ MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
14
+ NUM_WORKERS = 64 # image pre-process (resize/padding) workers
15
+ PRINT_NUM_VIS_TOKENS = False
16
+ SKIP_REPEAT = True
17
+ MODEL_PATH = 'deepseek-ai/DeepSeek-OCR' # change to your model path
18
+
19
+ # TODO: change INPUT_PATH
20
+ # .pdf: run_dpsk_ocr_pdf.py;
21
+ # .jpg, .png, .jpeg: run_dpsk_ocr_image.py;
22
+ # Omnidocbench images path: run_dpsk_ocr_eval_batch.py
23
+
24
+ INPUT_PATH = ''
25
+ OUTPUT_PATH = ''
26
+
27
+ PROMPT = '<image>\n<|grounding|>Convert the document to markdown.'
28
+ # PROMPT = '<image>\nFree OCR.'
29
+ # TODO commonly used prompts
30
+ # document: <image>\n<|grounding|>Convert the document to markdown.
31
+ # other image: <image>\n<|grounding|>OCR this image.
32
+ # without layouts: <image>\nFree OCR.
33
+ # figures in document: <image>\nParse the figure.
34
+ # general: <image>\nDescribe this image in detail.
35
+ # rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
36
+ # '先天下之忧而忧'
37
+ # .......
38
+
39
+
40
+ from transformers import AutoTokenizer
41
+
42
+ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/__init__.py ADDED
File without changes
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/build_linear.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import copy
5
+
6
+
7
+ class MlpProjector(nn.Module):
8
+
9
+ def __init__(self, cfg):
10
+
11
+ super().__init__()
12
+
13
+ self.cfg = cfg
14
+
15
+ if cfg.projector_type == "identity":
16
+ modules = nn.Identity()
17
+
18
+ elif cfg.projector_type == "linear":
19
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
20
+
21
+ elif cfg.projector_type == "mlp_gelu":
22
+ mlp_depth = cfg.get("depth", 1)
23
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
24
+ for _ in range(1, mlp_depth):
25
+ modules.append(nn.GELU())
26
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
27
+ modules = nn.Sequential(*modules)
28
+
29
+ elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
30
+ mlp_depth = cfg.get("depth", 1)
31
+ mlp_ratio = cfg.get("mlp_ratio", 1)
32
+ modules = [
33
+ nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
34
+ nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
35
+ ]
36
+ for _ in range(1, mlp_depth - 1):
37
+ modules.append(nn.GELU())
38
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
39
+ modules.append(nn.GELU())
40
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
41
+ modules = nn.Sequential(*modules)
42
+
43
+ elif cfg.projector_type == "downsample_mlp_gelu":
44
+ mlp_depth = cfg.get("depth", 1)
45
+ mlp_ratio = cfg.get("mlp_ratio", 1)
46
+ modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
47
+ for _ in range(1, mlp_depth - 1):
48
+ modules.append(nn.GELU())
49
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
50
+ modules.append(nn.GELU())
51
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
52
+ modules = nn.Sequential(*modules)
53
+
54
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
55
+ mlp_depth = cfg.get("depth", 1)
56
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
57
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
58
+
59
+ modules = []
60
+ for _ in range(1, mlp_depth):
61
+ modules.append(nn.GELU())
62
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
63
+ modules = nn.Sequential(*modules)
64
+
65
+ elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
66
+ mlp_depth = cfg.get("depth", 1)
67
+ channel_div = cfg.get("channel_div", 0.5)
68
+ self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
69
+ self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
70
+
71
+ modules = []
72
+ for _ in range(1, mlp_depth):
73
+ modules.append(nn.GELU())
74
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
75
+ modules = nn.Sequential(*modules)
76
+
77
+ elif cfg.projector_type == "low_high_split_mlp_gelu":
78
+ mlp_depth = cfg.get("depth", 1)
79
+ modules = []
80
+ for _ in range(1, mlp_depth):
81
+ modules.append(nn.GELU())
82
+ modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
83
+ modules = nn.Sequential(*modules)
84
+ self.high_layers = nn.Sequential(*modules)
85
+ self.low_layers = copy.deepcopy(modules)
86
+
87
+ else:
88
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
89
+
90
+ if cfg.get("token_pooling", False):
91
+ self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
92
+
93
+ if cfg.get("conv_fusion_high_low_features", False):
94
+ self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
95
+ self.layers = modules
96
+
97
+ def forward(self, x):
98
+ if self.cfg.get("token_pooling", False):
99
+ batch_size, wxh, channels = x.shape
100
+ w = h = int(wxh**0.5)
101
+ x = x.view(batch_size, w, h, channels)
102
+ x = x.permute(0, 3, 1, 2)
103
+ # import ipdb; ipdb.set_trace()
104
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
105
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
106
+ # 在通道维度上拼接
107
+ patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
108
+
109
+ # 通过线性层
110
+ patches = patches.permute(0, 2, 1, 3).contiguous()
111
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
112
+
113
+ x = self.token_pooling_layer(patches)
114
+
115
+ if self.cfg.get("conv_fusion_high_low_features", False):
116
+ x = self.fusion_layer(x[:, 0]) + x[:, 1]
117
+
118
+ if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
119
+ high_x, low_x = x[0], x[1]
120
+ high_x = self.high_up_proj(high_x)
121
+ low_x = self.low_up_proj(low_x)
122
+ x = torch.concat([high_x, low_x], dim=-1)
123
+
124
+ if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
125
+ high_x = x[...,:self.cfg.input_dim[0]]
126
+ low_x = x[...,self.cfg.input_dim[0]:]
127
+ high_x = self.high_up_proj(high_x)
128
+ low_x = self.low_up_proj(low_x)
129
+ x = torch.concat([high_x, low_x], dim=-1)
130
+
131
+ if self.cfg.projector_type == 'low_high_split_mlp_gelu':
132
+ high_x, low_x = x[0], x[1]
133
+ high_x = self.high_layers(high_x)
134
+ low_x = self.low_layers(low_x)
135
+ x = torch.concat([high_x, low_x], dim=-1)
136
+ return x
137
+
138
+ if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
139
+ bs, hw, input_dim = x.shape
140
+ h = w = int((hw) ** 0.5)
141
+
142
+ """compute padding"""
143
+ if h % self.cfg.downsample_ratio:
144
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
145
+ else:
146
+ pad = 0
147
+ x = x.reshape(bs, h, w, input_dim)
148
+ if pad > 0:
149
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
150
+
151
+ """4 to 1 concat"""
152
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
153
+ x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
154
+ x = x.permute(0, 2, 1)
155
+
156
+ return self.layers(x)
157
+
158
+ @staticmethod
159
+ def get_flops_per_sample(cfg):
160
+ if cfg.projector_type == "linear":
161
+ fwd = 2 * cfg.input_dim * cfg.n_embed
162
+
163
+ elif "mlp_gelu" in cfg.projector_type :
164
+ mlp_depth = cfg.get("depth", 1)
165
+ downsample_ratio = cfg.get("downsample_ratio", 1)
166
+ input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
167
+ input_dim = input_dim * downsample_ratio * downsample_ratio
168
+ fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
169
+ else:
170
+ fwd = 0
171
+
172
+ return fwd * 3
173
+
174
+
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/clip_sdpa.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ import math
3
+ from typing import Optional, Tuple
4
+ # from megatron.model import LayerNorm
5
+ from easydict import EasyDict as adict
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from torch import nn
9
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
10
+ # from optimus import flash_attn_func
11
+ # from megatron.core import tensor_parallel
12
+ # from megatron.core import parallel_state as mpu
13
+ # from megatron.core.utils import make_viewless_tensor, divide
14
+ # from megatron.model.fused_rms_norm import RMSNorm
15
+ # from megatron.model.transformer import (
16
+ # FlashSelfAttention,
17
+ # NoopTransformerLayer,
18
+ # _cfg_to_kwargs,
19
+ # )
20
+ # from megatron.model.enums import AttnMaskType, AttnType
21
+ # from megatron.model.fused_softmax import FusedScaleMaskSoftmax
22
+ # from megatron.model.utils import attention_mask_func
23
+
24
+ # from megatron.model.module import MegatronModule
25
+
26
+ # try:
27
+ # from einops import rearrange
28
+ # except ImportError:
29
+ # rearrange = None
30
+
31
+ # from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
32
+
33
+ # try:
34
+ # # flash attention 2.x
35
+ # from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
36
+ # except ImportError:
37
+ # try:
38
+ # # flash attention 1.x
39
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_func
40
+ # except ImportError:
41
+ # flash_attn_unpadded_func = None
42
+
43
+ # try:
44
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
45
+ # except ImportError:
46
+ # flash_attn_unpadded_relative_attention_bias_func = None
47
+
48
+ # try:
49
+ # from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
50
+ # except ImportError:
51
+ # mask_flash_attn_unpadded_func = None
52
+
53
+
54
+ class LayerNormfp32(torch.nn.LayerNorm):
55
+ """Subclass torch's LayerNorm to handle fp16."""
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ orig_type = x.dtype
59
+ ret = super().forward(x.type(torch.float32))
60
+ return ret.type(orig_type)
61
+
62
+
63
+ def get_abs_pos(abs_pos, tgt_size):
64
+ # abs_pos: L, C
65
+ # tgt_size: M
66
+ # return: M, C
67
+
68
+ # print(tgt_size)
69
+ # print(abs_pos.shape)
70
+ # exit()
71
+ dim = abs_pos.size(-1)
72
+ # print(dim)
73
+ abs_pos_new = abs_pos.squeeze(0)
74
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
75
+
76
+
77
+
78
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
79
+ tgt_size = int(math.sqrt(tgt_size))
80
+ dtype = abs_pos.dtype
81
+
82
+ if src_size != tgt_size:
83
+ old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
84
+ 2).contiguous()
85
+ old_pos_embed = old_pos_embed.to(torch.float32)
86
+ new_pos_embed = F.interpolate(
87
+ old_pos_embed,
88
+ size=(tgt_size, tgt_size),
89
+ mode='bicubic',
90
+ antialias=True,
91
+ align_corners=False,
92
+ ).to(dtype)
93
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
94
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
95
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
96
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
97
+ return vision_pos_embed
98
+ else:
99
+ return abs_pos
100
+
101
+ @torch.jit.script
102
+ def quick_gelu(x):
103
+ return x * torch.sigmoid(1.702 * x)
104
+
105
+
106
+
107
+ class CLIPVisionEmbeddings(nn.Module):
108
+ def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
109
+ super().__init__()
110
+ self.embed_dim = hidden_size
111
+ self.image_size = image_size
112
+ self.patch_size = patch_size
113
+
114
+ self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
115
+
116
+ self.patch_embedding = torch.nn.Conv2d(
117
+ in_channels=num_channels,
118
+ out_channels=self.embed_dim,
119
+ kernel_size=self.patch_size,
120
+ stride=self.patch_size,
121
+ bias=False,
122
+ )
123
+
124
+ self.num_patches = (self.image_size // self.patch_size) ** 2
125
+ self.num_positions = self.num_patches + 1
126
+ self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
127
+ self.register_buffer(
128
+ "position_ids", torch.arange(self.num_positions).expand((1, -1))
129
+ )
130
+
131
+ def forward(self, pixel_values, patch_embeds):
132
+ batch_size = pixel_values.shape[0]
133
+ # patch_embeds = self.patch_embedding(
134
+ # pixel_values
135
+ # ) # shape = [*, width, grid, grid]
136
+
137
+
138
+ if patch_embeds is not None:
139
+ patch_embeds = patch_embeds
140
+ # print(patch_embeds.shape)
141
+ else:
142
+ patch_embeds = self.patch_embedding(pixel_values)
143
+ # print(111111)
144
+ # shape = [*, width, grid, grid]
145
+ # patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
146
+
147
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
148
+
149
+
150
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
151
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
152
+
153
+ # x = torch.cat([cls_token, x], dim=1)
154
+ embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
155
+ # embeddings = embeddings + self.position_embedding(self.position_ids)
156
+ return embeddings
157
+
158
+
159
+ class NoTPFeedForward(nn.Module):
160
+ def __init__(
161
+ self,
162
+ cfg,
163
+ dim: int,
164
+ hidden_dim: int,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
169
+ self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
170
+
171
+ def forward(self, x):
172
+ output = self.fc2(quick_gelu(self.fc1(x)))
173
+ return output
174
+
175
+
176
+ # from optimus.flash_attn_interface import flash_attn_qkvpacked_func
177
+
178
+
179
+ # class NoTPAttention(nn.Module):
180
+ # def __init__(self, cfg):
181
+ # super().__init__()
182
+ # self.num_heads = cfg.num_attention_heads
183
+ # self.n_local_heads = cfg.num_attention_heads
184
+ # self.head_dim = cfg.hidden_size // cfg.num_attention_heads
185
+ # self.max_seq_len = cfg.seq_length
186
+ # self.use_flash_attention = cfg.use_flash_attn
187
+
188
+ # self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
189
+ # self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
190
+
191
+ # # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
192
+
193
+ # self.attn_drop = cfg.attention_dropout
194
+
195
+ # def forward(
196
+ # self,
197
+ # x: torch.Tensor,
198
+ # ):
199
+ # bsz, seqlen, _ = x.shape
200
+ # xqkv = self.qkv_proj(x)
201
+ # xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
202
+
203
+ # if self.use_flash_attention:
204
+ # output = flash_attn_qkvpacked_func(xqkv)
205
+ # output = output.view(bsz, seqlen, -1)
206
+ # else:
207
+ # xq, xk, xv = torch.split(xqkv, 1, dim=2)
208
+ # xq = xq.squeeze(2)
209
+ # xk = xk.squeeze(2)
210
+ # xv = xv.squeeze(2)
211
+ # # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
212
+
213
+ # # (B, num_head, S, head_size)
214
+ # xq = xq.permute(0, 2, 1, 3)
215
+ # xk = xk.permute(0, 2, 1, 3)
216
+ # xv = xv.permute(0, 2, 1, 3)
217
+
218
+ # output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
219
+ # utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
220
+ # output = self.out_proj(output)
221
+ # return output
222
+
223
+
224
+ # from optimus.flash_attn_interface import flash_attn_qkvpacked_func
225
+
226
+
227
+ class NoTPAttention(torch.nn.Module):
228
+ def __init__(self, cfg):
229
+ super().__init__()
230
+ self.num_heads = cfg.num_attention_heads
231
+ self.n_local_heads = cfg.num_attention_heads
232
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
233
+ self.max_seq_len = cfg.seq_length
234
+ self.use_flash_attention = cfg.use_flash_attn
235
+
236
+ self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
237
+ self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
238
+
239
+ # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
240
+
241
+ self.attn_drop = cfg.attention_dropout
242
+
243
+ def forward(
244
+ self,
245
+ x: torch.Tensor,
246
+ ):
247
+ bsz, seqlen, _ = x.shape
248
+ xqkv = self.qkv_proj(x)
249
+ xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
250
+
251
+ if self.use_flash_attention:
252
+ output = flash_attn_qkvpacked_func(xqkv)
253
+ output = output.view(bsz, seqlen, -1)
254
+ # xq, xk, xv = torch.split(xqkv, 1, dim=2)
255
+ # xq = xq.squeeze(2)
256
+ # xk = xk.squeeze(2)
257
+ # xv = xv.squeeze(2)
258
+ # # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
259
+
260
+ # # (B, num_head, S, head_size)
261
+ # xq = xq.permute(0, 2, 1, 3)
262
+ # xk = xk.permute(0, 2, 1, 3)
263
+ # xv = xv.permute(0, 2, 1, 3)
264
+ # # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
265
+ # output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
266
+ # output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
267
+ # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
268
+ else:
269
+ # output = flash_attn_qkvpacked_func(xqkv)
270
+ xq, xk, xv = torch.split(xqkv, 1, dim=2)
271
+ xq = xq.squeeze(2)
272
+ xk = xk.squeeze(2)
273
+ xv = xv.squeeze(2)
274
+ # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
275
+
276
+ # (B, num_head, S, head_size)
277
+ xq = xq.permute(0, 2, 1, 3)
278
+ xk = xk.permute(0, 2, 1, 3)
279
+ xv = xv.permute(0, 2, 1, 3)
280
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
281
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
282
+ output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
283
+ output = self.out_proj(output)
284
+ return output
285
+
286
+ class NoTPTransformerBlock(nn.Module):
287
+ def __init__(self, cfg, layer_id: int, multiple_of=256):
288
+ super().__init__()
289
+
290
+ self.n_heads = cfg.num_attention_heads
291
+ self.dim = cfg.hidden_size
292
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
293
+ self.self_attn = NoTPAttention(cfg)
294
+ self.mlp = NoTPFeedForward(
295
+ cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
296
+ )
297
+ self.layer_id = layer_id
298
+ self.layer_norm1 = torch.nn.LayerNorm(
299
+ cfg.hidden_size, eps=cfg.layernorm_epsilon
300
+ )
301
+ self.layer_norm2 = torch.nn.LayerNorm(
302
+ cfg.hidden_size, eps=cfg.layernorm_epsilon
303
+ )
304
+
305
+ def forward(self, x: torch.Tensor):
306
+ residual = self.self_attn.forward(self.layer_norm1(x))
307
+ h = x + residual
308
+ out = h + self.mlp.forward(self.layer_norm2(h))
309
+ return out
310
+
311
+
312
+ class NoTPTransformer(nn.Module):
313
+ def __init__(self, cfg):
314
+ super().__init__()
315
+
316
+ self.cfg = cfg
317
+ # self.recompute_list = self.cfg.get("recompute_list", [])
318
+ self.num_layers = cfg.num_layers # _get_num_layers(cfg)
319
+
320
+ self.layers = torch.nn.ModuleList()
321
+ for layer_id in range(self.num_layers):
322
+ self.layers.append(
323
+ NoTPTransformerBlock(
324
+ cfg,
325
+ layer_id + 1,
326
+ )
327
+ )
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states,
332
+ ):
333
+
334
+ for lid, layer in enumerate(self.layers):
335
+ # if lid in self.recompute_list:
336
+ # def custom(layer_id):
337
+ # def custom_forward(*args, **kwargs):
338
+ # x_ = self.layers[layer_id](*args, **kwargs)
339
+ # return x_
340
+
341
+ # return custom_forward
342
+
343
+ # assert hidden_states.requires_grad == True, logger.warning(
344
+ # "When using recalculation, the input must have grad fn"
345
+ # )
346
+ # hidden_states = tensor_parallel.checkpoint(
347
+ # custom(lid),
348
+ # False,
349
+ # hidden_states.contiguous()
350
+ # )
351
+ # else:
352
+ hidden_states = layer(hidden_states)
353
+
354
+ return hidden_states
355
+
356
+
357
+ # from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
358
+
359
+ class VitModel(nn.Module):
360
+ def __init__(
361
+ self,
362
+ cfg,
363
+ freeze_embed=False,
364
+ freeze_pre_norm=False
365
+ ) -> None:
366
+ super().__init__()
367
+
368
+ self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
369
+
370
+ if freeze_embed:
371
+ for name, param in self.embeddings.named_parameters():
372
+ param.requires_grad = False
373
+
374
+ self.transformer = NoTPTransformer(cfg=cfg)
375
+
376
+ if cfg.get("fp32norm", False):
377
+ logger.info("Load fp32 layernorm for ViT.")
378
+ self.pre_layrnorm = LayerNormfp32(
379
+ cfg.hidden_size,
380
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
381
+ )
382
+ else:
383
+ self.pre_layrnorm = torch.nn.LayerNorm(
384
+ cfg.hidden_size,
385
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
386
+ )
387
+
388
+ # self.pre_layrnorm = RMSNorm(
389
+ # cfg.hidden_size,
390
+ # eps=cfg.get("pre_layernorm_epsilon", 1e-5),
391
+ # sequence_parallel=False,
392
+ # use_fp32=True,
393
+ # use_optimus=True,
394
+ # )
395
+
396
+ if freeze_pre_norm:
397
+ for name, param in self.pre_layrnorm.named_parameters():
398
+ param.requires_grad = False
399
+
400
+ for p in self.parameters():
401
+ p.micro_dp = True
402
+
403
+ def set_input_tensor(self, input_tensor):
404
+ if not isinstance(input_tensor, list):
405
+ input_tensor = [input_tensor]
406
+ self.transformer.set_input_tensor(input_tensor[0])
407
+
408
+ def __str__(self) -> str:
409
+ return "open_clip"
410
+
411
+ def forward(
412
+ self,
413
+ x,
414
+ patch_embeds
415
+ ):
416
+ x = self.embeddings(x, patch_embeds)
417
+ hidden_states = self.pre_layrnorm(x)
418
+
419
+ # hidden_states, dis = local_dp_scatter(hidden_states)
420
+ output = self.transformer(hidden_states)
421
+
422
+ # output = local_dp_reduce(output, dis)
423
+
424
+ return output
425
+
426
+
427
+ vit_model_cfg = adict(
428
+ num_layers=24,
429
+ hidden_size=1024,
430
+ num_heads = 16,
431
+ num_attention_heads=16,
432
+ ffn_hidden_size=4096,
433
+ seq_length=256,
434
+ max_position_embeddings=256,
435
+ use_flash_attn=False,
436
+ understand_projector_stride=2,
437
+ hidden_dropout = 0.0,
438
+ attention_dropout = 0.0,
439
+ no_persist_layer_norm = False,
440
+ layernorm_epsilon = 1e-5,
441
+ pre_layernorm_epsilon = 1e-5,
442
+ image_size = 224,
443
+ patch_size = 14,
444
+ recompute_list = []
445
+ )
446
+
447
+ def build_clip_l():
448
+ return VitModel(
449
+ cfg=vit_model_cfg,
450
+ freeze_embed=False,
451
+ freeze_pre_norm=False,
452
+ )
453
+
454
+
455
+ if __name__ == '__main__':
456
+
457
+
458
+ from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
459
+
460
+
461
+
462
+ vit_model_cfg = adict(
463
+ num_layers=24,
464
+ hidden_size=1024,
465
+ num_attention_heads=16,
466
+ ffn_hidden_size=4096,
467
+ seq_length=256,
468
+ max_position_embeddings=256,
469
+ use_flash_attn=False,
470
+ understand_projector_stride=2,
471
+ hidden_dropout = 0.0,
472
+ attention_dropout = 0.0,
473
+ no_persist_layer_norm = False,
474
+ layernorm_epsilon = 1e-5,
475
+ pre_layernorm_epsilon = 1e-5,
476
+ image_size = 224,
477
+ patch_size = 14,
478
+ recompute_list = []
479
+ )
480
+
481
+ sam_model = build_sam_vit_b()
482
+
483
+
484
+ vision_model = VitModel(
485
+ cfg=vit_model_cfg,
486
+ freeze_embed=False,
487
+ freeze_pre_norm=False,
488
+ )
489
+
490
+ # model = VitModel(1344)
491
+ # x = torch.zeros(2, 3, 224, 224)
492
+ x = torch.zeros(2, 3, 1024, 1024)
493
+
494
+
495
+ with torch.no_grad():
496
+ # y = vision_model(x)
497
+ patch_embed = sam_model(x)
498
+ print(patch_embed.shape)
499
+ y = vision_model(x, patch_embed)
500
+ print(y.shape)
501
+
502
+ image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
503
+
504
+ print(image_feature.shape)
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/sam_vary_sdpa.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+ from functools import partial
13
+ from flash_attn import flash_attn_qkvpacked_func
14
+ # from .common import LayerNorm2d, MLPBlock
15
+
16
+ # from mmgpt.model.vision_encoder.flash_4 import _attention_rel_h_rel_w
17
+
18
+
19
+ def get_abs_pos(abs_pos, tgt_size):
20
+
21
+ dtype = abs_pos.dtype
22
+
23
+ src_size = abs_pos.size(1)
24
+
25
+ if src_size != tgt_size:
26
+ old_pos_embed = abs_pos.permute(0, 3, 1, 2)
27
+ old_pos_embed = old_pos_embed.to(torch.float32)
28
+ new_pos_embed = F.interpolate(
29
+ old_pos_embed,
30
+ size=(tgt_size, tgt_size),
31
+ mode='bicubic',
32
+ antialias=True,
33
+ align_corners=False,
34
+ ).to(dtype)
35
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
36
+ return new_pos_embed
37
+ else:
38
+ return abs_pos
39
+
40
+
41
+
42
+
43
+ class MLPBlock(nn.Module):
44
+ def __init__(
45
+ self,
46
+ embedding_dim: int,
47
+ mlp_dim: int,
48
+ act: Type[nn.Module] = nn.GELU,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
52
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
53
+ self.act = act()
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ return self.lin2(self.act(self.lin1(x)))
57
+
58
+
59
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
60
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
61
+ class LayerNorm2d(nn.Module):
62
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
63
+ super().__init__()
64
+ self.weight = nn.Parameter(torch.ones(num_channels))
65
+ self.bias = nn.Parameter(torch.zeros(num_channels))
66
+ self.eps = eps
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ u = x.mean(1, keepdim=True)
70
+ s = (x - u).pow(2).mean(1, keepdim=True)
71
+ x = (x - u) / torch.sqrt(s + self.eps)
72
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
73
+ return x
74
+
75
+
76
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
77
+ class ImageEncoderViT(nn.Module):
78
+ def __init__(
79
+ self,
80
+ img_size: int = 1024,
81
+ patch_size: int = 16,
82
+ in_chans: int = 3,
83
+ embed_dim: int = 768,
84
+ depth: int = 12,
85
+ num_heads: int = 12,
86
+ mlp_ratio: float = 4.0,
87
+ out_chans: int = 256,
88
+ qkv_bias: bool = True,
89
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
90
+ act_layer: Type[nn.Module] = nn.GELU,
91
+ use_abs_pos: bool = True,
92
+ use_rel_pos: bool = False,
93
+ rel_pos_zero_init: bool = True,
94
+ window_size: int = 0,
95
+ global_attn_indexes: Tuple[int, ...] = (),
96
+ ) -> None:
97
+ """
98
+ Args:
99
+ img_size (int): Input image size.
100
+ patch_size (int): Patch size.
101
+ in_chans (int): Number of input image channels.
102
+ embed_dim (int): Patch embedding dimension.
103
+ depth (int): Depth of ViT.
104
+ num_heads (int): Number of attention heads in each ViT block.
105
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
106
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
107
+ norm_layer (nn.Module): Normalization layer.
108
+ act_layer (nn.Module): Activation layer.
109
+ use_abs_pos (bool): If True, use absolute positional embeddings.
110
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
111
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
112
+ window_size (int): Window size for window attention blocks.
113
+ global_attn_indexes (list): Indexes for blocks using global attention.
114
+ """
115
+ super().__init__()
116
+ self.img_size = img_size
117
+
118
+ self.patch_embed = PatchEmbed(
119
+ kernel_size=(patch_size, patch_size),
120
+ stride=(patch_size, patch_size),
121
+ in_chans=in_chans,
122
+ embed_dim=embed_dim,
123
+ )
124
+
125
+ self.pos_embed: Optional[nn.Parameter] = None
126
+ if use_abs_pos:
127
+ # Initialize absolute positional embedding with pretrain image size.
128
+ self.pos_embed = nn.Parameter(
129
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
130
+ )
131
+
132
+ self.blocks = nn.ModuleList()
133
+ for i in range(depth):
134
+ block = Block(
135
+ dim=embed_dim,
136
+ num_heads=num_heads,
137
+ mlp_ratio=mlp_ratio,
138
+ qkv_bias=qkv_bias,
139
+ norm_layer=norm_layer,
140
+ act_layer=act_layer,
141
+ use_rel_pos=use_rel_pos,
142
+ rel_pos_zero_init=rel_pos_zero_init,
143
+ window_size=window_size if i not in global_attn_indexes else 0,
144
+ input_size=(img_size // patch_size, img_size // patch_size),
145
+ )
146
+ self.blocks.append(block)
147
+
148
+ self.neck = nn.Sequential(
149
+ nn.Conv2d(
150
+ embed_dim,
151
+ out_chans,
152
+ kernel_size=1,
153
+ bias=False,
154
+ ),
155
+ LayerNorm2d(out_chans),
156
+ nn.Conv2d(
157
+ out_chans,
158
+ out_chans,
159
+ kernel_size=3,
160
+ padding=1,
161
+ bias=False,
162
+ ),
163
+ LayerNorm2d(out_chans),
164
+ )
165
+
166
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
167
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ x = self.patch_embed(x)
171
+ if self.pos_embed is not None:
172
+ # x = x + self.pos_embed
173
+ x = x + get_abs_pos(self.pos_embed, x.size(1))
174
+
175
+ for blk in self.blocks:
176
+ x = blk(x)
177
+
178
+ neck_output = self.neck(x.permute(0, 3, 1, 2))
179
+ conv2_output = self.net_2(neck_output)
180
+ # print(f"conv2_output shape: {conv2_output.shape}")
181
+ conv3_output = self.net_3(conv2_output)
182
+
183
+ return conv3_output
184
+
185
+
186
+ class Block(nn.Module):
187
+ """Transformer blocks with support of window attention and residual propagation blocks"""
188
+
189
+ def __init__(
190
+ self,
191
+ dim: int,
192
+ num_heads: int,
193
+ mlp_ratio: float = 4.0,
194
+ qkv_bias: bool = True,
195
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
196
+ act_layer: Type[nn.Module] = nn.GELU,
197
+ use_rel_pos: bool = False,
198
+ rel_pos_zero_init: bool = True,
199
+ window_size: int = 0,
200
+ input_size: Optional[Tuple[int, int]] = None,
201
+ ) -> None:
202
+ """
203
+ Args:
204
+ dim (int): Number of input channels.
205
+ num_heads (int): Number of attention heads in each ViT block.
206
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
207
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
208
+ norm_layer (nn.Module): Normalization layer.
209
+ act_layer (nn.Module): Activation layer.
210
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
211
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
212
+ window_size (int): Window size for window attention blocks. If it equals 0, then
213
+ use global attention.
214
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
215
+ positional parameter size.
216
+ """
217
+ super().__init__()
218
+ self.norm1 = norm_layer(dim)
219
+ self.attn = Attention(
220
+ dim,
221
+ num_heads=num_heads,
222
+ qkv_bias=qkv_bias,
223
+ use_rel_pos=use_rel_pos,
224
+ rel_pos_zero_init=rel_pos_zero_init,
225
+ input_size=input_size if window_size == 0 else (window_size, window_size),
226
+ )
227
+
228
+ self.norm2 = norm_layer(dim)
229
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
230
+
231
+ self.window_size = window_size
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ shortcut = x
235
+ x = self.norm1(x)
236
+ # Window partition
237
+ if self.window_size > 0:
238
+ H, W = x.shape[1], x.shape[2]
239
+ x, pad_hw = window_partition(x, self.window_size)
240
+
241
+ x = self.attn(x)
242
+ # Reverse window partition
243
+ if self.window_size > 0:
244
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
245
+
246
+ x = shortcut + x
247
+ x = x + self.mlp(self.norm2(x))
248
+
249
+ return x
250
+
251
+
252
+ class Attention(nn.Module):
253
+ """Multi-head Attention block with relative position embeddings."""
254
+
255
+ def __init__(
256
+ self,
257
+ dim: int,
258
+ num_heads: int = 8,
259
+ qkv_bias: bool = True,
260
+ use_rel_pos: bool = False,
261
+ rel_pos_zero_init: bool = True,
262
+ input_size: Optional[Tuple[int, int]] = None,
263
+ ) -> None:
264
+ """
265
+ Args:
266
+ dim (int): Number of input channels.
267
+ num_heads (int): Number of attention heads.
268
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
269
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
270
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
271
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
272
+ positional parameter size.
273
+ """
274
+ super().__init__()
275
+ self.num_heads = num_heads
276
+ head_dim = dim // num_heads
277
+ self.scale = head_dim**-0.5
278
+
279
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
280
+ self.proj = nn.Linear(dim, dim)
281
+
282
+ self.use_rel_pos = use_rel_pos
283
+ if self.use_rel_pos:
284
+ assert (
285
+ input_size is not None
286
+ ), "Input size must be provided if using relative positional encoding."
287
+ # initialize relative positional embeddings
288
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
289
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
290
+
291
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
292
+ B, H, W, _ = x.shape
293
+ # qkv with shape (3, B, nHead, H * W, C)
294
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
295
+ # q, k, v with shape (B * nHead, H * W, C)
296
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
297
+
298
+ rel_h, rel_w = None, None
299
+ if self.use_rel_pos:
300
+ rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
301
+
302
+ q = q.view(B, self.num_heads, H * W, -1)
303
+ k = k.view(B, self.num_heads, H * W, -1)
304
+ v = v.view(B, self.num_heads, H * W, -1)
305
+
306
+ if self.use_rel_pos:
307
+ rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
308
+ rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
309
+ attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
310
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
311
+ # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
312
+ else:
313
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
314
+ # qkv = torch.stack([q, k, v], dim=1).transpose(1, 3).reshape(B, H * W, 3, self.num_heads, -1)
315
+ # x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False).transpose(1, 2)
316
+
317
+
318
+
319
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
320
+
321
+ x = self.proj(x)
322
+
323
+ return x
324
+
325
+
326
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
327
+ """
328
+ Partition into non-overlapping windows with padding if needed.
329
+ Args:
330
+ x (tensor): input tokens with [B, H, W, C].
331
+ window_size (int): window size.
332
+
333
+ Returns:
334
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
335
+ (Hp, Wp): padded height and width before partition
336
+ """
337
+ B, H, W, C = x.shape
338
+
339
+ pad_h = (window_size - H % window_size) % window_size
340
+ pad_w = (window_size - W % window_size) % window_size
341
+ if pad_h > 0 or pad_w > 0:
342
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
343
+ Hp, Wp = H + pad_h, W + pad_w
344
+
345
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
346
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
347
+ return windows, (Hp, Wp)
348
+
349
+
350
+ def window_unpartition(
351
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
352
+ ) -> torch.Tensor:
353
+ """
354
+ Window unpartition into original sequences and removing padding.
355
+ Args:
356
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
357
+ window_size (int): window size.
358
+ pad_hw (Tuple): padded height and width (Hp, Wp).
359
+ hw (Tuple): original height and width (H, W) before padding.
360
+
361
+ Returns:
362
+ x: unpartitioned sequences with [B, H, W, C].
363
+ """
364
+ Hp, Wp = pad_hw
365
+ H, W = hw
366
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
367
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
368
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
369
+
370
+ if Hp > H or Wp > W:
371
+ x = x[:, :H, :W, :].contiguous()
372
+ return x
373
+
374
+
375
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
376
+ """
377
+ Get relative positional embeddings according to the relative positions of
378
+ query and key sizes.
379
+ Args:
380
+ q_size (int): size of query q.
381
+ k_size (int): size of key k.
382
+ rel_pos (Tensor): relative position embeddings (L, C).
383
+
384
+ Returns:
385
+ Extracted positional embeddings according to relative positions.
386
+ """
387
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
388
+ # Interpolate rel pos if needed.
389
+ if rel_pos.shape[0] != max_rel_dist:
390
+ # Interpolate rel pos.
391
+ dtype = rel_pos.dtype
392
+ rel_pos = rel_pos.to(torch.float32)
393
+ rel_pos_resized = F.interpolate(
394
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
395
+ size=max_rel_dist,
396
+ mode="linear",
397
+ ).to(dtype)
398
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
399
+ else:
400
+ rel_pos_resized = rel_pos
401
+
402
+ # Scale the coords with short length if shapes for q and k are different.
403
+ q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
404
+ k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
405
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
406
+
407
+ return rel_pos_resized[relative_coords.long()]
408
+
409
+
410
+ def add_decomposed_rel_pos(
411
+ q: torch.Tensor,
412
+ rel_pos_h: torch.Tensor,
413
+ rel_pos_w: torch.Tensor,
414
+ q_size: Tuple[int, int],
415
+ k_size: Tuple[int, int],
416
+ ) -> torch.Tensor:
417
+ """
418
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
419
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
420
+ Args:
421
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
422
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
423
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
424
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
425
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
426
+
427
+ Returns:
428
+ attn (Tensor): attention map with added relative positional embeddings.
429
+ """
430
+ q_h, q_w = q_size
431
+ k_h, k_w = k_size
432
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
433
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
434
+
435
+ B, _, dim = q.shape
436
+ r_q = q.reshape(B, q_h, q_w, dim)
437
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
438
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
439
+ rel_h = rel_h.unsqueeze(-1)
440
+ rel_w = rel_w.unsqueeze(-2)
441
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
442
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
443
+
444
+ return rel_h, rel_w
445
+
446
+
447
+ class PatchEmbed(nn.Module):
448
+ """
449
+ Image to Patch Embedding.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ kernel_size: Tuple[int, int] = (16, 16),
455
+ stride: Tuple[int, int] = (16, 16),
456
+ padding: Tuple[int, int] = (0, 0),
457
+ in_chans: int = 3,
458
+ embed_dim: int = 768,
459
+ ) -> None:
460
+ """
461
+ Args:
462
+ kernel_size (Tuple): kernel size of the projection layer.
463
+ stride (Tuple): stride of the projection layer.
464
+ padding (Tuple): padding size of the projection layer.
465
+ in_chans (int): Number of input image channels.
466
+ embed_dim (int): Patch embedding dimension.
467
+ """
468
+ super().__init__()
469
+
470
+ self.proj = nn.Conv2d(
471
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
472
+ )
473
+
474
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
475
+ x = self.proj(x)
476
+ # B C H W -> B H W C
477
+ x = x.permute(0, 2, 3, 1)
478
+ return x
479
+
480
+
481
+ def build_sam_vit_b(checkpoint=None):
482
+ return _build_sam(
483
+ encoder_embed_dim=768,
484
+ encoder_depth=12,
485
+ encoder_num_heads=12,
486
+ encoder_global_attn_indexes=[2, 5, 8, 11],
487
+ checkpoint=checkpoint,
488
+ )
489
+
490
+
491
+ def _build_sam(
492
+ encoder_embed_dim,
493
+ encoder_depth,
494
+ encoder_num_heads,
495
+ encoder_global_attn_indexes,
496
+ checkpoint=None,
497
+ ):
498
+ prompt_embed_dim = 256
499
+ image_size = 1024
500
+ vit_patch_size = 16
501
+ image_embedding_size = image_size // vit_patch_size
502
+ image_encoder=ImageEncoderViT(
503
+ depth=encoder_depth,
504
+ embed_dim=encoder_embed_dim,
505
+ img_size=image_size,
506
+ mlp_ratio=4,
507
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
508
+ num_heads=encoder_num_heads,
509
+ patch_size=vit_patch_size,
510
+ qkv_bias=True,
511
+ use_rel_pos=True,
512
+ global_attn_indexes=encoder_global_attn_indexes,
513
+ window_size=14,
514
+ out_chans=prompt_embed_dim,
515
+ )
516
+
517
+ if checkpoint is not None:
518
+ # with open(checkpoint, "rb") as f:
519
+ state_dict = torch.load(checkpoint)
520
+ # print(state_dict.keys())
521
+ # for key in state_dict:
522
+ # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
523
+ # ocr-anyting
524
+ # image_encoder.load_state_dict(state_dict, strict=True)
525
+ # tob
526
+ image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
527
+ print(checkpoint)
528
+ return image_encoder
DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepseek_ocr.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
3
+ import math
4
+ from collections.abc import Iterable, Mapping, Sequence
5
+ from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from transformers import BatchFeature
12
+
13
+ from vllm.config import VllmConfig
14
+ from vllm.model_executor import SamplingMetadata
15
+ from vllm.model_executor.layers.quantization import QuantizationConfig
16
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
17
+ from vllm.multimodal import MULTIMODAL_REGISTRY
18
+ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
19
+ MultiModalKwargs, NestedTensors)
20
+ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
21
+ ImageSize, MultiModalDataItems)
22
+ from vllm.multimodal.processing import (BaseMultiModalProcessor,
23
+ BaseProcessingInfo, PromptReplacement,
24
+ PromptUpdate)
25
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder
26
+ from vllm.sequence import IntermediateTensors
27
+ from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
28
+ MlpProjectorConfig,
29
+ VisionEncoderConfig)
30
+ from process.image_process import (
31
+ DeepseekOCRProcessor, count_tiles)
32
+ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
33
+ # from vllm.utils import is_list_of
34
+
35
+ from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
36
+ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
37
+ init_vllm_registered_model, maybe_prefix,
38
+ merge_multimodal_embeddings)
39
+
40
+ from deepencoder.sam_vary_sdpa import build_sam_vit_b
41
+ from deepencoder.clip_sdpa import build_clip_l
42
+ from deepencoder.build_linear import MlpProjector
43
+ from addict import Dict
44
+ # import time
45
+ from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, PRINT_NUM_VIS_TOKENS, PROMPT
46
+ # The image token id may be various
47
+ _IMAGE_TOKEN = "<image>"
48
+
49
+
50
+ class DeepseekOCRProcessingInfo(BaseProcessingInfo):
51
+
52
+ def get_hf_config(self):
53
+ return self.ctx.get_hf_config(DeepseekVLV2Config)
54
+
55
+ def get_hf_processor(self, **kwargs: object):
56
+ return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
57
+
58
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
59
+ return {"image": None}
60
+
61
+ def get_num_image_tokens(self,
62
+ *,
63
+ image_width: int,
64
+ image_height: int,
65
+ cropping: bool = True) -> int:
66
+ hf_processor = self.get_hf_processor()
67
+
68
+
69
+ # image_size = hf_processor.image_size
70
+ # patch_size = hf_processor.patch_size
71
+ # downsample_ratio = hf_processor.downsample_ratio
72
+
73
+ image_size = IMAGE_SIZE
74
+ base_size = BASE_SIZE
75
+ patch_size = 16
76
+ downsample_ratio = 4
77
+
78
+ if CROP_MODE:
79
+ if image_width <= 640 and image_height <= 640:
80
+ crop_ratio = [1, 1]
81
+ else:
82
+ # images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image)
83
+
84
+ # find the closest aspect ratio to the target
85
+ crop_ratio = count_tiles(image_width, image_height, image_size=IMAGE_SIZE)
86
+
87
+ # print('===========')
88
+ # print('crop_ratio ', crop_ratio)
89
+ # print('============')
90
+
91
+ num_width_tiles, num_height_tiles = crop_ratio
92
+ else:
93
+ num_width_tiles = num_height_tiles = 1
94
+
95
+ h = w = math.ceil((base_size // patch_size) / downsample_ratio)
96
+
97
+ h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
98
+
99
+ global_views_tokens = h * (w + 1)
100
+ if num_width_tiles >1 or num_height_tiles>1:
101
+ local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
102
+ else:
103
+ local_views_tokens = 0
104
+
105
+
106
+ return global_views_tokens + local_views_tokens + 1
107
+
108
+ def get_image_size_with_most_features(self) -> ImageSize:
109
+
110
+ if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
111
+ return ImageSize(width=1024*2, height=1024*2)
112
+ return ImageSize(width=640*2, height=640*2)
113
+
114
+
115
+ class DeepseekOCRDummyInputsBuilder(
116
+ BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
117
+
118
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
119
+ num_images = mm_counts.get("image", 0)
120
+
121
+ processor = self.info.get_hf_processor()
122
+ image_token = processor.image_token
123
+
124
+ return image_token * num_images
125
+
126
+ def get_dummy_mm_data(
127
+ self,
128
+ seq_len: int,
129
+ mm_counts: Mapping[str, int],
130
+ ) -> MultiModalDataDict:
131
+ num_images = mm_counts.get("image", 0)
132
+
133
+ max_image_size = self.info.get_image_size_with_most_features()
134
+
135
+ if '<image>' in PROMPT:
136
+ return {
137
+ "image":
138
+ DeepseekOCRProcessor().tokenize_with_images(images = self._get_dummy_images(width=max_image_size.width,
139
+ height=max_image_size.height,
140
+ num_images=num_images), bos=True, eos=True, cropping=CROP_MODE)
141
+ }
142
+ else:
143
+ return {
144
+ "image": []
145
+ }
146
+
147
+
148
+
149
+
150
+ class DeepseekOCRMultiModalProcessor(
151
+ BaseMultiModalProcessor[DeepseekOCRProcessingInfo]):
152
+
153
+
154
+ def _call_hf_processor(
155
+ self,
156
+ prompt: str,
157
+ mm_data: Mapping[str, object],
158
+ mm_kwargs: Mapping[str, object],
159
+ ) -> BatchFeature:
160
+
161
+
162
+ # print(mm_data)
163
+ if mm_data:
164
+ processed_outputs = self.info.ctx.call_hf_processor(
165
+ self.info.get_hf_processor(**mm_kwargs),
166
+ dict(prompt=prompt, **mm_data),
167
+ mm_kwargs,
168
+ )
169
+
170
+ else:
171
+ tokenizer = self.info.get_tokenizer()
172
+ processed_outputs = tokenizer(prompt,
173
+ add_special_tokens=True,
174
+ return_tensors="pt")
175
+
176
+ return processed_outputs
177
+
178
+ def _get_mm_fields_config(
179
+ self,
180
+ hf_inputs: BatchFeature,
181
+ hf_processor_mm_kwargs: Mapping[str, object],
182
+ ) -> Mapping[str, MultiModalFieldConfig]:
183
+ return dict(
184
+ pixel_values=MultiModalFieldConfig.batched("image"),
185
+ images_spatial_crop=MultiModalFieldConfig.batched("image"),
186
+ # image_embeds=MultiModalFieldConfig.batched("image2"),
187
+ images_crop=MultiModalFieldConfig.batched("image"),
188
+ )
189
+
190
+ def _get_prompt_updates(
191
+ self,
192
+ mm_items: MultiModalDataItems,
193
+ hf_processor_mm_kwargs: Mapping[str, object],
194
+ out_mm_kwargs: MultiModalKwargs,
195
+ ) -> Sequence[PromptUpdate]:
196
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
197
+
198
+ image_token_id = hf_processor.image_token_id
199
+ assert isinstance(image_token_id, int)
200
+
201
+ def get_replacement_deepseek_vl2(item_idx: int):
202
+ images = mm_items.get_items(
203
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
204
+
205
+
206
+
207
+ if isinstance(images, ImageEmbeddingItems):
208
+ num_image_tokens = images.get_feature_size(item_idx)
209
+ else:
210
+
211
+
212
+ width = images[0][-1][0][0]
213
+ height = images[0][-1][0][1]
214
+
215
+ num_image_tokens = self.info.get_num_image_tokens(
216
+ image_width=width,
217
+ image_height=height,
218
+ # flag = True,
219
+ cropping=CROP_MODE,
220
+ )
221
+ return [image_token_id] * num_image_tokens
222
+
223
+ return [
224
+ PromptReplacement(
225
+ modality="image",
226
+ target=[image_token_id],
227
+ replacement=get_replacement_deepseek_vl2,
228
+ )
229
+ ]
230
+
231
+ def _cached_apply_hf_processor(
232
+ self,
233
+ prompt: Union[str, list[int]],
234
+ mm_data_items: MultiModalDataItems,
235
+ hf_processor_mm_kwargs: Mapping[str, object],
236
+ ) -> tuple[list[int], MultiModalKwargs, bool]:
237
+ # The processor logic is different for len(images) <= 2 vs > 2
238
+ # Since the processing cache assumes that the processor output is
239
+ # invariant of how many images are passed per prompt, we only
240
+ # perform caching for the most common case
241
+ if mm_data_items.get_count("image", strict=False) > 2:
242
+ # This code path corresponds to the cache being disabled
243
+ return self._apply_hf_processor_main(
244
+ prompt=prompt,
245
+ mm_items=mm_data_items,
246
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
247
+ enable_hf_prompt_update=True,
248
+ )
249
+
250
+ return super()._cached_apply_hf_processor(
251
+ prompt=prompt,
252
+ mm_data_items=mm_data_items,
253
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
254
+ )
255
+
256
+
257
+ @MULTIMODAL_REGISTRY.register_processor(
258
+ DeepseekOCRMultiModalProcessor,
259
+ info=DeepseekOCRProcessingInfo,
260
+ dummy_inputs=DeepseekOCRDummyInputsBuilder)
261
+ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
262
+
263
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
264
+ "language.": "language_model.",
265
+ })
266
+
267
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
268
+ super().__init__()
269
+
270
+ config: DeepseekVLV2Config = vllm_config.model_config.hf_config
271
+ quant_config = vllm_config.quant_config
272
+ multimodal_config = vllm_config.model_config.multimodal_config
273
+
274
+ # config.model_type ='deepseek_vl_v2'
275
+
276
+ self.config = config
277
+ self.multimodal_config = multimodal_config
278
+
279
+
280
+ self.vision_config = config.vision_config
281
+ self.projector_config = config.projector_config
282
+ self.text_config = config.text_config
283
+
284
+ model_config = vllm_config.model_config
285
+ tokenizer = cached_tokenizer_from_config(model_config)
286
+ self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
287
+
288
+ self.sam_model = build_sam_vit_b()
289
+ self.vision_model = build_clip_l()
290
+
291
+ n_embed = 1280
292
+ self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
293
+ self.tile_tag = config.tile_tag
294
+ self.global_view_pos = config.global_view_pos
295
+
296
+ # self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead")
297
+ # self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
298
+ # self.projector = torch.compile(self.projector, mode="max-autotune")
299
+
300
+
301
+
302
+
303
+ # special token for image token sequence format
304
+ embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
305
+ if self.tile_tag == "2D":
306
+ # <|view_separator|>, <|\n|>
307
+ self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
308
+ self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
309
+ else:
310
+ raise ValueError(
311
+ f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
312
+ )
313
+
314
+ if self.text_config.topk_method == "noaux_tc":
315
+ architectures = ["DeepseekV3ForCausalLM"]
316
+ elif not self.text_config.use_mla:
317
+ architectures = ["DeepseekForCausalLM"]
318
+ else:
319
+ architectures = ["DeepseekV2ForCausalLM"]
320
+
321
+ self.language_model = init_vllm_registered_model(
322
+ vllm_config=vllm_config,
323
+ hf_config=self.text_config,
324
+ prefix=maybe_prefix(prefix, "language"),
325
+ architectures=architectures,
326
+ )
327
+
328
+ self.make_empty_intermediate_tensors = (
329
+ self.language_model.make_empty_intermediate_tensors)
330
+
331
+
332
+
333
+ def _parse_and_validate_image_input(
334
+ self, **kwargs: object):
335
+
336
+ pixel_values = kwargs.pop("pixel_values", None)
337
+ images_spatial_crop = kwargs.pop("images_spatial_crop", None)
338
+ images_crop = kwargs.pop("images_crop", None)
339
+
340
+
341
+ if pixel_values is None or torch.sum(pixel_values).item() == 0:
342
+ return None
343
+
344
+ if pixel_values is not None:
345
+ if not isinstance(pixel_values, (torch.Tensor, list)):
346
+ raise ValueError("Incorrect type of pixel values. "
347
+ f"Got type: {type(pixel_values)}")
348
+
349
+ if not isinstance(images_spatial_crop, (torch.Tensor, list)):
350
+ raise ValueError("Incorrect type of image sizes. "
351
+ f"Got type: {type(images_spatial_crop)}")
352
+
353
+ if not isinstance(images_crop, (torch.Tensor, list)):
354
+ raise ValueError("Incorrect type of image crop. "
355
+ f"Got type: {type(images_crop)}")
356
+
357
+ return [pixel_values, images_crop, images_spatial_crop]
358
+
359
+
360
+ raise AssertionError("This line should be unreachable.")
361
+
362
+
363
+
364
+ def _pixel_values_to_embedding(
365
+ self,
366
+ pixel_values: torch.Tensor,
367
+ images_crop: torch.Tensor,
368
+ images_spatial_crop: torch.Tensor,
369
+ ) -> NestedTensors:
370
+
371
+ # Pixel_values (global view): [n_image, batch_size, 3, height, width]
372
+ # images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
373
+ # images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
374
+ # split the pixel and image_crop, all batch_size = 1
375
+
376
+ images_in_this_batch = []
377
+
378
+
379
+ # print(type(images_crop))
380
+
381
+ # print(pixel_values.shape)
382
+
383
+
384
+ with torch.no_grad():
385
+ for jdx in range(images_spatial_crop.size(0)):
386
+ # with torch.set_grad_enabled(False):
387
+ patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1
388
+ image_ori = pixel_values[jdx]
389
+ crop_shape = images_spatial_crop[jdx][0]
390
+
391
+ if torch.sum(patches).item() != 0: # if all values = 0, no crop
392
+ # P, C, H, W = patches.shape
393
+ # crop_flag = 1
394
+ local_features_1 = self.sam_model(patches)
395
+ #TODO del patches
396
+ # torch.compiler.cudagraph_mark_step_begin()
397
+ local_features_2 = self.vision_model(patches, local_features_1)
398
+
399
+
400
+ local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
401
+ local_features = self.projector(local_features)
402
+
403
+
404
+ global_features_1 = self.sam_model(image_ori)
405
+ global_features_2 = self.vision_model(image_ori, global_features_1)
406
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
407
+ global_features = self.projector(global_features)
408
+
409
+ if PRINT_NUM_VIS_TOKENS:
410
+ print('=====================')
411
+ print('BASE: ', global_features.shape)
412
+ print('PATCHES: ', local_features.shape)
413
+ print('=====================')
414
+
415
+ _, hw, n_dim = global_features.shape
416
+ h = w = int(hw ** 0.5)
417
+
418
+ _2, hw2, n_dim2 = local_features.shape
419
+ h2 = w2 = int(hw2 ** 0.5)
420
+
421
+ width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
422
+
423
+ global_features = global_features.view(h, w, n_dim)
424
+
425
+ global_features = torch.cat(
426
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
427
+ )
428
+
429
+ global_features = global_features.view(-1, n_dim)
430
+
431
+
432
+ local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
433
+ local_features = torch.cat(
434
+ [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
435
+ )
436
+ local_features = local_features.view(-1, n_dim2)
437
+
438
+ global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
439
+
440
+ else:
441
+ global_features_1 = self.sam_model(image_ori)
442
+ global_features_2 = self.vision_model(image_ori, global_features_1)
443
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
444
+ global_features = self.projector(global_features)
445
+
446
+ if PRINT_NUM_VIS_TOKENS:
447
+ print('=====================')
448
+ print('BASE: ', global_features.shape)
449
+ print('NO PATCHES')
450
+ print('=====================')
451
+
452
+ _, hw, n_dim = global_features.shape
453
+ h = w = int(hw ** 0.5)
454
+
455
+ global_features = global_features.view(h, w, n_dim)
456
+
457
+ global_features = torch.cat(
458
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
459
+ )
460
+
461
+ global_features = global_features.view(-1, n_dim)
462
+
463
+ global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
464
+
465
+ images_in_this_batch.append(global_local_features)
466
+
467
+ return images_in_this_batch
468
+
469
+ def _process_image_input(
470
+ self, image_input) -> torch.Tensor:
471
+
472
+
473
+ # image_input: [pixel_values, images_crop, images_spatial_crop]
474
+
475
+ pixel_values = image_input[0].to(torch.bfloat16)
476
+ # print(image_input[1][0].shape)
477
+ # print(type(image_input[1]))
478
+ # exit()
479
+
480
+ # images_crop = image_input[1].to(torch.bfloat16)
481
+ images_crop = image_input[1]
482
+ # images_crop = image_input[1]
483
+ images_spatial_crop = image_input[2].to(dtype=torch.long)
484
+
485
+ # local_start = time.time()
486
+ vision_features = self._pixel_values_to_embedding(
487
+ pixel_values=pixel_values, images_crop = images_crop, images_spatial_crop=images_spatial_crop)
488
+
489
+ # local_total_time = time.time() - local_start
490
+
491
+ # print('encoder_time: ', local_total_time)
492
+ # exit()
493
+ return vision_features
494
+
495
+ def get_language_model(self) -> torch.nn.Module:
496
+ return self.language_model
497
+
498
+ def get_multimodal_embeddings(
499
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
500
+ image_input = self._parse_and_validate_image_input(**kwargs)
501
+ if image_input is None:
502
+ return None
503
+ vision_embeddings = self._process_image_input(image_input)
504
+ return vision_embeddings
505
+
506
+
507
+
508
+ def get_input_embeddings(
509
+ self,
510
+ input_ids: torch.Tensor,
511
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
512
+ ) -> torch.Tensor:
513
+
514
+
515
+
516
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
517
+
518
+
519
+ if multimodal_embeddings is not None:
520
+ inputs_embeds = merge_multimodal_embeddings(
521
+ input_ids, inputs_embeds, multimodal_embeddings,
522
+ self.image_token_id)
523
+ # print(len(multimodal_embeddings))
524
+ # print(input_ids.shape)
525
+ # print(type(inputs_embeds))
526
+ # print(inputs_embeds.shape)
527
+
528
+ return inputs_embeds
529
+
530
+ def forward(self,
531
+ input_ids: torch.Tensor,
532
+ positions: torch.Tensor,
533
+ intermediate_tensors: Optional[IntermediateTensors] = None,
534
+ inputs_embeds: Optional[torch.Tensor] = None,
535
+ **kwargs: object):
536
+
537
+ if intermediate_tensors is not None:
538
+ inputs_embeds = None
539
+
540
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
541
+ # condition is for v0 compatibility
542
+ elif inputs_embeds is None:
543
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
544
+ inputs_embeds = self.get_input_embeddings(input_ids,
545
+ vision_embeddings)
546
+ input_ids = None
547
+
548
+ hidden_states = self.language_model(input_ids,
549
+ positions,
550
+ intermediate_tensors,
551
+ inputs_embeds=inputs_embeds)
552
+
553
+ return hidden_states
554
+
555
+ def compute_logits(
556
+ self,
557
+ hidden_states: torch.Tensor,
558
+ sampling_metadata: SamplingMetadata,
559
+ ) -> Optional[torch.Tensor]:
560
+ return self.language_model.compute_logits(hidden_states,
561
+ sampling_metadata)
562
+
563
+
564
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
565
+ processed_weights = []
566
+
567
+ for name, tensor in weights:
568
+ if 'sam_model' in name or 'vision_model' in name or 'projector' in name or 'image_newline' in name or 'view_seperator' in name:
569
+ new_name = name.replace('model.', '', 1)
570
+ else:
571
+ new_name = 'language.' + name
572
+
573
+ processed_weights.append((new_name, tensor))
574
+
575
+ loader = AutoWeightsLoader(self)
576
+ autoloaded_weights = loader.load_weights(processed_weights, mapper=self.hf_to_vllm_mapper)
577
+
578
+
579
+
580
+
581
+
582
+ return autoloaded_weights
DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/__init__.py ADDED
File without changes
DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image, ImageOps
7
+ from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
8
+ from transformers.processing_utils import ProcessorMixin
9
+ from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
10
+
11
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
12
+ best_ratio_diff = float('inf')
13
+ best_ratio = (1, 1)
14
+ area = width * height
15
+ for ratio in target_ratios:
16
+ target_aspect_ratio = ratio[0] / ratio[1]
17
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
18
+ if ratio_diff < best_ratio_diff:
19
+ best_ratio_diff = ratio_diff
20
+ best_ratio = ratio
21
+ elif ratio_diff == best_ratio_diff:
22
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
23
+ best_ratio = ratio
24
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
25
+ return best_ratio
26
+
27
+
28
+ def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
29
+ aspect_ratio = orig_width / orig_height
30
+
31
+ # calculate the existing image aspect ratio
32
+ target_ratios = set(
33
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
34
+ i * j <= max_num and i * j >= min_num)
35
+ # print(target_ratios)
36
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
37
+
38
+ # find the closest aspect ratio to the target
39
+ target_aspect_ratio = find_closest_aspect_ratio(
40
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
41
+
42
+ return target_aspect_ratio
43
+
44
+
45
+ def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
46
+ orig_width, orig_height = image.size
47
+ aspect_ratio = orig_width / orig_height
48
+
49
+ # calculate the existing image aspect ratio
50
+ target_ratios = set(
51
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
52
+ i * j <= max_num and i * j >= min_num)
53
+ # print(target_ratios)
54
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
55
+
56
+ # find the closest aspect ratio to the target
57
+ target_aspect_ratio = find_closest_aspect_ratio(
58
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
59
+
60
+ # print(target_aspect_ratio)
61
+ # calculate the target width and height
62
+ target_width = image_size * target_aspect_ratio[0]
63
+ target_height = image_size * target_aspect_ratio[1]
64
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
65
+
66
+ # resize the image
67
+ resized_img = image.resize((target_width, target_height))
68
+ processed_images = []
69
+ for i in range(blocks):
70
+ box = (
71
+ (i % (target_width // image_size)) * image_size,
72
+ (i // (target_width // image_size)) * image_size,
73
+ ((i % (target_width // image_size)) + 1) * image_size,
74
+ ((i // (target_width // image_size)) + 1) * image_size
75
+ )
76
+ # split the image
77
+ split_img = resized_img.crop(box)
78
+ processed_images.append(split_img)
79
+ assert len(processed_images) == blocks
80
+ if use_thumbnail and len(processed_images) != 1:
81
+ thumbnail_img = image.resize((image_size, image_size))
82
+ processed_images.append(thumbnail_img)
83
+ return processed_images, target_aspect_ratio
84
+
85
+
86
+
87
+
88
+
89
+ class ImageTransform:
90
+
91
+ def __init__(self,
92
+ mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
93
+ std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
94
+ normalize: bool = True):
95
+ self.mean = mean
96
+ self.std = std
97
+ self.normalize = normalize
98
+
99
+ transform_pipelines = [T.ToTensor()]
100
+
101
+ if normalize:
102
+ transform_pipelines.append(T.Normalize(mean, std))
103
+
104
+ self.transform = T.Compose(transform_pipelines)
105
+
106
+ def __call__(self, pil_img: Image.Image):
107
+ x = self.transform(pil_img)
108
+ return x
109
+
110
+
111
+ class DeepseekOCRProcessor(ProcessorMixin):
112
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
113
+ attributes = ["tokenizer"]
114
+
115
+ def __init__(
116
+ self,
117
+ tokenizer: LlamaTokenizerFast = TOKENIZER,
118
+ candidate_resolutions: Tuple[Tuple[int, int]] = [[1024, 1024]],
119
+ patch_size: int = 16,
120
+ downsample_ratio: int = 4,
121
+ image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
122
+ image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
123
+ normalize: bool = True,
124
+ image_token: str = "<image>",
125
+ pad_token: str = "<|▁pad▁|>",
126
+ add_special_token: bool = False,
127
+ sft_format: str = "deepseek",
128
+ mask_prompt: bool = True,
129
+ ignore_id: int = -100,
130
+ **kwargs,
131
+ ):
132
+
133
+ # self.candidate_resolutions = candidate_resolutions # placeholder no use
134
+ self.image_size = IMAGE_SIZE
135
+ self.base_size = BASE_SIZE
136
+ # self.patch_size = patch_size
137
+ self.patch_size = 16
138
+ self.image_mean = image_mean
139
+ self.image_std = image_std
140
+ self.normalize = normalize
141
+ # self.downsample_ratio = downsample_ratio
142
+ self.downsample_ratio = 4
143
+
144
+ self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
145
+
146
+
147
+ self.tokenizer = tokenizer
148
+ # self.tokenizer = add_special_token(tokenizer)
149
+ self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference
150
+
151
+ # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
152
+ if self.tokenizer.pad_token is None:
153
+ self.tokenizer.add_special_tokens({'pad_token': pad_token})
154
+
155
+ # add image token
156
+ # image_token_id = self.tokenizer.vocab.get(image_token)
157
+ # if image_token_id is None:
158
+ # special_tokens = [image_token]
159
+ # special_tokens_dict = {"additional_special_tokens": special_tokens}
160
+ # self.tokenizer.add_special_tokens(special_tokens_dict)
161
+ self.image_token_id = self.tokenizer.vocab.get(image_token)
162
+
163
+ # add five special tokens for grounding-related tasks
164
+ # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
165
+ # special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
166
+ # special_tokens_dict = {"additional_special_tokens": special_tokens}
167
+
168
+ # special_tokens = ['<image>','<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>', '<td>', '</td>', '<tr>', '</tr>']
169
+ # special_tokens_dict = {"additional_special_tokens": special_tokens}
170
+ # self.tokenizer.add_special_tokens(special_tokens_dict)
171
+
172
+ # # add special tokens for SFT data
173
+ # special_tokens = ["<|User|>", "<|Assistant|>"]
174
+ # special_tokens_dict = {"additional_special_tokens": special_tokens}
175
+ # self.tokenizer.add_special_tokens(special_tokens_dict)
176
+
177
+ self.image_token = image_token
178
+ self.pad_token = pad_token
179
+ self.add_special_token = add_special_token
180
+ self.sft_format = sft_format
181
+ self.mask_prompt = mask_prompt
182
+ self.ignore_id = ignore_id
183
+
184
+ super().__init__(
185
+ tokenizer,
186
+ **kwargs,
187
+ )
188
+
189
+
190
+
191
+
192
+ # def select_best_resolution(self, image_size):
193
+ # # used for cropping
194
+ # original_width, original_height = image_size
195
+ # best_fit = None
196
+ # max_effective_resolution = 0
197
+ # min_wasted_resolution = float("inf")
198
+
199
+ # for width, height in self.candidate_resolutions:
200
+ # scale = min(width / original_width, height / original_height)
201
+ # downscaled_width, downscaled_height = int(
202
+ # original_width * scale), int(original_height * scale)
203
+ # effective_resolution = min(downscaled_width * downscaled_height,
204
+ # original_width * original_height)
205
+ # wasted_resolution = (width * height) - effective_resolution
206
+
207
+ # if effective_resolution > max_effective_resolution or (
208
+ # effective_resolution == max_effective_resolution
209
+ # and wasted_resolution < min_wasted_resolution):
210
+ # max_effective_resolution = effective_resolution
211
+ # min_wasted_resolution = wasted_resolution
212
+ # best_fit = (width, height)
213
+
214
+ # return best_fit
215
+
216
+ @property
217
+ def bos_id(self):
218
+ return self.tokenizer.bos_token_id
219
+
220
+ @property
221
+ def eos_id(self):
222
+ return self.tokenizer.eos_token_id
223
+
224
+ @property
225
+ def pad_id(self):
226
+ return self.tokenizer.pad_token_id
227
+
228
+ def encode(self, text: str, bos: bool = True, eos: bool = False):
229
+ t = self.tokenizer.encode(text, add_special_tokens=False)
230
+
231
+ if bos:
232
+ t = [self.bos_id] + t
233
+ if eos:
234
+ t = t + [self.eos_id]
235
+
236
+ return t
237
+
238
+ def decode(self, t: List[int], **kwargs) -> str:
239
+ return self.tokenizer.decode(t, **kwargs)
240
+
241
+ def process_one(
242
+ self,
243
+ prompt: str,
244
+ images: List,
245
+ inference_mode: bool = True,
246
+ **kwargs,
247
+ ):
248
+ """
249
+
250
+ Args:
251
+ prompt (str): the formatted prompt;
252
+ conversations (List[Dict]): conversations with a list of messages;
253
+ images (List[ImageType]): the list of images;
254
+ inference_mode (bool): if True, then remove the last eos token;
255
+ system_prompt (str): the system prompt;
256
+ **kwargs:
257
+
258
+ Returns:
259
+ outputs (BaseProcessorOutput): the output of the processor,
260
+ - input_ids (torch.LongTensor): [N + image tokens]
261
+ - target_ids (torch.LongTensor): [N + image tokens]
262
+ - pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
263
+ - image_id (int): the id of the image token
264
+ - num_image_tokens (List[int]): the number of image tokens
265
+ """
266
+
267
+ assert (prompt is not None and images is not None
268
+ ), "prompt and images must be used at the same time."
269
+
270
+ sft_format = prompt
271
+
272
+ input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0]
273
+
274
+
275
+ return {
276
+ "input_ids": input_ids,
277
+ "pixel_values": pixel_values,
278
+ "images_crop": images_crop,
279
+ "images_seq_mask": images_seq_mask,
280
+ "images_spatial_crop": images_spatial_crop,
281
+ "num_image_tokens": num_image_tokens,
282
+ }
283
+
284
+
285
+ # prepare = BatchFeature(
286
+ # data=dict(
287
+ # input_ids=input_ids,
288
+ # pixel_values=pixel_values,
289
+ # images_crop = images_crop,
290
+ # images_seq_mask=images_seq_mask,
291
+ # images_spatial_crop=images_spatial_crop,
292
+ # num_image_tokens=num_image_tokens,
293
+ # ),
294
+ # tensor_type="pt",
295
+ # )
296
+ # return prepare
297
+
298
+ def __call__(
299
+ self,
300
+ *,
301
+ prompt: str,
302
+ images: List,
303
+ inference_mode: bool = True,
304
+ **kwargs,
305
+ ):
306
+ """
307
+
308
+ Args:
309
+ prompt (str): the formatted prompt;
310
+ images (List[ImageType]): the list of images;
311
+ inference_mode (bool): if True, then remove the last eos token;
312
+ **kwargs:
313
+
314
+ Returns:
315
+ outputs (BaseProcessorOutput): the output of the processor,
316
+ - input_ids (torch.LongTensor): [N + image tokens]
317
+ - images (torch.FloatTensor): [n_images, 3, H, W]
318
+ - image_id (int): the id of the image token
319
+ - num_image_tokens (List[int]): the number of image tokens
320
+ """
321
+
322
+ prepare = self.process_one(
323
+ prompt=prompt,
324
+ images=images,
325
+ inference_mode=inference_mode,
326
+ )
327
+
328
+ return prepare
329
+
330
+ def tokenize_with_images(
331
+ self,
332
+ # conversation: str,
333
+ images: List[Image.Image],
334
+ bos: bool = True,
335
+ eos: bool = True,
336
+ cropping: bool = True,
337
+ ):
338
+ """Tokenize text with <image> tags."""
339
+
340
+ # print(conversation)
341
+ conversation = PROMPT
342
+ assert conversation.count(self.image_token) == len(images)
343
+ text_splits = conversation.split(self.image_token)
344
+ images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], []
345
+ image_shapes = []
346
+ num_image_tokens = []
347
+ tokenized_str = []
348
+ # print('image: ', len(images))
349
+ for text_sep, image in zip(text_splits, images):
350
+ """encode text_sep"""
351
+ tokenized_sep = self.encode(text_sep, bos=False, eos=False)
352
+ tokenized_str += tokenized_sep
353
+ images_seq_mask += [False] * len(tokenized_sep)
354
+
355
+ """select best resolution for anyres"""
356
+ # if cropping:
357
+ # best_width, best_height = self.select_best_resolution(image.size)
358
+ # else:
359
+ # best_width, best_height = self.image_size, self.image_size
360
+
361
+ image_shapes.append(image.size)
362
+
363
+ if image.size[0] <= 640 and image.size[1] <= 640:
364
+ crop_ratio = [1, 1]
365
+ else:
366
+ if cropping:
367
+ # print('image-size: ', image.size)
368
+ # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
369
+ # print('image ', image.size)
370
+ # print('open_size:', image.size)
371
+ images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE)
372
+ # print('crop_ratio: ', crop_ratio)
373
+ else:
374
+ # best_width, best_height = self.image_size, self.image_size
375
+ crop_ratio = [1, 1]
376
+ # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
377
+
378
+ # print(crop_ratio)
379
+ """process the global view"""
380
+
381
+ # if cropping
382
+ if self.image_size <= 640 and not cropping:
383
+ # print('directly resize')
384
+ image = image.resize((self.image_size, self.image_size))
385
+
386
+ global_view = ImageOps.pad(image, (self.base_size, self.base_size),
387
+ color=tuple(int(x * 255) for x in self.image_transform.mean))
388
+ images_list.append(self.image_transform(global_view))
389
+
390
+ """record height / width crop num"""
391
+ # width_crop_num, height_crop_num = best_width // self.image_size, best_height // self.image_size
392
+ num_width_tiles, num_height_tiles = crop_ratio
393
+ images_spatial_crop.append([num_width_tiles, num_height_tiles])
394
+
395
+
396
+
397
+
398
+ if num_width_tiles > 1 or num_height_tiles > 1:
399
+ """process the local views"""
400
+ # local_view = ImageOps.pad(image, (best_width, best_height),
401
+ # color=tuple(int(x * 255) for x in self.image_transform.mean))
402
+ # for i in range(0, best_height, self.image_size):
403
+ # for j in range(0, best_width, self.image_size):
404
+ # images_crop_list.append(
405
+ # self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
406
+ for i in range(len(images_crop_raw)):
407
+ images_crop_list.append(self.image_transform(images_crop_raw[i]))
408
+
409
+ # """process the global view"""
410
+ # global_view = ImageOps.pad(image, (self.image_size, self.image_size),
411
+ # color=tuple(int(x * 255) for x in self.image_transform.mean))
412
+ # images_list.append(self.image_transform(global_view))
413
+
414
+ # """process the local views"""
415
+ # local_view = ImageOps.pad(image, (best_width, best_height),
416
+ # color=tuple(int(x * 255) for x in self.image_transform.mean))
417
+ # for i in range(0, best_height, self.image_size):
418
+ # for j in range(0, best_width, self.image_size):
419
+ # images_list.append(
420
+ # self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
421
+
422
+ # """add image tokens"""
423
+ """add image tokens"""
424
+ num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
425
+ num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
426
+
427
+
428
+ tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
429
+ tokenized_image += [self.image_token_id]
430
+ if num_width_tiles > 1 or num_height_tiles > 1:
431
+ tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * (
432
+ num_queries * num_height_tiles)
433
+ tokenized_str += tokenized_image
434
+ images_seq_mask += [True] * len(tokenized_image)
435
+ num_image_tokens.append(len(tokenized_image))
436
+
437
+ """process the last text split"""
438
+ tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
439
+ tokenized_str += tokenized_sep
440
+ images_seq_mask += [False] * len(tokenized_sep)
441
+
442
+ """add the bos and eos tokens"""
443
+ if bos:
444
+ tokenized_str = [self.bos_id] + tokenized_str
445
+ images_seq_mask = [False] + images_seq_mask
446
+ if eos:
447
+ tokenized_str = tokenized_str + [self.eos_id]
448
+ images_seq_mask = images_seq_mask + [False]
449
+
450
+ assert len(tokenized_str) == len(
451
+ images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
452
+
453
+
454
+
455
+ masked_tokenized_str = []
456
+ for token_index in tokenized_str:
457
+ if token_index != self.image_token_id:
458
+ masked_tokenized_str.append(token_index)
459
+ else:
460
+ masked_tokenized_str.append(self.ignore_id)
461
+
462
+ assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
463
+ (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
464
+ f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
465
+
466
+ input_ids = torch.LongTensor(tokenized_str)
467
+ target_ids = torch.LongTensor(masked_tokenized_str)
468
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
469
+
470
+ # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
471
+ target_ids[(input_ids < 0) |
472
+ (input_ids == self.image_token_id)] = self.ignore_id
473
+ input_ids[input_ids < 0] = self.pad_id
474
+
475
+ inference_mode = True
476
+
477
+ if inference_mode:
478
+ # Remove the ending eos token
479
+ assert input_ids[-1] == self.eos_id
480
+ input_ids = input_ids[:-1]
481
+ target_ids = target_ids[:-1]
482
+ images_seq_mask = images_seq_mask[:-1]
483
+
484
+ if len(images_list) == 0:
485
+ pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
486
+ images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
487
+ images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
488
+ else:
489
+ pixel_values = torch.stack(images_list, dim=0)
490
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
491
+ if images_crop_list:
492
+ images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
493
+ else:
494
+ images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
495
+
496
+ input_ids = input_ids.unsqueeze(0)
497
+
498
+
499
+ return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]]
500
+
501
+
502
+ AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)
DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsProcessor
3
+ from transformers.generation.logits_process import _calc_banned_ngram_tokens
4
+ from typing import List, Set
5
+
6
+
7
+ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
8
+
9
+ def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None):
10
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
11
+ raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
12
+ if not isinstance(window_size, int) or window_size <= 0:
13
+ raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}")
14
+ self.ngram_size = ngram_size
15
+ self.window_size = window_size
16
+ self.whitelist_token_ids = whitelist_token_ids or set()
17
+
18
+ def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
19
+ if len(input_ids) < self.ngram_size:
20
+ return scores
21
+
22
+ current_prefix = tuple(input_ids[-(self.ngram_size - 1):])
23
+
24
+ search_start = max(0, len(input_ids) - self.window_size)
25
+ search_end = len(input_ids) - self.ngram_size + 1
26
+
27
+ banned_tokens = set()
28
+ for i in range(search_start, search_end):
29
+ ngram = tuple(input_ids[i:i + self.ngram_size])
30
+ if ngram[:-1] == current_prefix:
31
+ banned_tokens.add(ngram[-1])
32
+
33
+ banned_tokens = banned_tokens - self.whitelist_token_ids
34
+
35
+ if banned_tokens:
36
+ scores = scores.clone()
37
+ for token in banned_tokens:
38
+ scores[token] = -float("inf")
39
+
40
+ return scores
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_eval_batch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from tqdm import tqdm
4
+ import torch
5
+ if torch.version.cuda == '11.8':
6
+ os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
7
+ os.environ['VLLM_USE_V1'] = '0'
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
9
+
10
+ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, MAX_CONCURRENCY, CROP_MODE, NUM_WORKERS
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import glob
13
+ from PIL import Image
14
+ from deepseek_ocr import DeepseekOCRForCausalLM
15
+
16
+ from vllm.model_executor.models.registry import ModelRegistry
17
+
18
+ from vllm import LLM, SamplingParams
19
+ from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
20
+ from process.image_process import DeepseekOCRProcessor
21
+ ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
22
+
23
+
24
+ llm = LLM(
25
+ model=MODEL_PATH,
26
+ hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
27
+ block_size=256,
28
+ enforce_eager=False,
29
+ trust_remote_code=True,
30
+ max_model_len=8192,
31
+ swap_space=0,
32
+ max_num_seqs = MAX_CONCURRENCY,
33
+ tensor_parallel_size=1,
34
+ gpu_memory_utilization=0.9,
35
+ )
36
+
37
+ logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=40, window_size=90, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
38
+
39
+ sampling_params = SamplingParams(
40
+ temperature=0.0,
41
+ max_tokens=8192,
42
+ logits_processors=logits_processors,
43
+ skip_special_tokens=False,
44
+ )
45
+
46
+ class Colors:
47
+ RED = '\033[31m'
48
+ GREEN = '\033[32m'
49
+ YELLOW = '\033[33m'
50
+ BLUE = '\033[34m'
51
+ RESET = '\033[0m'
52
+
53
+ def clean_formula(text):
54
+
55
+ formula_pattern = r'\\\[(.*?)\\\]'
56
+
57
+ def process_formula(match):
58
+ formula = match.group(1)
59
+
60
+ formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula)
61
+
62
+ formula = formula.strip()
63
+
64
+ return r'\[' + formula + r'\]'
65
+
66
+ cleaned_text = re.sub(formula_pattern, process_formula, text)
67
+
68
+ return cleaned_text
69
+
70
+ def re_match(text):
71
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
72
+ matches = re.findall(pattern, text, re.DOTALL)
73
+
74
+
75
+ # mathes_image = []
76
+ mathes_other = []
77
+ for a_match in matches:
78
+ mathes_other.append(a_match[0])
79
+ return matches, mathes_other
80
+
81
+ def process_single_image(image):
82
+ """single image"""
83
+ prompt_in = prompt
84
+ cache_item = {
85
+ "prompt": prompt_in,
86
+ "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
87
+ }
88
+ return cache_item
89
+
90
+
91
+ if __name__ == "__main__":
92
+
93
+ # INPUT_PATH = OmniDocBench images path
94
+
95
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
96
+
97
+ # print('image processing until processing prompts.....')
98
+
99
+ print(f'{Colors.RED}glob images.....{Colors.RESET}')
100
+
101
+ images_path = glob.glob(f'{INPUT_PATH}/*')
102
+
103
+ images = []
104
+
105
+ for image_path in images_path:
106
+ image = Image.open(image_path).convert('RGB')
107
+ images.append(image)
108
+
109
+ prompt = PROMPT
110
+
111
+ # batch_inputs = []
112
+
113
+
114
+ # for image in tqdm(images):
115
+
116
+ # prompt_in = prompt
117
+ # cache_list = [
118
+ # {
119
+ # "prompt": prompt_in,
120
+ # "multi_modal_data": {"image": Image.open(image).convert('RGB')},
121
+ # }
122
+ # ]
123
+ # batch_inputs.extend(cache_list)
124
+
125
+ with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
126
+ batch_inputs = list(tqdm(
127
+ executor.map(process_single_image, images),
128
+ total=len(images),
129
+ desc="Pre-processed images"
130
+ ))
131
+
132
+
133
+
134
+
135
+ outputs_list = llm.generate(
136
+ batch_inputs,
137
+ sampling_params=sampling_params
138
+ )
139
+
140
+
141
+ output_path = OUTPUT_PATH
142
+
143
+ os.makedirs(output_path, exist_ok=True)
144
+
145
+ for output, image in zip(outputs_list, images_path):
146
+
147
+ content = output.outputs[0].text
148
+ mmd_det_path = output_path + image.split('/')[-1].replace('.jpg', '_det.md')
149
+
150
+ with open(mmd_det_path, 'w', encoding='utf-8') as afile:
151
+ afile.write(content)
152
+
153
+ content = clean_formula(content)
154
+ matches_ref, mathes_other = re_match(content)
155
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
156
+ content = content.replace(a_match_other, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n').replace('<center>', '').replace('</center>', '')
157
+
158
+ mmd_path = output_path + image.split('/')[-1].replace('.jpg', '.md')
159
+
160
+ with open(mmd_path, 'w', encoding='utf-8') as afile:
161
+ afile.write(content)
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_image.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ import os
4
+
5
+ import torch
6
+ if torch.version.cuda == '11.8':
7
+ os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
8
+
9
+ os.environ['VLLM_USE_V1'] = '0'
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
11
+
12
+ from vllm import AsyncLLMEngine, SamplingParams
13
+ from vllm.engine.arg_utils import AsyncEngineArgs
14
+ from vllm.model_executor.models.registry import ModelRegistry
15
+ import time
16
+ from deepseek_ocr import DeepseekOCRForCausalLM
17
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+ from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
21
+ from process.image_process import DeepseekOCRProcessor
22
+ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
23
+
24
+
25
+
26
+ ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
27
+
28
+ def load_image(image_path):
29
+
30
+ try:
31
+ image = Image.open(image_path)
32
+
33
+ corrected_image = ImageOps.exif_transpose(image)
34
+
35
+ return corrected_image
36
+
37
+ except Exception as e:
38
+ print(f"error: {e}")
39
+ try:
40
+ return Image.open(image_path)
41
+ except:
42
+ return None
43
+
44
+
45
+ def re_match(text):
46
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
47
+ matches = re.findall(pattern, text, re.DOTALL)
48
+
49
+
50
+ mathes_image = []
51
+ mathes_other = []
52
+ for a_match in matches:
53
+ if '<|ref|>image<|/ref|>' in a_match[0]:
54
+ mathes_image.append(a_match[0])
55
+ else:
56
+ mathes_other.append(a_match[0])
57
+ return matches, mathes_image, mathes_other
58
+
59
+
60
+ def extract_coordinates_and_label(ref_text, image_width, image_height):
61
+
62
+
63
+ try:
64
+ label_type = ref_text[1]
65
+ cor_list = eval(ref_text[2])
66
+ except Exception as e:
67
+ print(e)
68
+ return None
69
+
70
+ return (label_type, cor_list)
71
+
72
+
73
+ def draw_bounding_boxes(image, refs):
74
+
75
+ image_width, image_height = image.size
76
+ img_draw = image.copy()
77
+ draw = ImageDraw.Draw(img_draw)
78
+
79
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
80
+ draw2 = ImageDraw.Draw(overlay)
81
+
82
+ # except IOError:
83
+ font = ImageFont.load_default()
84
+
85
+ img_idx = 0
86
+
87
+ for i, ref in enumerate(refs):
88
+ try:
89
+ result = extract_coordinates_and_label(ref, image_width, image_height)
90
+ if result:
91
+ label_type, points_list = result
92
+
93
+ color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
94
+
95
+ color_a = color + (20, )
96
+ for points in points_list:
97
+ x1, y1, x2, y2 = points
98
+
99
+ x1 = int(x1 / 999 * image_width)
100
+ y1 = int(y1 / 999 * image_height)
101
+
102
+ x2 = int(x2 / 999 * image_width)
103
+ y2 = int(y2 / 999 * image_height)
104
+
105
+ if label_type == 'image':
106
+ try:
107
+ cropped = image.crop((x1, y1, x2, y2))
108
+ cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
109
+ except Exception as e:
110
+ print(e)
111
+ pass
112
+ img_idx += 1
113
+
114
+ try:
115
+ if label_type == 'title':
116
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
117
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
118
+ else:
119
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
120
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
121
+
122
+ text_x = x1
123
+ text_y = max(0, y1 - 15)
124
+
125
+ text_bbox = draw.textbbox((0, 0), label_type, font=font)
126
+ text_width = text_bbox[2] - text_bbox[0]
127
+ text_height = text_bbox[3] - text_bbox[1]
128
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
129
+ fill=(255, 255, 255, 30))
130
+
131
+ draw.text((text_x, text_y), label_type, font=font, fill=color)
132
+ except:
133
+ pass
134
+ except:
135
+ continue
136
+ img_draw.paste(overlay, (0, 0), overlay)
137
+ return img_draw
138
+
139
+
140
+ def process_image_with_refs(image, ref_texts):
141
+ result_image = draw_bounding_boxes(image, ref_texts)
142
+ return result_image
143
+
144
+
145
+
146
+
147
+ async def stream_generate(image=None, prompt=''):
148
+
149
+
150
+ engine_args = AsyncEngineArgs(
151
+ model=MODEL_PATH,
152
+ hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
153
+ block_size=256,
154
+ max_model_len=8192,
155
+ enforce_eager=False,
156
+ trust_remote_code=True,
157
+ tensor_parallel_size=1,
158
+ gpu_memory_utilization=0.75,
159
+ )
160
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
161
+
162
+ logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td>
163
+
164
+ sampling_params = SamplingParams(
165
+ temperature=0.0,
166
+ max_tokens=8192,
167
+ logits_processors=logits_processors,
168
+ skip_special_tokens=False,
169
+ # ignore_eos=False,
170
+
171
+ )
172
+
173
+ request_id = f"request-{int(time.time())}"
174
+
175
+ printed_length = 0
176
+
177
+ if image and '<image>' in prompt:
178
+ request = {
179
+ "prompt": prompt,
180
+ "multi_modal_data": {"image": image}
181
+ }
182
+ elif prompt:
183
+ request = {
184
+ "prompt": prompt
185
+ }
186
+ else:
187
+ assert False, f'prompt is none!!!'
188
+ async for request_output in engine.generate(
189
+ request, sampling_params, request_id
190
+ ):
191
+ if request_output.outputs:
192
+ full_text = request_output.outputs[0].text
193
+ new_text = full_text[printed_length:]
194
+ print(new_text, end='', flush=True)
195
+ printed_length = len(full_text)
196
+ final_output = full_text
197
+ print('\n')
198
+
199
+ return final_output
200
+
201
+
202
+
203
+
204
+ if __name__ == "__main__":
205
+
206
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
207
+ os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
208
+
209
+ image = load_image(INPUT_PATH).convert('RGB')
210
+
211
+
212
+ if '<image>' in PROMPT:
213
+
214
+ image_features = DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)
215
+ else:
216
+ image_features = ''
217
+
218
+ prompt = PROMPT
219
+
220
+ result_out = asyncio.run(stream_generate(image_features, prompt))
221
+
222
+
223
+ save_results = 1
224
+
225
+ if save_results and '<image>' in prompt:
226
+ print('='*15 + 'save results:' + '='*15)
227
+
228
+ image_draw = image.copy()
229
+
230
+ outputs = result_out
231
+
232
+ with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile:
233
+ afile.write(outputs)
234
+
235
+ matches_ref, matches_images, mathes_other = re_match(outputs)
236
+ # print(matches_ref)
237
+ result = process_image_with_refs(image_draw, matches_ref)
238
+
239
+
240
+ for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
241
+ outputs = outputs.replace(a_match_image, f'![](images/' + str(idx) + '.jpg)\n')
242
+
243
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
244
+ outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
245
+
246
+ # if 'structural formula' in conversation[0]['content']:
247
+ # outputs = '<smiles>' + outputs + '</smiles>'
248
+ with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile:
249
+ afile.write(outputs)
250
+
251
+ if 'line_type' in outputs:
252
+ import matplotlib.pyplot as plt
253
+ from matplotlib.patches import Circle
254
+ lines = eval(outputs)['Line']['line']
255
+
256
+ line_type = eval(outputs)['Line']['line_type']
257
+ # print(lines)
258
+
259
+ endpoints = eval(outputs)['Line']['line_endpoint']
260
+
261
+ fig, ax = plt.subplots(figsize=(3,3), dpi=200)
262
+ ax.set_xlim(-15, 15)
263
+ ax.set_ylim(-15, 15)
264
+
265
+ for idx, line in enumerate(lines):
266
+ try:
267
+ p0 = eval(line.split(' -- ')[0])
268
+ p1 = eval(line.split(' -- ')[-1])
269
+
270
+ if line_type[idx] == '--':
271
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
272
+ else:
273
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
274
+
275
+ ax.scatter(p0[0], p0[1], s=5, color = 'k')
276
+ ax.scatter(p1[0], p1[1], s=5, color = 'k')
277
+ except:
278
+ pass
279
+
280
+ for endpoint in endpoints:
281
+
282
+ label = endpoint.split(': ')[0]
283
+ (x, y) = eval(endpoint.split(': ')[1])
284
+ ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
285
+ fontsize=5, fontweight='light')
286
+
287
+ try:
288
+ if 'Circle' in eval(outputs).keys():
289
+ circle_centers = eval(outputs)['Circle']['circle_center']
290
+ radius = eval(outputs)['Circle']['radius']
291
+
292
+ for center, r in zip(circle_centers, radius):
293
+ center = eval(center.split(': ')[1])
294
+ circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8)
295
+ ax.add_patch(circle)
296
+ except:
297
+ pass
298
+
299
+
300
+ plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
301
+ plt.close()
302
+
303
+ result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg')
DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fitz
3
+ import img2pdf
4
+ import io
5
+ import re
6
+ from tqdm import tqdm
7
+ import torch
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+
11
+ if torch.version.cuda == '11.8':
12
+ os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
13
+ os.environ['VLLM_USE_V1'] = '0'
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
15
+
16
+
17
+ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
18
+
19
+ from PIL import Image, ImageDraw, ImageFont
20
+ import numpy as np
21
+ from deepseek_ocr import DeepseekOCRForCausalLM
22
+
23
+ from vllm.model_executor.models.registry import ModelRegistry
24
+
25
+ from vllm import LLM, SamplingParams
26
+ from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
27
+ from process.image_process import DeepseekOCRProcessor
28
+
29
+ ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
30
+
31
+
32
+ llm = LLM(
33
+ model=MODEL_PATH,
34
+ hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
35
+ block_size=256,
36
+ enforce_eager=False,
37
+ trust_remote_code=True,
38
+ max_model_len=8192,
39
+ swap_space=0,
40
+ max_num_seqs=MAX_CONCURRENCY,
41
+ tensor_parallel_size=1,
42
+ gpu_memory_utilization=0.9,
43
+ disable_mm_preprocessor_cache=True
44
+ )
45
+
46
+ logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
47
+
48
+ sampling_params = SamplingParams(
49
+ temperature=0.0,
50
+ max_tokens=8192,
51
+ logits_processors=logits_processors,
52
+ skip_special_tokens=False,
53
+ include_stop_str_in_output=True,
54
+ )
55
+
56
+
57
+ class Colors:
58
+ RED = '\033[31m'
59
+ GREEN = '\033[32m'
60
+ YELLOW = '\033[33m'
61
+ BLUE = '\033[34m'
62
+ RESET = '\033[0m'
63
+
64
+ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
65
+ """
66
+ pdf2images
67
+ """
68
+ images = []
69
+
70
+ pdf_document = fitz.open(pdf_path)
71
+
72
+ zoom = dpi / 72.0
73
+ matrix = fitz.Matrix(zoom, zoom)
74
+
75
+ for page_num in range(pdf_document.page_count):
76
+ page = pdf_document[page_num]
77
+
78
+ pixmap = page.get_pixmap(matrix=matrix, alpha=False)
79
+ Image.MAX_IMAGE_PIXELS = None
80
+
81
+ if image_format.upper() == "PNG":
82
+ img_data = pixmap.tobytes("png")
83
+ img = Image.open(io.BytesIO(img_data))
84
+ else:
85
+ img_data = pixmap.tobytes("png")
86
+ img = Image.open(io.BytesIO(img_data))
87
+ if img.mode in ('RGBA', 'LA'):
88
+ background = Image.new('RGB', img.size, (255, 255, 255))
89
+ background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
90
+ img = background
91
+
92
+ images.append(img)
93
+
94
+ pdf_document.close()
95
+ return images
96
+
97
+ def pil_to_pdf_img2pdf(pil_images, output_path):
98
+
99
+ if not pil_images:
100
+ return
101
+
102
+ image_bytes_list = []
103
+
104
+ for img in pil_images:
105
+ if img.mode != 'RGB':
106
+ img = img.convert('RGB')
107
+
108
+ img_buffer = io.BytesIO()
109
+ img.save(img_buffer, format='JPEG', quality=95)
110
+ img_bytes = img_buffer.getvalue()
111
+ image_bytes_list.append(img_bytes)
112
+
113
+ try:
114
+ pdf_bytes = img2pdf.convert(image_bytes_list)
115
+ with open(output_path, "wb") as f:
116
+ f.write(pdf_bytes)
117
+
118
+ except Exception as e:
119
+ print(f"error: {e}")
120
+
121
+
122
+
123
+ def re_match(text):
124
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
125
+ matches = re.findall(pattern, text, re.DOTALL)
126
+
127
+
128
+ mathes_image = []
129
+ mathes_other = []
130
+ for a_match in matches:
131
+ if '<|ref|>image<|/ref|>' in a_match[0]:
132
+ mathes_image.append(a_match[0])
133
+ else:
134
+ mathes_other.append(a_match[0])
135
+ return matches, mathes_image, mathes_other
136
+
137
+
138
+ def extract_coordinates_and_label(ref_text, image_width, image_height):
139
+
140
+
141
+ try:
142
+ label_type = ref_text[1]
143
+ cor_list = eval(ref_text[2])
144
+ except Exception as e:
145
+ print(e)
146
+ return None
147
+
148
+ return (label_type, cor_list)
149
+
150
+
151
+ def draw_bounding_boxes(image, refs, jdx):
152
+
153
+ image_width, image_height = image.size
154
+ img_draw = image.copy()
155
+ draw = ImageDraw.Draw(img_draw)
156
+
157
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
158
+ draw2 = ImageDraw.Draw(overlay)
159
+
160
+ # except IOError:
161
+ font = ImageFont.load_default()
162
+
163
+ img_idx = 0
164
+
165
+ for i, ref in enumerate(refs):
166
+ try:
167
+ result = extract_coordinates_and_label(ref, image_width, image_height)
168
+ if result:
169
+ label_type, points_list = result
170
+
171
+ color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
172
+
173
+ color_a = color + (20, )
174
+ for points in points_list:
175
+ x1, y1, x2, y2 = points
176
+
177
+ x1 = int(x1 / 999 * image_width)
178
+ y1 = int(y1 / 999 * image_height)
179
+
180
+ x2 = int(x2 / 999 * image_width)
181
+ y2 = int(y2 / 999 * image_height)
182
+
183
+ if label_type == 'image':
184
+ try:
185
+ cropped = image.crop((x1, y1, x2, y2))
186
+ cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
187
+ except Exception as e:
188
+ print(e)
189
+ pass
190
+ img_idx += 1
191
+
192
+ try:
193
+ if label_type == 'title':
194
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
195
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
196
+ else:
197
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
198
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
199
+
200
+ text_x = x1
201
+ text_y = max(0, y1 - 15)
202
+
203
+ text_bbox = draw.textbbox((0, 0), label_type, font=font)
204
+ text_width = text_bbox[2] - text_bbox[0]
205
+ text_height = text_bbox[3] - text_bbox[1]
206
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
207
+ fill=(255, 255, 255, 30))
208
+
209
+ draw.text((text_x, text_y), label_type, font=font, fill=color)
210
+ except:
211
+ pass
212
+ except:
213
+ continue
214
+ img_draw.paste(overlay, (0, 0), overlay)
215
+ return img_draw
216
+
217
+
218
+ def process_image_with_refs(image, ref_texts, jdx):
219
+ result_image = draw_bounding_boxes(image, ref_texts, jdx)
220
+ return result_image
221
+
222
+
223
+ def process_single_image(image):
224
+ """single image"""
225
+ prompt_in = prompt
226
+ cache_item = {
227
+ "prompt": prompt_in,
228
+ "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
229
+ }
230
+ return cache_item
231
+
232
+
233
+ if __name__ == "__main__":
234
+
235
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
236
+ os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
237
+
238
+ print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
239
+
240
+
241
+ images = pdf_to_images_high_quality(INPUT_PATH)
242
+
243
+
244
+ prompt = PROMPT
245
+
246
+ # batch_inputs = []
247
+
248
+ with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
249
+ batch_inputs = list(tqdm(
250
+ executor.map(process_single_image, images),
251
+ total=len(images),
252
+ desc="Pre-processed images"
253
+ ))
254
+
255
+
256
+ # for image in tqdm(images):
257
+
258
+ # prompt_in = prompt
259
+ # cache_list = [
260
+ # {
261
+ # "prompt": prompt_in,
262
+ # "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
263
+ # }
264
+ # ]
265
+ # batch_inputs.extend(cache_list)
266
+
267
+
268
+ outputs_list = llm.generate(
269
+ batch_inputs,
270
+ sampling_params=sampling_params
271
+ )
272
+
273
+
274
+ output_path = OUTPUT_PATH
275
+
276
+ os.makedirs(output_path, exist_ok=True)
277
+
278
+
279
+ mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
280
+ mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
281
+ pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
282
+ contents_det = ''
283
+ contents = ''
284
+ draw_images = []
285
+ jdx = 0
286
+ for output, img in zip(outputs_list, images):
287
+ content = output.outputs[0].text
288
+
289
+ if '<|end▁of▁sentence|>' in content: # repeat no eos
290
+ content = content.replace('<|end▁of▁sentence|>', '')
291
+ else:
292
+ if SKIP_REPEAT:
293
+ continue
294
+
295
+
296
+ page_num = f'\n<--- Page Split --->'
297
+
298
+ contents_det += content + f'\n{page_num}\n'
299
+
300
+ image_draw = img.copy()
301
+
302
+ matches_ref, matches_images, mathes_other = re_match(content)
303
+ # print(matches_ref)
304
+ result_image = process_image_with_refs(image_draw, matches_ref, jdx)
305
+
306
+
307
+ draw_images.append(result_image)
308
+
309
+
310
+ for idx, a_match_image in enumerate(matches_images):
311
+ content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n')
312
+
313
+ for idx, a_match_other in enumerate(mathes_other):
314
+ content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
315
+
316
+
317
+ contents += content + f'\n{page_num}\n'
318
+
319
+
320
+ jdx += 1
321
+
322
+ with open(mmd_det_path, 'w', encoding='utf-8') as afile:
323
+ afile.write(contents_det)
324
+
325
+ with open(mmd_path, 'w', encoding='utf-8') as afile:
326
+ afile.write(contents)
327
+
328
+
329
+ pil_to_pdf_img2pdf(draw_images, pdf_out_path)
330
+