vikhyatk commited on
Commit
9ba2958
1 Parent(s): 0691a81

Upload Moondream

Browse files
Files changed (3) hide show
  1. model.safetensors +1 -1
  2. moondream.py +6 -13
  3. vision_encoder.py +74 -24
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:15daa0c9a135d2084520e47d517ef689e09ee71c63d606bd8b8ff209bacc3e34
3
  size 3715037856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95f9c7ac62666e56e45a66d4c07a0d110be1c5cfbe4eef1ae86857ef2787ce19
3
  size 3715037856
moondream.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
4
  from transformers import PreTrainedModel
5
- import re
6
 
7
  from .modeling_phi import PhiForCausalLM
8
  from .configuration_moondream import PhiConfig
@@ -62,16 +61,13 @@ class Moondream(PreTrainedModel):
62
  image_embeds,
63
  prompt,
64
  tokenizer,
65
- eos_text="<END>",
66
  max_new_tokens=128,
67
  **kwargs,
68
  ):
69
- eos_tokens = tokenizer(eos_text, add_special_tokens=False)[0].ids
70
-
71
  generate_config = {
72
- "eos_token_id": eos_tokens,
73
  "bos_token_id": tokenizer.bos_token_id,
74
- "pad_token_id": tokenizer.eos_token_id,
75
  "max_new_tokens": max_new_tokens,
76
  **kwargs,
77
  }
@@ -97,12 +93,11 @@ class Moondream(PreTrainedModel):
97
  answer = self.generate(
98
  image_embeds,
99
  prompt,
100
- eos_text="<END>",
101
  tokenizer=tokenizer,
102
  max_new_tokens=512,
103
  **kwargs,
104
  )[0]
105
- cleaned_answer = re.sub("<$|<END$", "", answer).strip()
106
 
107
  # Use the result_queue to pass the result if it is provided
108
  if result_queue:
@@ -117,8 +112,6 @@ class Moondream(PreTrainedModel):
117
  tokenizer,
118
  **kwargs,
119
  ):
120
- eos_tokens = tokenizer("<END>", add_special_tokens=False)[0].ids
121
-
122
  image_embeds = self.encode_image(images)
123
 
124
  templated_prompts = [
@@ -159,9 +152,9 @@ class Moondream(PreTrainedModel):
159
  )
160
 
161
  generate_config = {
162
- "eos_token_id": eos_tokens,
163
  "bos_token_id": tokenizer.bos_token_id,
164
- "pad_token_id": tokenizer.eos_token_id,
165
  "max_new_tokens": 512,
166
  **kwargs,
167
  }
@@ -174,6 +167,6 @@ class Moondream(PreTrainedModel):
174
  )
175
 
176
  return [
177
- re.sub("<$|<END$", "", x).strip()
178
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
179
  ]
 
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
4
  from transformers import PreTrainedModel
 
5
 
6
  from .modeling_phi import PhiForCausalLM
7
  from .configuration_moondream import PhiConfig
 
61
  image_embeds,
62
  prompt,
63
  tokenizer,
 
64
  max_new_tokens=128,
65
  **kwargs,
66
  ):
 
 
67
  generate_config = {
68
+ "eos_token_id": tokenizer.eos_token_id,
69
  "bos_token_id": tokenizer.bos_token_id,
70
+ "pad_token_id": tokenizer.bos_token_id,
71
  "max_new_tokens": max_new_tokens,
72
  **kwargs,
73
  }
 
93
  answer = self.generate(
94
  image_embeds,
95
  prompt,
 
96
  tokenizer=tokenizer,
97
  max_new_tokens=512,
98
  **kwargs,
99
  )[0]
100
+ cleaned_answer = answer.strip()
101
 
102
  # Use the result_queue to pass the result if it is provided
103
  if result_queue:
 
112
  tokenizer,
113
  **kwargs,
114
  ):
 
 
115
  image_embeds = self.encode_image(images)
116
 
117
  templated_prompts = [
 
152
  )
153
 
154
  generate_config = {
155
+ "eos_token_id": tokenizer.eos_token_id,
156
  "bos_token_id": tokenizer.bos_token_id,
157
+ "pad_token_id": tokenizer.bos_token_id,
158
  "max_new_tokens": 512,
159
  **kwargs,
160
  }
 
167
  )
168
 
169
  return [
170
+ x.strip()
171
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
172
  ]
vision_encoder.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
 
2
  from torch import nn
3
- from PIL import Image
4
  from einops import rearrange
5
  from torchvision.transforms.v2 import (
6
  Compose,
@@ -10,34 +10,92 @@ from torchvision.transforms.v2 import (
10
  ToDtype,
11
  Normalize,
12
  )
13
- import timm
14
 
15
 
16
- class VisualHolder(nn.Module):
17
- def __init__(self, model):
18
  super().__init__()
19
- self.visual = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def forward(self, x):
22
- return self.visual(x)
 
 
23
 
24
 
25
- class ModelHolder(nn.Module):
26
- def __init__(self, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  super().__init__()
28
- self.model = model
29
 
30
  def forward(self, x):
31
- return self.model(x)
32
 
33
 
34
  class LinearPatchEmbedding(nn.Module):
35
- def __init__(self, conv):
 
36
  super().__init__()
37
  self.linear = nn.Linear(588, 1152)
38
- self.linear.weight.data = conv.weight.data.view(1152, -1)
39
- if conv.bias is not None:
40
- self.linear.bias.data = conv.bias.data
41
 
42
  def forward(self, x):
43
  return self.linear(x)
@@ -49,13 +107,12 @@ class MLP(nn.Module):
49
  in_features: int,
50
  hidden_features: int = None,
51
  out_features: int = None,
52
- act_layer: nn.Module = nn.GELU,
53
  ) -> None:
54
  super().__init__()
55
  out_features = out_features or in_features
56
  hidden_features = hidden_features or in_features
57
  self.fc1 = nn.Linear(in_features, hidden_features)
58
- self.act = act_layer()
59
  self.fc2 = nn.Linear(hidden_features, out_features)
60
 
61
  torch.nn.init.kaiming_normal_(
@@ -94,14 +151,7 @@ class VisionEncoder(nn.Module):
94
  def __init__(self) -> None:
95
  super().__init__()
96
 
97
- self.encoder = ModelHolder(
98
- VisualHolder(timm.create_model("vit_so400m_patch14_siglip_384"))
99
- )
100
- self.encoder.model.visual.patch_embed = LinearPatchEmbedding(
101
- self.encoder.model.visual.patch_embed.proj
102
- )
103
- self.encoder.model.visual.attn_pool = nn.Identity()
104
-
105
  self.projection = VisionProjection()
106
 
107
  self.preprocess = Compose(
 
1
  import torch
2
+ import torch.nn.functional as F
3
  from torch import nn
 
4
  from einops import rearrange
5
  from torchvision.transforms.v2 import (
6
  Compose,
 
10
  ToDtype,
11
  Normalize,
12
  )
 
13
 
14
 
15
+ class Attention(nn.Module):
16
+ def __init__(self, dim, num_heads=16):
17
  super().__init__()
18
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
19
+
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+
23
+ self.qkv = nn.Linear(dim, dim * 3)
24
+ self.proj = nn.Linear(dim, dim)
25
+
26
+ torch.nn.init.kaiming_normal_(
27
+ self.qkv.weight, mode="fan_in", nonlinearity="relu"
28
+ )
29
+ torch.nn.init.kaiming_normal_(
30
+ self.proj.weight, mode="fan_in", nonlinearity="relu"
31
+ )
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ B, N, C = x.shape
35
+ qkv = (
36
+ self.qkv(x)
37
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
38
+ .permute(2, 0, 3, 1, 4)
39
+ )
40
+ q, k, v = qkv.unbind(0)
41
+
42
+ x = F.scaled_dot_product_attention(q, k, v)
43
+
44
+ x = x.transpose(1, 2).reshape(B, N, C)
45
+ x = self.proj(x)
46
+ return x
47
+
48
+
49
+ class VitBlock(nn.Module):
50
+ def __init__(self, embed_dim):
51
+ super().__init__()
52
+ self.attn = Attention(embed_dim)
53
+ self.mlp = MLP(embed_dim, 4304)
54
+ self.norm1 = nn.LayerNorm(embed_dim)
55
+ self.norm2 = nn.LayerNorm(embed_dim)
56
 
57
  def forward(self, x):
58
+ x = x + self.attn(self.norm1(x))
59
+ x = x + self.mlp(self.norm2(x))
60
+ return x
61
 
62
 
63
+ class VisionTransformer(nn.Module):
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+
68
+ embed_len = 729
69
+ embed_dim = 1152
70
+
71
+ self.patch_embed = LinearPatchEmbedding()
72
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
73
+ self.blocks = nn.Sequential(*[VitBlock(embed_dim) for _ in range(27)])
74
+ self.norm = nn.LayerNorm(embed_dim)
75
+
76
+ def forward(self, x):
77
+ x = self.patch_embed(x)
78
+ x = x + self.pos_embed
79
+ for block in self.blocks:
80
+ x = block(x)
81
+ return self.norm(x)
82
+
83
+
84
+ class EncoderWrapper(nn.Module):
85
+
86
+ def __init__(self):
87
  super().__init__()
88
+ self.model = nn.ModuleDict({"visual": VisionTransformer()})
89
 
90
  def forward(self, x):
91
+ return self.model["visual"](x)
92
 
93
 
94
  class LinearPatchEmbedding(nn.Module):
95
+
96
+ def __init__(self):
97
  super().__init__()
98
  self.linear = nn.Linear(588, 1152)
 
 
 
99
 
100
  def forward(self, x):
101
  return self.linear(x)
 
107
  in_features: int,
108
  hidden_features: int = None,
109
  out_features: int = None,
 
110
  ) -> None:
111
  super().__init__()
112
  out_features = out_features or in_features
113
  hidden_features = hidden_features or in_features
114
  self.fc1 = nn.Linear(in_features, hidden_features)
115
+ self.act = nn.GELU(approximate="tanh")
116
  self.fc2 = nn.Linear(hidden_features, out_features)
117
 
118
  torch.nn.init.kaiming_normal_(
 
151
  def __init__(self) -> None:
152
  super().__init__()
153
 
154
+ self.encoder = EncoderWrapper()
 
 
 
 
 
 
 
155
  self.projection = VisionProjection()
156
 
157
  self.preprocess = Compose(