qe2 commited on
Commit
7c27d36
·
verified ·
1 Parent(s): 9125686

Upload 31 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ moondream2-mmproj-f16.gguf filter=lfs diff=lfs merge=lfs -text
37
+ moondream2-text-model-f16.gguf filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,75 @@
1
  ---
2
- license: cc-by-nc-sa-4.0
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ pipeline_tag: image-text-to-text
4
  ---
5
+
6
+ Moondream is a small vision language model designed to run efficiently everywhere.
7
+
8
+ [Website](https://moondream.ai/) / [Demo](https://moondream.ai/playground) / [GitHub](https://github.com/vikhyat/moondream)
9
+
10
+ This repository contains the latest (**2025-06-21**) release of Moondream, as well as [historical releases](https://huggingface.co/vikhyatk/moondream2/blob/main/versions.txt). The model is updated frequently, so we recommend specifying a revision as shown below if you're using it in a production application.
11
+
12
+
13
+ ### Usage
14
+
15
+ ```python
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ from PIL import Image
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ "vikhyatk/moondream2",
21
+ revision="2025-06-21",
22
+ trust_remote_code=True,
23
+ # Uncomment to run on GPU.
24
+ # device_map={"": "cuda"}
25
+ )
26
+
27
+ # Captioning
28
+ print("Short caption:")
29
+ print(model.caption(image, length="short")["caption"])
30
+
31
+ print("\nNormal caption:")
32
+ for t in model.caption(image, length="normal", stream=True)["caption"]:
33
+ # Streaming generation example, supported for caption() and detect()
34
+ print(t, end="", flush=True)
35
+ print(model.caption(image, length="normal"))
36
+
37
+ # Visual Querying
38
+ print("\nVisual query: 'How many people are in the image?'")
39
+ print(model.query(image, "How many people are in the image?")["answer"])
40
+
41
+ # Object Detection
42
+ print("\nObject detection: 'face'")
43
+ objects = model.detect(image, "face")["objects"]
44
+ print(f"Found {len(objects)} face(s)")
45
+
46
+ # Pointing
47
+ print("\nPointing: 'person'")
48
+ points = model.point(image, "person")["points"]
49
+ print(f"Found {len(points)} person(s)")
50
+ ```
51
+
52
+ ### Changelog
53
+
54
+ **2025-06-21**
55
+
56
+ (release notes coming soon)
57
+
58
+ **2025-04-15** ([full release notes](https://moondream.ai/blog/moondream-2025-04-14-release))
59
+
60
+ 1. Improved chart understanding (ChartQA up from 74.8 to 77.5, 82.2 with PoT)
61
+ 2. Added temperature and nucleus sampling to reduce repetitive outputs
62
+ 3. Better OCR for documents and tables (prompt with “Transcribe the text” or “Transcribe the text in natural reading order”)
63
+ 4. Object detection supports document layout detection (figure, formula, text, etc)
64
+ 5. UI understanding (ScreenSpot F1\@0.5 up from 53.3 to 60.3)
65
+ 6. Improved text understanding (DocVQA up from 76.5 to 79.3, TextVQA up from 74.6 to 76.3)
66
+
67
+ **2025-03-27** ([full release notes](https://moondream.ai/blog/moondream-2025-03-27-release))
68
+
69
+ 1. Added support for long-form captioning
70
+ 2. Open vocabulary image tagging
71
+ 3. Improved counting accuracy (e.g. CountBenchQA increased from 80 to 86.4)
72
+ 4. Improved text understanding (e.g. OCRBench increased from 58.3 to 61.2)
73
+ 5. Improved object detection, especially for small objects (e.g. COCO up from 30.5 to 51.2)
74
+ 6. Fixed token streaming bug affecting multi-byte unicode characters
75
+ 7. gpt-fast style `compile()` now supported in HF Transformers implementation
added_tokens.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "\t\t": 50294,
3
+ "\t\t\t": 50293,
4
+ "\t\t\t\t": 50292,
5
+ "\t\t\t\t\t": 50291,
6
+ "\t\t\t\t\t\t": 50290,
7
+ "\t\t\t\t\t\t\t": 50289,
8
+ "\t\t\t\t\t\t\t\t": 50288,
9
+ "\t\t\t\t\t\t\t\t\t": 50287,
10
+ " ": 50286,
11
+ " ": 50285,
12
+ " ": 50284,
13
+ " ": 50283,
14
+ " ": 50282,
15
+ " ": 50281,
16
+ " ": 50280,
17
+ " ": 50279,
18
+ " ": 50278,
19
+ " ": 50277,
20
+ " ": 50276,
21
+ " ": 50275,
22
+ " ": 50274,
23
+ " ": 50273,
24
+ " ": 50272,
25
+ " ": 50271,
26
+ " ": 50270,
27
+ " ": 50269,
28
+ " ": 50268,
29
+ " ": 50267,
30
+ " ": 50266,
31
+ " ": 50265,
32
+ " ": 50264,
33
+ " ": 50263,
34
+ " ": 50262,
35
+ " ": 50261,
36
+ " ": 50260,
37
+ " ": 50259,
38
+ " ": 50258,
39
+ " ": 50257
40
+ }
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HfMoondream"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
+ },
9
+ "config": {},
10
+ "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
+ "transformers_version": "4.52.4"
13
+ }
config.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextConfig:
7
+ dim: int = 2048
8
+ ff_dim: int = 8192
9
+ n_layers: int = 24
10
+ vocab_size: int = 51200
11
+ max_context: int = 2048
12
+ n_heads: int = 32
13
+ n_kv_heads: int = 32
14
+ prefix_attn: int = 730
15
+ group_size: Optional[int] = None
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class VisionConfig:
20
+ enc_dim: int = 1152
21
+ enc_patch_size: int = 14
22
+ enc_n_layers: int = 27
23
+ enc_ff_dim: int = 4304
24
+ enc_n_heads: int = 16
25
+ proj_out_dim: int = 2048
26
+ crop_size: int = 378
27
+ in_channels: int = 3
28
+ max_crops: int = 12
29
+ overlap_margin: int = 4
30
+ proj_inner_dim: int = 8192
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RegionConfig:
35
+ dim: int = 2048
36
+ coord_feat_dim: int = 256
37
+ coord_out_dim: int = 1024
38
+ size_feat_dim: int = 512
39
+ size_out_dim: int = 2048
40
+ inner_dim: int = 8192
41
+ group_size: Optional[int] = None
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class TokenizerConfig:
46
+ bos_id: int = 0
47
+ eos_id: int = 0
48
+ answer_id: int = 3
49
+ thinking_id: int = 4
50
+ coord_id: int = 5
51
+ size_id: int = 6
52
+ start_ground_points_id: int = 7
53
+ end_ground_id: int = 9
54
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
55
+ default_factory=lambda: {
56
+ "caption": {
57
+ "short": [1, 32708, 2, 12492, 3],
58
+ "normal": [1, 32708, 2, 6382, 3],
59
+ "long": [1, 32708, 2, 4059, 3],
60
+ },
61
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
62
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
63
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
64
+ }
65
+ )
66
+
67
+
68
+ @dataclass(frozen=True)
69
+ class MoondreamConfig:
70
+ text: TextConfig = TextConfig()
71
+ vision: VisionConfig = VisionConfig()
72
+ region: RegionConfig = RegionConfig()
73
+ tokenizer: TokenizerConfig = TokenizerConfig()
74
+
75
+ @classmethod
76
+ def from_dict(cls, config_dict: dict):
77
+ text_config = TextConfig(**config_dict.get("text", {}))
78
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
79
+ region_config = RegionConfig(**config_dict.get("region", {}))
80
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
81
+ return cls(
82
+ text=text_config,
83
+ vision=vision_config,
84
+ region=region_config,
85
+ tokenizer=tokenizer_config,
86
+ )
87
+
88
+ def to_dict(self):
89
+ return {
90
+ "text": self.text.__dict__,
91
+ "vision": self.vision.__dict__,
92
+ "region": self.region.__dict__,
93
+ "tokenizer": self.tokenizer.__dict__,
94
+ }
configuration_moondream.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PhiConfig(PretrainedConfig):
5
+ model_type = "phi"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=51200,
11
+ hidden_size=2048,
12
+ intermediate_size=8192,
13
+ num_hidden_layers=24,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ resid_pdrop=0.0,
17
+ embd_pdrop=0.0,
18
+ attention_dropout=0.0,
19
+ hidden_act="gelu_new",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-5,
23
+ use_cache=True,
24
+ tie_word_embeddings=False,
25
+ rope_theta=10000.0,
26
+ rope_scaling=None,
27
+ partial_rotary_factor=0.5,
28
+ bos_token_id=1,
29
+ eos_token_id=2,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.hidden_size = hidden_size
34
+ self.intermediate_size = intermediate_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+
38
+ if num_key_value_heads is None:
39
+ num_key_value_heads = num_attention_heads
40
+
41
+ self.num_key_value_heads = num_key_value_heads
42
+ self.resid_pdrop = resid_pdrop
43
+ self.embd_pdrop = embd_pdrop
44
+ self.attention_dropout = attention_dropout
45
+ self.hidden_act = hidden_act
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.initializer_range = initializer_range
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.use_cache = use_cache
50
+ self.rope_theta = rope_theta
51
+ self.rope_scaling = rope_scaling
52
+ self.partial_rotary_factor = partial_rotary_factor
53
+ self._rope_scaling_validation()
54
+
55
+ super().__init__(
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
63
+ def _rope_scaling_validation(self):
64
+ """
65
+ Validate the `rope_scaling` configuration.
66
+ """
67
+ if self.rope_scaling is None:
68
+ return
69
+
70
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
71
+ raise ValueError(
72
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
73
+ f"got {self.rope_scaling}"
74
+ )
75
+ rope_scaling_type = self.rope_scaling.get("type", None)
76
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
77
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
78
+ raise ValueError(
79
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
80
+ )
81
+ if (
82
+ rope_scaling_factor is None
83
+ or not isinstance(rope_scaling_factor, float)
84
+ or rope_scaling_factor <= 1.0
85
+ ):
86
+ raise ValueError(
87
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
88
+ )
89
+
90
+
91
+ class MoondreamConfig(PretrainedConfig):
92
+ model_type = "moondream1"
93
+
94
+ def __init__(self, **kwargs):
95
+ self.text_config = PhiConfig(**kwargs.pop("text_config", {}))
96
+ super().__init__(**kwargs)
fourier_features.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ class FourierFeatures(nn.Module):
9
+ def __init__(self, in_features, out_features, std=1.0):
10
+ super().__init__()
11
+ assert out_features % 2 == 0
12
+ self.register_buffer(
13
+ "weight", torch.randn([out_features // 2, in_features]) * std
14
+ )
15
+
16
+ def forward(self, input):
17
+ f = 2 * math.pi * input @ self.weight.T
18
+ return torch.cat([f.cos(), f.sin()], dim=-1)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.44.0"
4
+ }
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from PIL import Image
3
+ import torch
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_dir):
9
+ self.model_id = "vikhyatk/moondream2"
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True)
11
+ self.tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", trust_remote_code=True)
12
+
13
+ # Check if CUDA (GPU support) is available and then set the device to GPU or CPU
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model.to(self.device)
16
+
17
+ def preprocess_image(self, encoded_image):
18
+ """Decode and preprocess the input image."""
19
+ decoded_image = base64.b64decode(encoded_image)
20
+ img = Image.open(BytesIO(decoded_image)).convert("RGB")
21
+ return img
22
+
23
+ def __call__(self, data):
24
+ """Handle the incoming request."""
25
+ try:
26
+ # Extract the inputs from the data
27
+ inputs = data.pop("inputs", data)
28
+ input_image = inputs['image']
29
+ question = inputs.get('question', "move to the red ball")
30
+
31
+ # Preprocess the image
32
+ img = self.preprocess_image(input_image)
33
+
34
+ # Perform inference
35
+ enc_image = self.model.encode_image(img).to(self.device)
36
+ answer = self.model.answer_question(enc_image, question, self.tokenizer)
37
+
38
+ # If the output is a tensor, move it back to CPU and convert to list
39
+ if isinstance(answer, torch.Tensor):
40
+ answer = answer.cpu().numpy().tolist()
41
+
42
+ # Create the response
43
+ response = {
44
+ "statusCode": 200,
45
+ "body": {
46
+ "answer": answer
47
+ }
48
+ }
49
+ return response
50
+ except Exception as e:
51
+ # Handle any errors
52
+ response = {
53
+ "statusCode": 500,
54
+ "body": {
55
+ "error": str(e)
56
+ }
57
+ }
58
+ return response
hf_moondream.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Union
6
+
7
+ from .config import MoondreamConfig
8
+ from .moondream import MoondreamModel
9
+
10
+ # Files sometimes don't get loaded without these...
11
+ from .image_crops import *
12
+ from .vision import *
13
+ from .text import *
14
+ from .region import *
15
+ from .utils import *
16
+
17
+
18
+ def extract_question(text):
19
+ prefix = "<image>\n\nQuestion: "
20
+ suffix = "\n\nAnswer:"
21
+
22
+ if text.startswith(prefix) and text.endswith(suffix):
23
+ return text[len(prefix) : -len(suffix)]
24
+ else:
25
+ return None
26
+
27
+
28
+ class HfConfig(PretrainedConfig):
29
+ _auto_class = "AutoConfig"
30
+ model_type = "moondream1"
31
+
32
+ def __init__(self, **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.config = {}
35
+
36
+
37
+ class HfMoondream(PreTrainedModel):
38
+ _auto_class = "AutoModelForCausalLM"
39
+ config_class = HfConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.model = MoondreamModel(
44
+ MoondreamConfig.from_dict(config.config), setup_caches=False
45
+ )
46
+ self._is_kv_cache_setup = False
47
+
48
+ def _setup_caches(self):
49
+ if not self._is_kv_cache_setup:
50
+ self.model._setup_caches()
51
+ self._is_kv_cache_setup = True
52
+
53
+ @property
54
+ def encode_image(self):
55
+ self._setup_caches()
56
+ return self.model.encode_image
57
+
58
+ @property
59
+ def query(self):
60
+ self._setup_caches()
61
+ return self.model.query
62
+
63
+ @property
64
+ def caption(self):
65
+ self._setup_caches()
66
+ return self.model.caption
67
+
68
+ @property
69
+ def detect(self):
70
+ self._setup_caches()
71
+ return self.model.detect
72
+
73
+ @property
74
+ def point(self):
75
+ self._setup_caches()
76
+ return self.model.point
77
+
78
+ @property
79
+ def detect_gaze(self):
80
+ self._setup_caches()
81
+ return self.model.detect_gaze
82
+
83
+ def answer_question(
84
+ self,
85
+ image_embeds,
86
+ question,
87
+ tokenizer=None,
88
+ chat_history="",
89
+ result_queue=None,
90
+ max_new_tokens=256,
91
+ **kwargs
92
+ ):
93
+ answer = self.query(image_embeds, question)["answer"].strip()
94
+
95
+ if result_queue is not None:
96
+ result_queue.put(answer)
97
+ return answer
98
+
99
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
100
+ answers = []
101
+ for image, prompt in zip(images, prompts):
102
+ answers.append(self.query(image, prompt)["answer"].strip())
103
+ return answers
104
+
105
+ def _unsupported_exception(self):
106
+ raise NotImplementedError(
107
+ "This method is not supported in the latest version of moondream. "
108
+ "Consider upgrading to the updated API spec, or alternately pin "
109
+ "to 'revision=2024-08-26'."
110
+ )
111
+
112
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
113
+ """
114
+ Function definition remains unchanged for backwards compatibility.
115
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
116
+ """
117
+ prompt_extracted = extract_question(prompt)
118
+ if prompt_extracted is not None:
119
+ answer = self.model.query(
120
+ image=image_embeds, question=prompt_extracted, stream=False
121
+ )["answer"]
122
+ else:
123
+ image_embeds = self.encode_image(image_embeds)
124
+ prompt_tokens = torch.tensor(
125
+ [self.model.tokenizer.encode(prompt).ids],
126
+ device=self.device,
127
+ )
128
+
129
+ def generator():
130
+ for token in self.model._generate_answer(
131
+ prompt_tokens,
132
+ image_embeds.kv_cache,
133
+ image_embeds.pos,
134
+ max_new_tokens,
135
+ ):
136
+ yield token
137
+
138
+ answer = "".join(list(generator()))
139
+
140
+ return [answer]
141
+
142
+ def get_input_embeddings(self) -> nn.Embedding:
143
+ """
144
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
145
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
146
+ **shares** the weight tensor—no copy is made.
147
+ """
148
+ if not hasattr(self, "_input_embeddings"):
149
+ self._input_embeddings = nn.Embedding.from_pretrained(
150
+ self.model.text.wte, # tensor created in text.py
151
+ freeze=True, # set to False if you need it trainable
152
+ )
153
+ return self._input_embeddings
154
+
155
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
156
+ """
157
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
158
+ embeddings and keeps everything tied to `self.model.text.wte`.
159
+ """
160
+ # 1. point the low-level parameter to the new weight matrix
161
+ self.model.text.wte = value.weight
162
+ # 2. keep a reference for get_input_embeddings()
163
+ self._input_embeddings = value
164
+
165
+ def input_embeds(
166
+ self,
167
+ input_ids: Union[torch.LongTensor, list, tuple],
168
+ *,
169
+ device: torch.device | None = None
170
+ ) -> torch.FloatTensor:
171
+ """
172
+ Back-compat wrapper that turns token IDs into embeddings.
173
+
174
+ Example:
175
+ ids = torch.tensor([[1, 2, 3]])
176
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
177
+ """
178
+ if not torch.is_tensor(input_ids):
179
+ input_ids = torch.as_tensor(input_ids)
180
+ if device is not None:
181
+ input_ids = input_ids.to(device)
182
+
183
+ return self.get_input_embeddings()(input_ids)
image_crops.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ from typing import TypedDict
6
+
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
+
17
+ def select_tiling(
18
+ height: int, width: int, crop_size: int, max_crops: int
19
+ ) -> tuple[int, int]:
20
+ """
21
+ Determine the optimal number of tiles to cover an image with overlapping crops.
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ class OverlapCropOutput(TypedDict):
54
+ crops: np.ndarray
55
+ tiling: tuple[int, int]
56
+
57
+
58
+ def overlap_crop_image(
59
+ image: np.ndarray,
60
+ overlap_margin: int,
61
+ max_crops: int,
62
+ base_size: tuple[int, int] = (378, 378),
63
+ patch_size: int = 14,
64
+ ) -> OverlapCropOutput:
65
+ """
66
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
67
+
68
+ This function takes an input image and creates multiple overlapping crops with
69
+ consistent margins. It produces:
70
+ 1. A single global crop resized to base_size
71
+ 2. Multiple overlapping local crops that maintain high resolution details
72
+ 3. A patch ordering matrix that tracks correspondence between crops
73
+
74
+ The overlap strategy ensures:
75
+ - Smooth transitions between adjacent crops
76
+ - No loss of information at crop boundaries
77
+ - Proper handling of features that cross crop boundaries
78
+ - Consistent patch indexing across the full image
79
+
80
+ Args:
81
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
82
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
83
+ patch_size (int): Size of patches in pixels, default 14
84
+ overlap_margin (int): Margin size in patch units, default 4
85
+ max_crops (int): Maximum number of crops allowed, default 12
86
+
87
+ Returns:
88
+ OverlapCropOutput: Dictionary containing:
89
+ - crops: A numpy array containing the global crop of the full image (index 0)
90
+ followed by the overlapping cropped regions (indices 1+)
91
+ - tiling: Tuple of (height,width) tile counts
92
+ """
93
+ original_h, original_w = image.shape[:2]
94
+
95
+ # Convert margin from patch units to pixels
96
+ margin_pixels = patch_size * overlap_margin
97
+ total_margin_pixels = margin_pixels * 2 # Both sides
98
+
99
+ # Calculate crop parameters
100
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
101
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
102
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
103
+
104
+ # Determine tiling
105
+ tiling = select_tiling(
106
+ original_h - total_margin_pixels,
107
+ original_w - total_margin_pixels,
108
+ crop_window_size,
109
+ max_crops,
110
+ )
111
+
112
+ # Pre-allocate crops.
113
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
114
+ crops = np.zeros(
115
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
116
+ )
117
+
118
+ # Resize image to fit tiling
119
+ target_size = (
120
+ tiling[0] * crop_window_size + total_margin_pixels,
121
+ tiling[1] * crop_window_size + total_margin_pixels,
122
+ )
123
+
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
+
152
+ for i in range(tiling[0]):
153
+ for j in range(tiling[1]):
154
+ # Calculate crop coordinates
155
+ y0 = i * crop_window_size
156
+ x0 = j * crop_window_size
157
+
158
+ # Extract crop with padding if needed
159
+ y_end = min(y0 + base_size[0], image.shape[0])
160
+ x_end = min(x0 + base_size[1], image.shape[1])
161
+
162
+ crop_region = image[y0:y_end, x0:x_end]
163
+ crops[
164
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
165
+ ] = crop_region
166
+
167
+ return {"crops": crops, "tiling": tiling}
168
+
169
+
170
+ def reconstruct_from_crops(
171
+ crops: torch.Tensor,
172
+ tiling: tuple[int, int],
173
+ overlap_margin: int,
174
+ patch_size: int = 14,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Reconstruct the original image from overlapping crops into a single seamless image.
178
+
179
+ Takes a list of overlapping image crops along with their positional metadata and
180
+ reconstructs them into a single coherent image by carefully stitching together
181
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
182
+
183
+ Args:
184
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
185
+ (H,W,C)
186
+ tiling: Tuple of (height,width) indicating crop grid layout
187
+ patch_size: Size in pixels of each patch, default 14
188
+ overlap_margin: Number of overlapping patches on each edge, default 4
189
+
190
+ Returns:
191
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
192
+ with shape (H,W,C) where H,W are the original image dimensions
193
+ """
194
+ tiling_h, tiling_w = tiling
195
+ crop_height, crop_width = crops[0].shape[:2]
196
+ margin_pixels = overlap_margin * patch_size
197
+
198
+ # Calculate output size (only adding margins once)
199
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
200
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
201
+
202
+ reconstructed = torch.zeros(
203
+ (output_h, output_w, crops[0].shape[2]),
204
+ device=crops[0].device,
205
+ dtype=crops[0].dtype,
206
+ )
207
+
208
+ for i, crop in enumerate(crops):
209
+ tile_y = i // tiling_w
210
+ tile_x = i % tiling_w
211
+
212
+ # For each tile, determine which part to keep
213
+ # Keep left margin only for first column
214
+ x_start = 0 if tile_x == 0 else margin_pixels
215
+ # Keep right margin only for last column
216
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
217
+ # Keep top margin only for first row
218
+ y_start = 0 if tile_y == 0 else margin_pixels
219
+ # Keep bottom margin only for last row
220
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
221
+
222
+ # Calculate where this piece belongs in the output
223
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
224
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
225
+
226
+ # Place the piece
227
+ reconstructed[
228
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
229
+ ] = crop[y_start:y_end, x_start:x_end]
230
+
231
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
+
8
+ try:
9
+ from torchao import quantize_
10
+ from torchao.quantization import int4_weight_only
11
+ except ImportError:
12
+
13
+ def quantize_(model, quant_mode):
14
+ raise ImportError(
15
+ "torchao is not installed. Please install it with `pip install torchao`."
16
+ )
17
+
18
+ def int4_weight_only(group_size):
19
+ raise ImportError(
20
+ "torchao is not installed. Please install it with `pip install torchao`."
21
+ )
22
+
23
+
24
+ def gelu_approx(x):
25
+ return F.gelu(x, approximate="tanh")
26
+
27
+
28
+ @dataclass
29
+ class LinearWeights:
30
+ weight: torch.Tensor
31
+ bias: torch.Tensor
32
+
33
+
34
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35
+ return F.linear(x, w.weight, w.bias)
36
+
37
+
38
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
+ _step = W_q.shape[0]
40
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41
+ W_r[:_step] = (W_q & 0b11110000) >> 4
42
+ W_r[_step:] = W_q & 0b00001111
43
+ W_r.sub_(zero).mul_(scale)
44
+ return W_r.reshape(orig_shape)
45
+
46
+
47
+ class QuantizedLinear(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ dtype: torch.dtype,
53
+ ):
54
+ # TODO: Take group_size as an input instead of hardcoding it here.
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.ParameterDict(
59
+ {
60
+ "packed": nn.Parameter(
61
+ torch.empty(
62
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63
+ ),
64
+ requires_grad=False,
65
+ ),
66
+ "scale": nn.Parameter(
67
+ torch.empty(out_features * in_features // 128, 1),
68
+ requires_grad=False,
69
+ ),
70
+ "zero_point": nn.Parameter(
71
+ torch.empty(out_features * in_features // 128, 1),
72
+ requires_grad=False,
73
+ ),
74
+ }
75
+ )
76
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77
+ self.unpacked = False
78
+
79
+ def unpack(self):
80
+ if self.unpacked:
81
+ return
82
+
83
+ self.weight = nn.Parameter(
84
+ dequantize_tensor(
85
+ self.weight["packed"],
86
+ self.weight["scale"],
87
+ self.weight["zero_point"],
88
+ (self.out_features, self.in_features),
89
+ torch.bfloat16,
90
+ )
91
+ )
92
+ with torch.device("meta"):
93
+ self.linear = nn.Linear(
94
+ self.in_features, self.out_features, dtype=torch.bfloat16
95
+ )
96
+ self.linear.weight = self.weight
97
+ self.linear.bias = nn.Parameter(
98
+ self.bias.to(torch.bfloat16), requires_grad=False
99
+ )
100
+
101
+ del self.weight, self.bias
102
+ quantize_(self, int4_weight_only(group_size=128))
103
+ self.unpacked = True
104
+ torch.cuda.empty_cache()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.unpacked:
108
+ self.unpack()
109
+ return self.linear(x)
110
+
111
+
112
+ @dataclass
113
+ class LayerNormWeights:
114
+ weight: torch.Tensor
115
+ bias: torch.Tensor
116
+
117
+
118
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
119
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
120
+
121
+
122
+ @dataclass
123
+ class MLPWeights:
124
+ fc1: LinearWeights
125
+ fc2: LinearWeights
126
+ act: Literal["gelu_approx"] = "gelu_approx"
127
+
128
+
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
130
+ x0 = w.fc1(x)
131
+ if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
+ else:
135
+ x = x0
136
+
137
+ x = gelu_approx(x)
138
+
139
+ x0 = w.fc2(x)
140
+ if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
+ else:
144
+ x = x0
145
+
146
+ return x
147
+
148
+
149
+ @dataclass
150
+ class AttentionWeights:
151
+ qkv: LinearWeights
152
+ proj: LinearWeights
153
+
154
+
155
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
156
+ bsz, q_len, d_model = x.shape
157
+ head_dim = d_model // n_heads
158
+
159
+ q, k, v = [
160
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
161
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
162
+ ]
163
+ out = F.scaled_dot_product_attention(q, k, v)
164
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
165
+ out = linear(out, w.proj)
166
+ return out
lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
9
+
10
+
11
+ def variant_cache_dir():
12
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
+
16
+ hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
+
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
+
22
+
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
+
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
+
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
34
+
35
+ headers = {"User-Agent": "moondream-torch"}
36
+ api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
39
+
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
43
+ return dest
44
+
45
+
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
55
+
56
+
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
60
+ return None
61
+
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
64
+ )
65
+
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2329af3a8706e93cd82d0556cb3df5d5da70ccc4a8385a8ec1423af45272f431
3
+ size 3854538968
modeling_phi.py ADDED
@@ -0,0 +1,1463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import (
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ get_torch_version,
39
+ is_flash_attn_2_available,
40
+ is_flash_attn_greater_or_equal_2_10,
41
+ is_torchdynamo_compiling,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_moondream import PhiConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "PhiConfig"
55
+
56
+
57
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
+ def _prepare_4d_causal_attention_mask_with_cache_position(
59
+ attention_mask: torch.Tensor,
60
+ sequence_length: int,
61
+ target_length: int,
62
+ dtype: torch.dtype,
63
+ device: torch.device,
64
+ min_dtype: float,
65
+ cache_position: torch.Tensor,
66
+ batch_size: int,
67
+ ):
68
+ """
69
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
+
72
+ Args:
73
+ attention_mask (`torch.Tensor`):
74
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
+ sequence_length (`int`):
76
+ The sequence length being processed.
77
+ target_length (`int`):
78
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
+ dtype (`torch.dtype`):
80
+ The dtype to use for the 4D attention mask.
81
+ device (`torch.device`):
82
+ The device to plcae the 4D attention mask on.
83
+ min_dtype (`float`):
84
+ The minimum value representable with the dtype `dtype`.
85
+ cache_position (`torch.Tensor`):
86
+ Indices depicting the position of the input sequence tokens in the sequence.
87
+ batch_size (`torch.Tensor`):
88
+ Batch size.
89
+ """
90
+ if attention_mask is not None and attention_mask.dim() == 4:
91
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
+ causal_mask = attention_mask
93
+ else:
94
+ causal_mask = torch.full(
95
+ (sequence_length, target_length),
96
+ fill_value=min_dtype,
97
+ dtype=dtype,
98
+ device=device,
99
+ )
100
+ if sequence_length != 1:
101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
102
+ causal_mask *= torch.arange(
103
+ target_length, device=device
104
+ ) > cache_position.reshape(-1, 1)
105
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
+ if attention_mask is not None:
107
+ causal_mask = (
108
+ causal_mask.clone()
109
+ ) # copy to contiguous memory for in-place edit
110
+ mask_length = attention_mask.shape[-1]
111
+ padding_mask = (
112
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
+ )
114
+ padding_mask = padding_mask == 0
115
+ causal_mask[:, :, :, :mask_length] = causal_mask[
116
+ :, :, :, :mask_length
117
+ ].masked_fill(padding_mask, min_dtype)
118
+
119
+ return causal_mask
120
+
121
+
122
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
+ class PhiRotaryEmbedding(nn.Module):
124
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
+ super().__init__()
126
+
127
+ self.dim = dim
128
+ self.max_position_embeddings = max_position_embeddings
129
+ self.base = base
130
+ inv_freq = 1.0 / (
131
+ self.base
132
+ ** (
133
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
+ / self.dim
135
+ )
136
+ )
137
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
138
+
139
+ # Build here to make `torch.jit.trace` work.
140
+ self._set_cos_sin_cache(
141
+ seq_len=max_position_embeddings,
142
+ device=self.inv_freq.device,
143
+ dtype=torch.get_default_dtype(),
144
+ )
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+ t = torch.arange(
149
+ self.max_seq_len_cached, device=device, dtype=torch.int64
150
+ ).type_as(self.inv_freq)
151
+
152
+ freqs = torch.outer(t, self.inv_freq)
153
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
+
158
+ def forward(self, x, seq_len=None):
159
+ # x: [bs, num_attention_heads, seq_len, head_size]
160
+ if seq_len > self.max_seq_len_cached:
161
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(
187
+ self.max_seq_len_cached, device=device, dtype=torch.int64
188
+ ).type_as(self.inv_freq)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
+
197
+
198
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
+
202
+ def __init__(
203
+ self,
204
+ dim,
205
+ max_position_embeddings=2048,
206
+ base=10000,
207
+ device=None,
208
+ scaling_factor=1.0,
209
+ ):
210
+ self.scaling_factor = scaling_factor
211
+ super().__init__(dim, max_position_embeddings, base, device)
212
+
213
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
214
+ self.max_seq_len_cached = seq_len
215
+
216
+ if seq_len > self.max_position_embeddings:
217
+ base = self.base * (
218
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
219
+ - (self.scaling_factor - 1)
220
+ ) ** (self.dim / (self.dim - 2))
221
+ inv_freq = 1.0 / (
222
+ base
223
+ ** (
224
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
+ / self.dim
226
+ )
227
+ )
228
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
229
+
230
+ t = torch.arange(
231
+ self.max_seq_len_cached, device=device, dtype=torch.int64
232
+ ).type_as(self.inv_freq)
233
+
234
+ freqs = torch.outer(t, self.inv_freq)
235
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
+ emb = torch.cat((freqs, freqs), dim=-1)
237
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
238
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
239
+
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
242
+ def rotate_half(x):
243
+ """Rotates half the hidden dims of the input."""
244
+ x1 = x[..., : x.shape[-1] // 2]
245
+ x2 = x[..., x.shape[-1] // 2 :]
246
+ return torch.cat((-x2, x1), dim=-1)
247
+
248
+
249
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
+ """Applies Rotary Position Embedding to the query and key tensors.
252
+
253
+ Args:
254
+ q (`torch.Tensor`): The query tensor.
255
+ k (`torch.Tensor`): The key tensor.
256
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
257
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
258
+ position_ids (`torch.Tensor`):
259
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
260
+ used to pass offsetted position ids when working with a KV-cache.
261
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
262
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
263
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
264
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
265
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
266
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
267
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
268
+ Returns:
269
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
270
+ """
271
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
272
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
273
+ q_embed = (q * cos) + (rotate_half(q) * sin)
274
+ k_embed = (k * cos) + (rotate_half(k) * sin)
275
+ return q_embed, k_embed
276
+
277
+
278
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
279
+ class PhiMLP(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.config = config
283
+ self.activation_fn = ACT2FN[config.hidden_act]
284
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
285
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
286
+
287
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
288
+ hidden_states = self.fc1(hidden_states)
289
+ hidden_states = self.activation_fn(hidden_states)
290
+ hidden_states = self.fc2(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
295
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
296
+ """
297
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
298
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
299
+ """
300
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
301
+ if n_rep == 1:
302
+ return hidden_states
303
+ hidden_states = hidden_states[:, :, None, :, :].expand(
304
+ batch, num_key_value_heads, n_rep, slen, head_dim
305
+ )
306
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
307
+
308
+
309
+ class PhiAttention(nn.Module):
310
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
311
+
312
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
313
+ super().__init__()
314
+ self.config = config
315
+ self.layer_idx = layer_idx
316
+ if layer_idx is None:
317
+ logger.warning_once(
318
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
+ "when creating this class."
321
+ )
322
+
323
+ self.attention_dropout = config.attention_dropout
324
+ self.hidden_size = config.hidden_size
325
+ self.num_heads = config.num_attention_heads
326
+ self.head_dim = self.hidden_size // self.num_heads
327
+ self.num_key_value_heads = config.num_key_value_heads
328
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
+ self.max_position_embeddings = config.max_position_embeddings
330
+ self.rope_theta = config.rope_theta
331
+ self.partial_rotary_factor = config.partial_rotary_factor
332
+ self.is_causal = True
333
+
334
+ if (self.head_dim * self.num_heads) != self.hidden_size:
335
+ raise ValueError(
336
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
337
+ f" and `num_heads`: {self.num_heads})."
338
+ )
339
+
340
+ self.Wqkv = nn.Linear(
341
+ self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
342
+ )
343
+ self.out_proj = nn.Linear(
344
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
345
+ )
346
+
347
+ self._init_rope()
348
+
349
+ def _init_rope(self):
350
+ if self.config.rope_scaling is None:
351
+ self.rotary_emb = PhiRotaryEmbedding(
352
+ int(self.partial_rotary_factor * self.head_dim),
353
+ max_position_embeddings=self.max_position_embeddings,
354
+ base=self.rope_theta,
355
+ )
356
+ else:
357
+ scaling_type = self.config.rope_scaling["type"]
358
+ scaling_factor = self.config.rope_scaling["factor"]
359
+ if scaling_type == "linear":
360
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
361
+ int(self.partial_rotary_factor * self.head_dim),
362
+ max_position_embeddings=self.max_position_embeddings,
363
+ scaling_factor=scaling_factor,
364
+ base=self.rope_theta,
365
+ )
366
+ elif scaling_type == "dynamic":
367
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
368
+ int(self.partial_rotary_factor * self.head_dim),
369
+ max_position_embeddings=self.max_position_embeddings,
370
+ scaling_factor=scaling_factor,
371
+ base=self.rope_theta,
372
+ )
373
+ else:
374
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ position_ids: Optional[torch.LongTensor] = None,
381
+ past_key_value: Optional[Cache] = None,
382
+ output_attentions: bool = False,
383
+ use_cache: bool = False,
384
+ cache_position: Optional[torch.LongTensor] = None,
385
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
+ bsz, q_len, _ = hidden_states.size()
387
+
388
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
389
+ 3, dim=-1
390
+ )
391
+
392
+ query_states = query_states.view(
393
+ bsz, q_len, self.num_heads, self.head_dim
394
+ ).transpose(1, 2)
395
+ key_states = key_states.view(
396
+ bsz, q_len, self.num_key_value_heads, self.head_dim
397
+ ).transpose(1, 2)
398
+ value_states = value_states.view(
399
+ bsz, q_len, self.num_key_value_heads, self.head_dim
400
+ ).transpose(1, 2)
401
+
402
+ kv_seq_len = key_states.shape[-2]
403
+ if past_key_value is not None:
404
+ if self.layer_idx is None:
405
+ raise ValueError(
406
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
407
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
408
+ "with a layer index."
409
+ )
410
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
411
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412
+
413
+ # Partial rotary embedding
414
+ query_rot, query_pass = (
415
+ query_states[..., : self.rotary_emb.dim],
416
+ query_states[..., self.rotary_emb.dim :],
417
+ )
418
+ key_rot, key_pass = (
419
+ key_states[..., : self.rotary_emb.dim],
420
+ key_states[..., self.rotary_emb.dim :],
421
+ )
422
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
423
+ query_rot, key_rot = apply_rotary_pos_emb(
424
+ query_rot, key_rot, cos, sin, position_ids
425
+ )
426
+
427
+ # [batch_size, seq_length, num_heads, head_dim]
428
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
429
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
430
+
431
+ if past_key_value is not None:
432
+ cache_kwargs = {
433
+ "sin": sin,
434
+ "cos": cos,
435
+ "partial_rotation_size": self.rotary_emb.dim,
436
+ "cache_position": cache_position,
437
+ }
438
+ key_states, value_states = past_key_value.update(
439
+ key_states, value_states, self.layer_idx, cache_kwargs
440
+ )
441
+
442
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
443
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
444
+
445
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
+ attn_weights = torch.matmul(
447
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
+ ) / math.sqrt(self.head_dim)
449
+
450
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
+ raise ValueError(
452
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
+ f" {attn_weights.size()}"
454
+ )
455
+
456
+ if attention_mask is not None:
457
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
+ attn_weights += causal_mask
459
+
460
+ # upcast attention to fp32
461
+ attn_weights = nn.functional.softmax(
462
+ attn_weights, dim=-1, dtype=torch.float32
463
+ ).to(value_states.dtype)
464
+ attn_weights = nn.functional.dropout(
465
+ attn_weights, p=self.attention_dropout, training=self.training
466
+ )
467
+
468
+ attn_output = torch.matmul(attn_weights, value_states)
469
+
470
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
+ raise ValueError(
472
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
+ f" {attn_output.size()}"
474
+ )
475
+
476
+ attn_output = attn_output.transpose(1, 2).contiguous()
477
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
+
479
+ attn_output = self.out_proj(attn_output)
480
+
481
+ if not output_attentions:
482
+ attn_weights = None
483
+
484
+ return attn_output, attn_weights, past_key_value
485
+
486
+
487
+ class PhiFlashAttention2(PhiAttention):
488
+ """
489
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
490
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
491
+ flash attention and deal with padding tokens in case the input contains any of them.
492
+ """
493
+
494
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
495
+ def __init__(self, *args, **kwargs):
496
+ super().__init__(*args, **kwargs)
497
+
498
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
501
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
502
+
503
+ def forward(
504
+ self,
505
+ hidden_states: torch.Tensor,
506
+ attention_mask: Optional[torch.LongTensor] = None,
507
+ position_ids: Optional[torch.LongTensor] = None,
508
+ past_key_value: Optional[Cache] = None,
509
+ output_attentions: bool = False,
510
+ use_cache: bool = False,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **kwargs,
513
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
+ # PhiFlashAttention2 attention does not support output_attentions
515
+
516
+ output_attentions = False
517
+
518
+ bsz, q_len, _ = hidden_states.size()
519
+
520
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
521
+ 3, dim=-1
522
+ )
523
+
524
+ # Flash attention requires the input to have the shape
525
+ # batch_size x seq_length x head_dim x hidden_dim
526
+ # therefore we just need to keep the original shape
527
+ query_states = query_states.view(
528
+ bsz, q_len, self.num_heads, self.head_dim
529
+ ).transpose(1, 2)
530
+ key_states = key_states.view(
531
+ bsz, q_len, self.num_key_value_heads, self.head_dim
532
+ ).transpose(1, 2)
533
+ value_states = value_states.view(
534
+ bsz, q_len, self.num_key_value_heads, self.head_dim
535
+ ).transpose(1, 2)
536
+
537
+ kv_seq_len = key_states.shape[-2]
538
+ if past_key_value is not None:
539
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
540
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
541
+
542
+ # Partial rotary embedding
543
+ query_rot, query_pass = (
544
+ query_states[..., : self.rotary_emb.dim],
545
+ query_states[..., self.rotary_emb.dim :],
546
+ )
547
+ key_rot, key_pass = (
548
+ key_states[..., : self.rotary_emb.dim],
549
+ key_states[..., self.rotary_emb.dim :],
550
+ )
551
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
552
+ query_rot, key_rot = apply_rotary_pos_emb(
553
+ query_rot, key_rot, cos, sin, position_ids
554
+ )
555
+
556
+ # [batch_size, seq_length, num_heads, head_dim]
557
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
558
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
559
+
560
+ if past_key_value is not None:
561
+ cache_kwargs = {
562
+ "sin": sin,
563
+ "cos": cos,
564
+ "partial_rotation_size": self.rotary_emb.dim,
565
+ "cache_position": cache_position,
566
+ }
567
+ key_states, value_states = past_key_value.update(
568
+ key_states, value_states, self.layer_idx, cache_kwargs
569
+ )
570
+
571
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
572
+ # to be able to avoid many of these transpose/reshape/view.
573
+ query_states = query_states.transpose(1, 2)
574
+ key_states = key_states.transpose(1, 2)
575
+ value_states = value_states.transpose(1, 2)
576
+
577
+ attn_dropout = self.attention_dropout if self.training else 0.0
578
+
579
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
580
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
581
+ # cast them back in the correct dtype just to be sure everything works as expected.
582
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
583
+ # in fp32.
584
+
585
+ if query_states.dtype == torch.float32:
586
+ if torch.is_autocast_enabled():
587
+ target_dtype = torch.get_autocast_gpu_dtype()
588
+ # Handle the case where the model is quantized
589
+ elif hasattr(self.config, "_pre_quantization_dtype"):
590
+ target_dtype = self.config._pre_quantization_dtype
591
+ else:
592
+ target_dtype = self.q_proj.weight.dtype
593
+
594
+ logger.warning_once(
595
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
596
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
597
+ f" {target_dtype}."
598
+ )
599
+
600
+ query_states = query_states.to(target_dtype)
601
+ key_states = key_states.to(target_dtype)
602
+ value_states = value_states.to(target_dtype)
603
+
604
+ attn_output = _flash_attention_forward(
605
+ query_states,
606
+ key_states,
607
+ value_states,
608
+ attention_mask,
609
+ q_len,
610
+ position_ids=position_ids,
611
+ dropout=attn_dropout,
612
+ softmax_scale=None,
613
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
+ is_causal=self.is_causal,
615
+ )
616
+
617
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
618
+ attn_output = self.out_proj(attn_output)
619
+
620
+ if not output_attentions:
621
+ attn_weights = None
622
+
623
+ return attn_output, attn_weights, past_key_value
624
+
625
+
626
+ class PhiSdpaAttention(PhiAttention):
627
+ def __init__(self, *args, **kwargs):
628
+ super().__init__(*args, **kwargs)
629
+ self.require_contiguous_qkv = version.parse(
630
+ get_torch_version()
631
+ ) < version.parse("2.2.0")
632
+
633
+ """
634
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
+ SDPA API.
637
+ """
638
+
639
+ # Adapted from PhiAttention.forward
640
+ def forward(
641
+ self,
642
+ hidden_states: torch.Tensor,
643
+ attention_mask: Optional[torch.Tensor] = None,
644
+ position_ids: Optional[torch.LongTensor] = None,
645
+ past_key_value: Optional[Cache] = None,
646
+ output_attentions: bool = False,
647
+ use_cache: bool = False,
648
+ cache_position: Optional[torch.LongTensor] = None,
649
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
+ if output_attentions:
651
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
+ logger.warning_once(
653
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
+ )
658
+ return super().forward(
659
+ hidden_states=hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_ids=position_ids,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ )
666
+
667
+ bsz, q_len, _ = hidden_states.size()
668
+
669
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
+ 3, dim=-1
671
+ )
672
+
673
+ query_states = query_states.view(
674
+ bsz, q_len, self.num_heads, self.head_dim
675
+ ).transpose(1, 2)
676
+ key_states = key_states.view(
677
+ bsz, q_len, self.num_key_value_heads, self.head_dim
678
+ ).transpose(1, 2)
679
+ value_states = value_states.view(
680
+ bsz, q_len, self.num_key_value_heads, self.head_dim
681
+ ).transpose(1, 2)
682
+
683
+ kv_seq_len = key_states.shape[-2]
684
+ if past_key_value is not None:
685
+ if self.layer_idx is None:
686
+ raise ValueError(
687
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
+ "with a layer index."
690
+ )
691
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
+
694
+ # Partial rotary embedding
695
+ query_rot, query_pass = (
696
+ query_states[..., : self.rotary_emb.dim],
697
+ query_states[..., self.rotary_emb.dim :],
698
+ )
699
+ key_rot, key_pass = (
700
+ key_states[..., : self.rotary_emb.dim],
701
+ key_states[..., self.rotary_emb.dim :],
702
+ )
703
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
+ query_rot, key_rot = apply_rotary_pos_emb(
705
+ query_rot, key_rot, cos, sin, position_ids
706
+ )
707
+
708
+ # [batch_size, seq_length, num_heads, head_dim]
709
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
710
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
711
+
712
+ if past_key_value is not None:
713
+ cache_kwargs = {
714
+ "sin": sin,
715
+ "cos": cos,
716
+ "partial_rotation_size": self.rotary_emb.dim,
717
+ "cache_position": cache_position,
718
+ }
719
+ key_states, value_states = past_key_value.update(
720
+ key_states, value_states, self.layer_idx, cache_kwargs
721
+ )
722
+
723
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
724
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
725
+
726
+ causal_mask = attention_mask
727
+ if attention_mask is not None:
728
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
+
730
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
733
+ if (
734
+ self.require_contiguous_qkv
735
+ and query_states.device.type == "cuda"
736
+ and attention_mask is not None
737
+ ):
738
+ query_states = query_states.contiguous()
739
+ key_states = key_states.contiguous()
740
+ value_states = value_states.contiguous()
741
+
742
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
+ is_causal = True if causal_mask is None and q_len > 1 else False
745
+
746
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
747
+ query_states,
748
+ key_states,
749
+ value_states,
750
+ attn_mask=causal_mask,
751
+ dropout_p=self.attention_dropout if self.training else 0.0,
752
+ is_causal=is_causal,
753
+ )
754
+
755
+ attn_output = attn_output.transpose(1, 2).contiguous()
756
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
+
758
+ attn_output = self.out_proj(attn_output)
759
+
760
+ return attn_output, None, past_key_value
761
+
762
+
763
+ PHI_ATTENTION_CLASSES = {
764
+ "eager": PhiAttention,
765
+ "flash_attention_2": PhiFlashAttention2,
766
+ "sdpa": PhiSdpaAttention,
767
+ }
768
+
769
+
770
+ class PhiDecoderLayer(nn.Module):
771
+ def __init__(self, config: PhiConfig, layer_idx: int):
772
+ super().__init__()
773
+ self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
774
+ config, layer_idx=layer_idx
775
+ )
776
+ self.mlp = PhiMLP(config)
777
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
778
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
779
+
780
+ def forward(
781
+ self,
782
+ hidden_states: torch.Tensor,
783
+ attention_mask: Optional[torch.Tensor] = None,
784
+ position_ids: Optional[torch.LongTensor] = None,
785
+ output_attentions: Optional[bool] = False,
786
+ use_cache: Optional[bool] = False,
787
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
+ cache_position: Optional[torch.LongTensor] = None,
789
+ **kwargs,
790
+ ) -> Tuple[
791
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
+ ]:
793
+ """
794
+ Args:
795
+ hidden_states (`torch.FloatTensor`):
796
+ input to the layer of shape `(batch, seq_len, embed_dim)`
797
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
798
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
799
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
800
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
801
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
802
+ output_attentions (`bool`, *optional*):
803
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
804
+ returned tensors for more detail.
805
+ use_cache (`bool`, *optional*):
806
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
+ (see `past_key_values`).
808
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
+ Indices depicting the position of the input sequence tokens in the sequence
811
+ kwargs (`dict`, *optional*):
812
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
+ into the model
814
+ """
815
+
816
+ residual = hidden_states
817
+
818
+ hidden_states = self.ln(hidden_states)
819
+
820
+ # Self Attention
821
+ attn_outputs, self_attn_weights, present_key_value = self.mixer(
822
+ hidden_states=hidden_states,
823
+ attention_mask=attention_mask,
824
+ position_ids=position_ids,
825
+ past_key_value=past_key_value,
826
+ output_attentions=output_attentions,
827
+ use_cache=use_cache,
828
+ cache_position=cache_position,
829
+ )
830
+ attn_outputs = self.resid_dropout(attn_outputs)
831
+
832
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
+ outputs = (hidden_states,)
835
+
836
+ if output_attentions:
837
+ outputs += (self_attn_weights,)
838
+
839
+ if use_cache:
840
+ outputs += (present_key_value,)
841
+
842
+ return outputs
843
+
844
+
845
+ PHI_START_DOCSTRING = r"""
846
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
+ etc.)
849
+
850
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
+ and behavior.
853
+
854
+ Parameters:
855
+ config ([`PhiConfig`]):
856
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
857
+ load the weights associated with the model, only the configuration. Check out the
858
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
+ PHI_START_DOCSTRING,
865
+ )
866
+ class PhiPreTrainedModel(PreTrainedModel):
867
+ config_class = PhiConfig
868
+ base_model_prefix = "model"
869
+ supports_gradient_checkpointing = True
870
+ _no_split_modules = ["PhiDecoderLayer"]
871
+ _skip_keys_device_placement = "past_key_values"
872
+ _supports_flash_attn_2 = True
873
+ _supports_sdpa = True
874
+ _supports_cache_class = True
875
+
876
+ def _init_weights(self, module):
877
+ std = self.config.initializer_range
878
+ if isinstance(module, nn.Linear):
879
+ module.weight.data.normal_(mean=0.0, std=std)
880
+ if module.bias is not None:
881
+ module.bias.data.zero_()
882
+ elif isinstance(module, nn.Embedding):
883
+ module.weight.data.normal_(mean=0.0, std=std)
884
+ if module.padding_idx is not None:
885
+ module.weight.data[module.padding_idx].zero_()
886
+
887
+
888
+ class Embedding(nn.Module):
889
+ def __init__(self, config: PhiConfig):
890
+ super().__init__()
891
+ self.wte = nn.Embedding(
892
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
893
+ )
894
+
895
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
+ return self.wte(input_ids)
897
+
898
+ PHI_INPUTS_DOCSTRING = r"""
899
+ Args:
900
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
+ it.
903
+
904
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
+ [`PreTrainedTokenizer.__call__`] for details.
906
+
907
+ [What are input IDs?](../glossary#input-ids)
908
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
+
911
+ - 1 for tokens that are **not masked**,
912
+ - 0 for tokens that are **masked**.
913
+
914
+ [What are attention masks?](../glossary#attention-mask)
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
+ `past_key_values`).
921
+
922
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
+ information on the default strategy.
925
+
926
+ - 1 indicates the head is **not masked**,
927
+ - 0 indicates the head is **masked**.
928
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
+ config.n_positions - 1]`.
931
+
932
+ [What are position IDs?](../glossary#position-ids)
933
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
+
938
+ Two formats are allowed:
939
+ - a [`~cache_utils.Cache`] instance;
940
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
+ cache format.
943
+
944
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
+ legacy cache format will be returned.
946
+
947
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
+ of shape `(batch_size, sequence_length)`.
950
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
+ model's internal embedding lookup matrix.
954
+ use_cache (`bool`, *optional*):
955
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
+ `past_key_values`).
957
+ output_attentions (`bool`, *optional*):
958
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
+ tensors for more detail.
960
+ output_hidden_states (`bool`, *optional*):
961
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
+ more detail.
963
+ return_dict (`bool`, *optional*):
964
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
+ the complete sequence length.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
+ PHI_START_DOCSTRING,
975
+ )
976
+ class PhiModel(PhiPreTrainedModel):
977
+ """
978
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
+
980
+ Args:
981
+ config: PhiConfig
982
+ """
983
+
984
+ def __init__(self, config: PhiConfig):
985
+ super().__init__(config)
986
+ self.padding_idx = config.pad_token_id
987
+ self.vocab_size = config.vocab_size
988
+
989
+ self.embd = Embedding(config)
990
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
+ self.h = nn.ModuleList(
992
+ [
993
+ PhiDecoderLayer(config, layer_idx)
994
+ for layer_idx in range(config.num_hidden_layers)
995
+ ]
996
+ )
997
+
998
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
+ self._use_sdpa = config._attn_implementation == "sdpa"
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ # Initialize weights and apply final processing
1003
+ self.post_init()
1004
+
1005
+ def get_input_embeddings(self):
1006
+ return self.embd.wte
1007
+
1008
+ def set_input_embeddings(self, value):
1009
+ self.embd.wte = value
1010
+
1011
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
+ def forward(
1013
+ self,
1014
+ input_ids: torch.LongTensor = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ position_ids: Optional[torch.LongTensor] = None,
1017
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1019
+ use_cache: Optional[bool] = None,
1020
+ output_attentions: Optional[bool] = None,
1021
+ output_hidden_states: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
+ output_attentions = (
1026
+ output_attentions
1027
+ if output_attentions is not None
1028
+ else self.config.output_attentions
1029
+ )
1030
+ output_hidden_states = (
1031
+ output_hidden_states
1032
+ if output_hidden_states is not None
1033
+ else self.config.output_hidden_states
1034
+ )
1035
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
+
1037
+ return_dict = (
1038
+ return_dict if return_dict is not None else self.config.use_return_dict
1039
+ )
1040
+
1041
+ if (input_ids is None) ^ (inputs_embeds is not None):
1042
+ raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
+ )
1045
+
1046
+ if self.gradient_checkpointing and self.training:
1047
+ if use_cache:
1048
+ logger.warning_once(
1049
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
+ )
1051
+ use_cache = False
1052
+
1053
+ use_legacy_cache = False
1054
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
+ use_legacy_cache = True
1056
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
+ logger.warning_once(
1058
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
+ )
1061
+
1062
+ if inputs_embeds is None:
1063
+ inputs_embeds = self.embd(input_ids)
1064
+
1065
+ if cache_position is None:
1066
+ past_seen_tokens = (
1067
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1068
+ )
1069
+ cache_position = torch.arange(
1070
+ past_seen_tokens,
1071
+ past_seen_tokens + inputs_embeds.shape[1],
1072
+ device=inputs_embeds.device,
1073
+ )
1074
+ if position_ids is None:
1075
+ position_ids = cache_position.unsqueeze(0)
1076
+
1077
+ causal_mask = self._update_causal_mask(
1078
+ attention_mask,
1079
+ inputs_embeds,
1080
+ cache_position,
1081
+ past_key_values,
1082
+ output_attentions,
1083
+ )
1084
+
1085
+ hidden_states = inputs_embeds
1086
+
1087
+ # decoder layers
1088
+ all_hidden_states = () if output_hidden_states else None
1089
+ all_self_attns = () if output_attentions else None
1090
+ next_decoder_cache = None
1091
+
1092
+ for decoder_layer in self.h:
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+
1096
+ if self.gradient_checkpointing and self.training:
1097
+ layer_outputs = self._gradient_checkpointing_func(
1098
+ decoder_layer.__call__,
1099
+ hidden_states,
1100
+ causal_mask,
1101
+ position_ids,
1102
+ output_attentions,
1103
+ use_cache,
1104
+ past_key_values,
1105
+ cache_position,
1106
+ )
1107
+ else:
1108
+ layer_outputs = decoder_layer(
1109
+ hidden_states,
1110
+ attention_mask=causal_mask,
1111
+ position_ids=position_ids,
1112
+ past_key_value=past_key_values,
1113
+ output_attentions=output_attentions,
1114
+ use_cache=use_cache,
1115
+ cache_position=cache_position,
1116
+ )
1117
+
1118
+ hidden_states = layer_outputs[0]
1119
+
1120
+ if use_cache:
1121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
+
1123
+ if output_attentions:
1124
+ all_self_attns += (layer_outputs[1],)
1125
+
1126
+ # add hidden states from the last decoder layer
1127
+ if output_hidden_states:
1128
+ all_hidden_states += (hidden_states,)
1129
+
1130
+ next_cache = None
1131
+ if use_cache:
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache()
1134
+ if use_legacy_cache
1135
+ else next_decoder_cache
1136
+ )
1137
+ if not return_dict:
1138
+ return tuple(
1139
+ v
1140
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
+ if v is not None
1142
+ )
1143
+ return BaseModelOutputWithPast(
1144
+ last_hidden_state=hidden_states,
1145
+ past_key_values=next_cache,
1146
+ hidden_states=all_hidden_states,
1147
+ attentions=all_self_attns,
1148
+ )
1149
+
1150
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
+ def _update_causal_mask(
1152
+ self,
1153
+ attention_mask: torch.Tensor,
1154
+ input_tensor: torch.Tensor,
1155
+ cache_position: torch.Tensor,
1156
+ past_key_values: Cache,
1157
+ output_attentions: bool,
1158
+ ):
1159
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
+
1164
+ if self.config._attn_implementation == "flash_attention_2":
1165
+ if attention_mask is not None and 0.0 in attention_mask:
1166
+ return attention_mask
1167
+ return None
1168
+
1169
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
+ # to infer the attention mask.
1172
+ past_seen_tokens = (
1173
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1174
+ )
1175
+ using_static_cache = isinstance(past_key_values, StaticCache)
1176
+
1177
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
+ if (
1179
+ self.config._attn_implementation == "sdpa"
1180
+ and not using_static_cache
1181
+ and not output_attentions
1182
+ ):
1183
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
+ attention_mask,
1185
+ inputs_embeds=input_tensor,
1186
+ past_key_values_length=past_seen_tokens,
1187
+ is_training=self.training,
1188
+ ):
1189
+ return None
1190
+
1191
+ dtype, device = input_tensor.dtype, input_tensor.device
1192
+ min_dtype = torch.finfo(dtype).min
1193
+ sequence_length = input_tensor.shape[1]
1194
+ if using_static_cache:
1195
+ target_length = past_key_values.get_max_length()
1196
+ else:
1197
+ target_length = (
1198
+ attention_mask.shape[-1]
1199
+ if isinstance(attention_mask, torch.Tensor)
1200
+ else past_seen_tokens + sequence_length + 1
1201
+ )
1202
+
1203
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
+ attention_mask,
1206
+ sequence_length=sequence_length,
1207
+ target_length=target_length,
1208
+ dtype=dtype,
1209
+ device=device,
1210
+ min_dtype=min_dtype,
1211
+ cache_position=cache_position,
1212
+ batch_size=input_tensor.shape[0],
1213
+ )
1214
+
1215
+ if (
1216
+ self.config._attn_implementation == "sdpa"
1217
+ and attention_mask is not None
1218
+ and attention_mask.device.type == "cuda"
1219
+ and not output_attentions
1220
+ ):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype
1226
+ )
1227
+
1228
+ return causal_mask
1229
+
1230
+
1231
+ class CausalLMHead(nn.Module):
1232
+ """Causal Language Modeling head. Simplified version."""
1233
+
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
+ self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
+
1239
+ def forward(self, hidden_states):
1240
+ return self.linear(self.ln(hidden_states))
1241
+
1242
+
1243
+ class PhiForCausalLM(PhiPreTrainedModel):
1244
+
1245
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
+ def __init__(self, config):
1247
+ super().__init__(config)
1248
+ self.transformer = PhiModel(config)
1249
+ self.vocab_size = config.vocab_size
1250
+ self.lm_head = CausalLMHead(config)
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
+ def get_input_embeddings(self):
1257
+ return self.transformer.embd.wte
1258
+
1259
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
+ def set_input_embeddings(self, value):
1261
+ self.transformer.embd.wte = value
1262
+
1263
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
+ def get_output_embeddings(self):
1265
+ return self.lm_head.linear
1266
+
1267
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head.linear = new_embeddings
1270
+
1271
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
+ def set_decoder(self, decoder):
1273
+ self.model = decoder
1274
+
1275
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
+ def get_decoder(self):
1277
+ return self.model
1278
+
1279
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
+ @replace_return_docstrings(
1281
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
+ )
1283
+ def forward(
1284
+ self,
1285
+ input_ids: torch.LongTensor = None,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ position_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1290
+ labels: Optional[torch.LongTensor] = None,
1291
+ use_cache: Optional[bool] = None,
1292
+ output_attentions: Optional[bool] = None,
1293
+ output_hidden_states: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ cache_position: Optional[torch.LongTensor] = None,
1296
+ num_logits_to_keep: int = 0,
1297
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
+ r"""
1299
+ Args:
1300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
+
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
+ Returns:
1311
+
1312
+ Example:
1313
+
1314
+ ```python
1315
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
+
1317
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
+
1320
+ >>> prompt = "This is an example script ."
1321
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
+
1323
+ >>> # Generate
1324
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
+ ```"""
1328
+
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
+ outputs = self.transformer(
1345
+ input_ids=input_ids,
1346
+ attention_mask=attention_mask,
1347
+ position_ids=position_ids,
1348
+ past_key_values=past_key_values,
1349
+ inputs_embeds=inputs_embeds,
1350
+ use_cache=use_cache,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ cache_position=cache_position,
1355
+ )
1356
+
1357
+ hidden_states = outputs[0]
1358
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
+
1360
+ loss = None
1361
+ if labels is not None:
1362
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
+ logits = logits.float()
1364
+ # Shift so that tokens < n predict n
1365
+ shift_logits = logits[..., :-1, :].contiguous()
1366
+ shift_labels = labels[..., 1:].contiguous()
1367
+ # Flatten the tokens
1368
+ loss_fct = CrossEntropyLoss()
1369
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
+ shift_labels = shift_labels.view(-1)
1371
+ # Enable model parallelism
1372
+ shift_labels = shift_labels.to(shift_logits.device)
1373
+ loss = loss_fct(shift_logits, shift_labels)
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ return (loss,) + output if loss is not None else output
1378
+
1379
+ return CausalLMOutputWithPast(
1380
+ loss=loss,
1381
+ logits=logits,
1382
+ past_key_values=outputs.past_key_values,
1383
+ hidden_states=outputs.hidden_states,
1384
+ attentions=outputs.attentions,
1385
+ )
1386
+
1387
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
+ def prepare_inputs_for_generation(
1389
+ self,
1390
+ input_ids,
1391
+ past_key_values=None,
1392
+ attention_mask=None,
1393
+ inputs_embeds=None,
1394
+ cache_position=None,
1395
+ position_ids=None,
1396
+ use_cache=True,
1397
+ num_logits_to_keep=0,
1398
+ **kwargs,
1399
+ ):
1400
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
+ if past_key_values is not None:
1404
+ if inputs_embeds is not None: # Exception 1
1405
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1406
+ elif (
1407
+ input_ids.shape[1] != cache_position.shape[0]
1408
+ ): # Default case (the "else", a no op, is Exception 2)
1409
+ input_ids = input_ids[:, cache_position]
1410
+
1411
+ if attention_mask is not None and position_ids is None:
1412
+ # create position_ids on the fly for batch generation
1413
+ position_ids = attention_mask.long().cumsum(-1) - 1
1414
+ position_ids.masked_fill_(attention_mask == 0, 1)
1415
+ if past_key_values:
1416
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1417
+
1418
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
+
1421
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
+ if inputs_embeds is not None and cache_position[0] == 0:
1423
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
+ else:
1425
+ # The clone here is for the same reason as for `position_ids`.
1426
+ model_inputs = {
1427
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
+ "inputs_embeds": None,
1429
+ }
1430
+
1431
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
+ if model_inputs["inputs_embeds"] is not None:
1433
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
+ device = model_inputs["inputs_embeds"].device
1435
+ else:
1436
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1437
+ device = model_inputs["input_ids"].device
1438
+
1439
+ dtype = self.lm_head.weight.dtype
1440
+ min_dtype = torch.finfo(dtype).min
1441
+
1442
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
+ attention_mask,
1444
+ sequence_length=sequence_length,
1445
+ target_length=past_key_values.get_max_length(),
1446
+ dtype=dtype,
1447
+ device=device,
1448
+ min_dtype=min_dtype,
1449
+ cache_position=cache_position,
1450
+ batch_size=batch_size,
1451
+ )
1452
+
1453
+ model_inputs.update(
1454
+ {
1455
+ "position_ids": position_ids,
1456
+ "cache_position": cache_position,
1457
+ "past_key_values": past_key_values,
1458
+ "use_cache": use_cache,
1459
+ "attention_mask": attention_mask,
1460
+ "num_logits_to_keep": num_logits_to_keep,
1461
+ }
1462
+ )
1463
+ return model_inputs
moondream.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+
10
+ from .config import MoondreamConfig
11
+ from .image_crops import reconstruct_from_crops
12
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
14
+ from .region import (
15
+ decode_coordinate,
16
+ encode_coordinate,
17
+ decode_size,
18
+ encode_size,
19
+ encode_spatial_refs,
20
+ SpatialRefs,
21
+ )
22
+ from .layers import QuantizedLinear
23
+ from .lora import variant_state_dict
24
+ from .utils import remove_outlier_points
25
+
26
+ ImageEncodingSettings = TypedDict(
27
+ "ImageEncodingSettings",
28
+ {"variant": str},
29
+ total=False,
30
+ )
31
+
32
+ TextSamplingSettings = TypedDict(
33
+ "TextSamplingSettings",
34
+ {
35
+ "max_tokens": int,
36
+ "temperature": float,
37
+ "top_p": float,
38
+ "variant": str,
39
+ },
40
+ total=False,
41
+ )
42
+
43
+ ObjectSamplingSettings = TypedDict(
44
+ "ObjectSamplingSettings",
45
+ {"max_objects": int, "variant": str},
46
+ total=False,
47
+ )
48
+
49
+
50
+ DEFAULT_MAX_TOKENS = 768
51
+ DEFAULT_TEMPERATURE = 0.5
52
+ DEFAULT_TOP_P = 0.3
53
+ DEFAULT_MAX_OBJECTS = 50
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class EncodedImage:
58
+ pos: int
59
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
60
+
61
+
62
+ class KVCache(nn.Module):
63
+
64
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
65
+ super().__init__()
66
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
67
+ self.register_buffer(
68
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
69
+ )
70
+ self.register_buffer(
71
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
72
+ )
73
+
74
+ def update(self, pos_ids, k, v):
75
+ kout, vout = self.k_cache, self.v_cache
76
+ kout[:, :, pos_ids, :] = k
77
+ vout[:, :, pos_ids, :] = v
78
+ return kout, vout
79
+
80
+
81
+ class MoondreamModel(nn.Module):
82
+
83
+ def __init__(
84
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
85
+ ):
86
+ super().__init__()
87
+ self.config = config
88
+
89
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
90
+ self.vision = build_vision_model(config.vision, dtype)
91
+ self.text = build_text_model(config.text, dtype)
92
+
93
+ # Region Model
94
+ linear_cls = (
95
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
96
+ )
97
+ self.region = nn.ModuleDict(
98
+ {
99
+ "coord_encoder": linear_cls(
100
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
101
+ ),
102
+ "coord_decoder": nn.ModuleDict(
103
+ {
104
+ "fc1": linear_cls(
105
+ config.region.dim, config.region.inner_dim, dtype=dtype
106
+ ),
107
+ "fc2": linear_cls(
108
+ config.region.inner_dim,
109
+ config.region.coord_out_dim,
110
+ dtype=dtype,
111
+ ),
112
+ }
113
+ ),
114
+ "size_encoder": linear_cls(
115
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
116
+ ),
117
+ "size_decoder": nn.ModuleDict(
118
+ {
119
+ "fc1": linear_cls(
120
+ config.region.dim, config.region.inner_dim, dtype=dtype
121
+ ),
122
+ "fc2": linear_cls(
123
+ config.region.inner_dim,
124
+ config.region.size_out_dim,
125
+ dtype=dtype,
126
+ ),
127
+ }
128
+ ),
129
+ }
130
+ )
131
+ self.region.coord_features = nn.Parameter(
132
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
133
+ )
134
+ self.region.size_features = nn.Parameter(
135
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
136
+ )
137
+
138
+ attn_mask = torch.tril(
139
+ torch.ones(
140
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
141
+ )
142
+ )
143
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
144
+ prefix_attn_len = 1 + patch_w**2
145
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
146
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
147
+
148
+ # Initialize KV caches.
149
+ if setup_caches:
150
+ self._setup_caches()
151
+
152
+ def _setup_caches(self):
153
+ c = self.config.text
154
+ for b in self.text.blocks:
155
+ b.kv_cache = KVCache(
156
+ c.n_heads,
157
+ c.n_kv_heads,
158
+ c.max_context,
159
+ c.dim,
160
+ device=self.device,
161
+ dtype=self.vision.pos_emb.dtype,
162
+ )
163
+
164
+ @property
165
+ def device(self):
166
+ return self.vision.pos_emb.device
167
+
168
+ def _vis_enc(self, x: torch.Tensor):
169
+ return vision_encoder(x, self.vision, self.config.vision)
170
+
171
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
172
+ return vision_projection(g, r, self.vision, self.config.vision)
173
+
174
+ def _prefill(
175
+ self,
176
+ x: torch.Tensor,
177
+ attn_mask: torch.Tensor,
178
+ pos_ids: torch.Tensor,
179
+ lora: Optional[torch.Tensor],
180
+ ):
181
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
182
+
183
+ def _decode_one_tok(
184
+ self,
185
+ x: torch.Tensor,
186
+ attn_mask: torch.Tensor,
187
+ pos_ids: torch.Tensor,
188
+ lora: Optional[torch.Tensor],
189
+ ):
190
+ hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
191
+ logits = lm_head(hidden, self.text)
192
+ return logits, hidden
193
+
194
+ def compile(self):
195
+ for module in self.modules():
196
+ if isinstance(module, QuantizedLinear):
197
+ module.unpack()
198
+
199
+ # TODO: vision_projection is not being compiled
200
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
201
+ self._prefill = torch.compile(self._prefill, fullgraph=True)
202
+ self._decode_one_tok = torch.compile(
203
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
204
+ )
205
+
206
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
207
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
208
+
209
+ torch._dynamo.mark_dynamic(all_crops, 0)
210
+
211
+ outputs = self._vis_enc(all_crops)
212
+
213
+ global_features = outputs[0]
214
+ local_features = outputs[1:].view(
215
+ -1,
216
+ self.config.vision.enc_n_layers,
217
+ self.config.vision.enc_n_layers,
218
+ self.config.vision.enc_dim,
219
+ )
220
+
221
+ reconstructed = reconstruct_from_crops(
222
+ local_features,
223
+ tiling,
224
+ patch_size=1,
225
+ overlap_margin=self.config.vision.overlap_margin,
226
+ )
227
+
228
+ return self._vis_proj(global_features, reconstructed)
229
+
230
+ def encode_image(
231
+ self,
232
+ image: Union[Image.Image, EncodedImage],
233
+ settings: Optional[ImageEncodingSettings] = None,
234
+ ) -> EncodedImage:
235
+ if isinstance(image, EncodedImage):
236
+ return image
237
+ elif not isinstance(image, Image.Image):
238
+ raise ValueError("image must be a PIL Image or EncodedImage")
239
+
240
+ lora = (
241
+ variant_state_dict(settings["variant"], device=self.device)
242
+ if settings is not None and settings["variant"] is not None
243
+ else None
244
+ )
245
+
246
+ # Run through text model in addition to the vision encoder, to minimize
247
+ # re-computation if multiple queries are performed on this image.
248
+ with torch.inference_mode():
249
+ img_emb = self._run_vision_encoder(image)
250
+ bos_emb = text_encoder(
251
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
252
+ self.text,
253
+ )
254
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
255
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
256
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
257
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
258
+
259
+ return EncodedImage(
260
+ pos=inputs_embeds.size(1),
261
+ caches=[
262
+ (
263
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
264
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
265
+ )
266
+ for b in self.text.blocks
267
+ ],
268
+ )
269
+
270
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
271
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
272
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
273
+ mask = probs_sum - probs_sort > top_p
274
+ probs_sort[mask] = 0.0
275
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
276
+ next_probs = torch.zeros_like(probs)
277
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
278
+ return next_probs
279
+
280
+ def _prefill_prompt(
281
+ self,
282
+ prompt_tokens: torch.Tensor,
283
+ pos: int,
284
+ temperature: float,
285
+ top_p: float,
286
+ spatial_refs: Optional[SpatialRefs] = None,
287
+ attn_mask: Optional[torch.Tensor] = None,
288
+ lora: Optional[dict] = None,
289
+ ):
290
+ with torch.inference_mode():
291
+ prompt_emb = text_encoder(prompt_tokens, self.text)
292
+
293
+ if spatial_refs:
294
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
295
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
296
+ encoded_refs["coords"]
297
+ )
298
+ if encoded_refs["sizes"] is not None:
299
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
300
+ encoded_refs["sizes"]
301
+ )
302
+
303
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
304
+
305
+ if attn_mask is None:
306
+ attn_mask = self.attn_mask
307
+
308
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
309
+ pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
310
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
311
+ logits_BV = lm_head(hidden_BC, self.text)
312
+
313
+ if temperature == 0:
314
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
315
+ else:
316
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
317
+ probs = self._apply_top_p(probs, top_p)
318
+ next_token = torch.multinomial(probs, num_samples=1)
319
+
320
+ pos = pos + prompt_emb.size(1)
321
+ return logits_BV, hidden_BC, next_token, pos
322
+
323
+ def _generate_reasoning(
324
+ self,
325
+ prompt_tokens,
326
+ pos,
327
+ settings: Optional[TextSamplingSettings] = None,
328
+ spatial_refs: Optional[SpatialRefs] = None,
329
+ attn_mask: Optional[torch.Tensor] = None,
330
+ ) -> Tuple[int, str, List[dict]]:
331
+ max_tokens = (
332
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
333
+ if settings
334
+ else DEFAULT_MAX_TOKENS
335
+ )
336
+ temperature = (
337
+ settings.get("temperature", DEFAULT_TEMPERATURE)
338
+ if settings
339
+ else DEFAULT_TEMPERATURE
340
+ )
341
+ lora = (
342
+ variant_state_dict(settings["variant"], device=self.device)
343
+ if settings is not None and "variant" in settings
344
+ else None
345
+ )
346
+
347
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
348
+ eos_id = self.config.tokenizer.answer_id
349
+
350
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
351
+ prompt_tokens,
352
+ pos,
353
+ temperature,
354
+ top_p,
355
+ spatial_refs,
356
+ attn_mask=attn_mask,
357
+ lora=lora,
358
+ )
359
+
360
+ text_token_chunks = [[]]
361
+ grounding_chunks = [[]]
362
+
363
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
364
+ mask[:, :, :pos] = 1
365
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
366
+ generated_tokens = 0
367
+
368
+ while (
369
+ next_token_id := next_token.item()
370
+ ) != eos_id and generated_tokens < max_tokens:
371
+ if (
372
+ next_token_id == self.config.tokenizer.start_ground_points_id
373
+ or next_token_id == self.config.tokenizer.end_ground_id
374
+ ):
375
+ text_token_chunks.append([])
376
+ grounding_chunks.append([])
377
+
378
+ text_token_chunks[-1].append(next_token_id)
379
+
380
+ with torch.inference_mode():
381
+ if next_token_id == self.config.tokenizer.coord_id:
382
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
383
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
384
+ grounding_chunks[-1].append(coord.item())
385
+
386
+ next_emb = encode_coordinate(
387
+ coord.to(dtype=coord_logits.dtype), self.region
388
+ ).unsqueeze(0)
389
+ else:
390
+ next_emb = text_encoder(next_token, self.text)
391
+
392
+ mask[:, :, pos], pos_ids[0] = 1, pos
393
+
394
+ logits_BV, last_hidden_BC = self._decode_one_tok(
395
+ next_emb, mask, pos_ids, lora
396
+ )
397
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
398
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
399
+
400
+ pos += 1
401
+
402
+ if temperature == 0:
403
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
404
+ else:
405
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
406
+ probs = self._apply_top_p(probs, top_p)
407
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
408
+
409
+ generated_tokens += 1
410
+
411
+ text_chunks = [
412
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
413
+ ]
414
+ text = "".join(text_chunks)
415
+
416
+ start_idx = 0
417
+ grounding = []
418
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
419
+ if len(grounding_chunk) > 1:
420
+ points = []
421
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
422
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
423
+ grounding.append(
424
+ {
425
+ "start_idx": start_idx,
426
+ "end_idx": start_idx + len(text_chunk),
427
+ "points": points,
428
+ }
429
+ )
430
+ start_idx += len(text_chunk)
431
+
432
+ return pos, text, grounding
433
+
434
+ def _generate_answer(
435
+ self,
436
+ prompt_tokens: torch.Tensor,
437
+ pos: int,
438
+ settings: Optional[TextSamplingSettings] = None,
439
+ spatial_refs: Optional[SpatialRefs] = None,
440
+ eos_id: Optional[int] = None,
441
+ attn_mask: Optional[torch.Tensor] = None,
442
+ ):
443
+ max_tokens = (
444
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
445
+ if settings
446
+ else DEFAULT_MAX_TOKENS
447
+ )
448
+ temperature = (
449
+ settings.get("temperature", DEFAULT_TEMPERATURE)
450
+ if settings
451
+ else DEFAULT_TEMPERATURE
452
+ )
453
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
454
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
455
+ lora = (
456
+ variant_state_dict(settings["variant"], device=self.device)
457
+ if settings is not None and "variant" in settings
458
+ else None
459
+ )
460
+
461
+ _, _, next_token, pos = self._prefill_prompt(
462
+ prompt_tokens,
463
+ pos,
464
+ temperature,
465
+ top_p,
466
+ spatial_refs,
467
+ attn_mask=attn_mask,
468
+ lora=lora,
469
+ )
470
+
471
+ def generator(next_token, pos):
472
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
473
+ mask[:, :, :pos] = 1
474
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
475
+ generated_tokens = 0
476
+
477
+ # For properly handling token streaming with Unicode
478
+ token_cache = []
479
+ print_len = 0
480
+
481
+ while (
482
+ next_token_id := next_token.item()
483
+ ) != eos_id and generated_tokens < max_tokens:
484
+ # Add token to our cache
485
+ token_cache.append(next_token_id)
486
+
487
+ # Decode all tokens collected so far
488
+ text = self.tokenizer.decode(token_cache)
489
+
490
+ # After a newline, we flush the cache completely
491
+ if text.endswith("\n"):
492
+ printable_text = text[print_len:]
493
+ token_cache = []
494
+ print_len = 0
495
+ if printable_text:
496
+ yield printable_text
497
+ # If the last token is a CJK character, we can safely print it
498
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
499
+ printable_text = text[print_len:]
500
+ print_len += len(printable_text)
501
+ if printable_text:
502
+ yield printable_text
503
+ # Otherwise, only yield up to the last space to avoid cutting words
504
+ else:
505
+ last_space_idx = text.rfind(" ", print_len)
506
+ if last_space_idx >= print_len:
507
+ printable_text = text[print_len : last_space_idx + 1]
508
+ print_len += len(printable_text)
509
+ if printable_text:
510
+ yield printable_text
511
+
512
+ with torch.inference_mode():
513
+ next_emb = text_encoder(next_token, self.text)
514
+ mask[:, :, pos], pos_ids[0] = 1, pos
515
+
516
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
517
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
518
+
519
+ pos += 1
520
+
521
+ if temperature == 0:
522
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
523
+ 1
524
+ ) # (1, 1)
525
+ else:
526
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
527
+ probs = self._apply_top_p(probs, top_p)
528
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
529
+
530
+ generated_tokens += 1
531
+
532
+ # Flush any remaining text in the cache
533
+ if token_cache:
534
+ text = self.tokenizer.decode(token_cache)
535
+ printable_text = text[print_len:]
536
+ if printable_text:
537
+ yield printable_text
538
+
539
+ return generator(next_token, pos)
540
+
541
+ def query(
542
+ self,
543
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
544
+ question: str = None,
545
+ reasoning: bool = False,
546
+ spatial_refs: Optional[SpatialRefs] = None,
547
+ stream: bool = False,
548
+ settings: Optional[TextSamplingSettings] = None,
549
+ ):
550
+ if self.config.tokenizer.templates["query"] is None:
551
+ raise NotImplementedError("Model does not support querying.")
552
+
553
+ if question is None:
554
+ raise ValueError("question must be provided.")
555
+
556
+ if spatial_refs and image is None:
557
+ raise ValueError("spatial_refs can only be used with an image.")
558
+
559
+ attn_mask = self.attn_mask
560
+ if image is not None:
561
+ image = self.encode_image(image, settings)
562
+ self.load_encoded_image(image)
563
+ pos = image.pos
564
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
565
+ else:
566
+ self._setup_caches()
567
+ pos = 0
568
+ prompt_toks = [
569
+ self.config.tokenizer.bos_id
570
+ ] + self.config.tokenizer.templates["query"]["prefix"]
571
+ max_context = self.config.text.max_context
572
+ attn_mask = torch.tril(
573
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
574
+ ).to(self.device)
575
+
576
+ spatial_toks = []
577
+ if spatial_refs:
578
+ for ref in spatial_refs:
579
+ coord_id = self.config.tokenizer.coord_id
580
+ size_id = self.config.tokenizer.size_id
581
+ if len(ref) == 2:
582
+ spatial_toks.extend([coord_id, coord_id])
583
+ else:
584
+ spatial_toks.extend([coord_id, coord_id, size_id])
585
+
586
+ prompt_tokens = [
587
+ prompt_toks
588
+ + spatial_toks
589
+ + self.tokenizer.encode(question).ids
590
+ + self.config.tokenizer.templates["query"]["suffix"]
591
+ ]
592
+
593
+ if reasoning:
594
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
595
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
596
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
597
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
598
+ )
599
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
600
+ reasoning_dict = {
601
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
602
+ }
603
+ else:
604
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
605
+ reasoning_dict = {}
606
+
607
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
608
+
609
+ def generator():
610
+ for token in self._generate_answer(
611
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
612
+ ):
613
+ yield token
614
+
615
+ if stream:
616
+ return {**reasoning_dict, "answer": generator()}
617
+ else:
618
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
619
+
620
+ def load_encoded_image(self, encoded_image: EncodedImage):
621
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
622
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
623
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
624
+
625
+ def caption(
626
+ self,
627
+ image: Union[Image.Image, EncodedImage],
628
+ length: Literal["normal", "short", "long"] = "normal",
629
+ stream: bool = False,
630
+ settings: Optional[TextSamplingSettings] = None,
631
+ ):
632
+ if self.config.tokenizer.templates["caption"] is None:
633
+ raise NotImplementedError("Model does not support captioning.")
634
+ if length not in self.config.tokenizer.templates["caption"]:
635
+ raise ValueError(f"Model does not support caption length '{length}'.")
636
+
637
+ image = self.encode_image(image, settings)
638
+ self.load_encoded_image(image)
639
+
640
+ prompt_tokens = torch.tensor(
641
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
642
+ )
643
+
644
+ def generator():
645
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
646
+ yield token
647
+
648
+ if stream:
649
+ return {"caption": generator()}
650
+ else:
651
+ return {"caption": "".join(list(generator()))}
652
+
653
+ def _generate_points(
654
+ self,
655
+ hidden: torch.Tensor,
656
+ next_token: torch.Tensor,
657
+ pos: int,
658
+ include_size: bool = True,
659
+ max_objects: int = DEFAULT_MAX_OBJECTS,
660
+ lora: Optional[dict] = None,
661
+ ):
662
+ out = []
663
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
664
+ mask[:, :, :pos] = 1
665
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
666
+
667
+ with torch.inference_mode():
668
+ while (
669
+ next_token.item() != self.config.tokenizer.eos_id
670
+ and len(out) < max_objects
671
+ ):
672
+ x_logits = decode_coordinate(hidden, self.region)
673
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
674
+ next_emb = encode_coordinate(
675
+ x_center.to(dtype=x_logits.dtype), self.region
676
+ ).unsqueeze(0)
677
+
678
+ # Decode y-coordinate
679
+ mask[:, :, pos], pos_ids[0] = 1, pos
680
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
681
+ pos += 1
682
+ y_logits = decode_coordinate(hidden, self.region)
683
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
684
+ next_emb = encode_coordinate(
685
+ y_center.to(dtype=y_logits.dtype), self.region
686
+ ).unsqueeze(0)
687
+
688
+ # Decode size
689
+ if include_size:
690
+ mask[:, :, pos], pos_ids[0] = 1, pos
691
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
692
+ pos += 1
693
+ size_logits = decode_size(hidden, self.region)
694
+
695
+ # Get bin indices from the logits
696
+ w_bin = torch.argmax(size_logits[0], dim=-1)
697
+ h_bin = torch.argmax(size_logits[1], dim=-1)
698
+
699
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
700
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
701
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
702
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
703
+
704
+ next_emb = (
705
+ encode_size(
706
+ torch.tensor(
707
+ [w, h], device=self.device, dtype=size_logits.dtype
708
+ ),
709
+ self.region,
710
+ )
711
+ .unsqueeze(0)
712
+ .unsqueeze(0)
713
+ )
714
+
715
+ # Add object
716
+ out.append(
717
+ {
718
+ "x_min": x_center.item() - w.item() / 2,
719
+ "y_min": y_center.item() - h.item() / 2,
720
+ "x_max": x_center.item() + w.item() / 2,
721
+ "y_max": y_center.item() + h.item() / 2,
722
+ }
723
+ )
724
+ else:
725
+ out.append({"x": x_center.item(), "y": y_center.item()})
726
+
727
+ # Decode next token (x-coordinate, or eos)
728
+ mask[:, :, pos], pos_ids[0] = 1, pos
729
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
730
+ pos += 1
731
+ next_token = torch.argmax(logits, dim=-1)
732
+
733
+ return out
734
+
735
+ def detect(
736
+ self,
737
+ image: Union[Image.Image, EncodedImage],
738
+ object: str,
739
+ settings: Optional[ObjectSamplingSettings] = None,
740
+ ):
741
+ if self.config.tokenizer.templates["detect"] is None:
742
+ raise NotImplementedError("Model does not support object detection.")
743
+
744
+ image = self.encode_image(image, settings)
745
+ self.load_encoded_image(image)
746
+
747
+ prompt_tokens = torch.tensor(
748
+ [
749
+ self.config.tokenizer.templates["detect"]["prefix"]
750
+ + self.tokenizer.encode(" " + object).ids
751
+ + self.config.tokenizer.templates["detect"]["suffix"]
752
+ ],
753
+ device=self.device,
754
+ )
755
+
756
+ lora = (
757
+ variant_state_dict(settings["variant"], device=self.device)
758
+ if settings is not None and "variant" in settings
759
+ else None
760
+ )
761
+
762
+ _, hidden, next_token, pos = self._prefill_prompt(
763
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
764
+ )
765
+ hidden = hidden[:, -1:, :]
766
+
767
+ max_objects = (
768
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
769
+ if settings
770
+ else DEFAULT_MAX_OBJECTS
771
+ )
772
+ objects = self._generate_points(
773
+ hidden,
774
+ next_token,
775
+ pos,
776
+ include_size=True,
777
+ max_objects=max_objects,
778
+ lora=lora,
779
+ )
780
+
781
+ return {"objects": objects}
782
+
783
+ def point(
784
+ self,
785
+ image: Union[Image.Image, EncodedImage],
786
+ object: str,
787
+ settings: Optional[ObjectSamplingSettings] = None,
788
+ ):
789
+ if self.config.tokenizer.templates["point"] is None:
790
+ raise NotImplementedError("Model does not support pointing.")
791
+
792
+ image = self.encode_image(image, settings)
793
+ self.load_encoded_image(image)
794
+
795
+ prompt_tokens = torch.tensor(
796
+ [
797
+ self.config.tokenizer.templates["point"]["prefix"]
798
+ + self.tokenizer.encode(" " + object).ids
799
+ + self.config.tokenizer.templates["point"]["suffix"]
800
+ ],
801
+ device=self.device,
802
+ )
803
+
804
+ lora = (
805
+ variant_state_dict(settings["variant"], device=self.device)
806
+ if settings is not None and "variant" in settings
807
+ else None
808
+ )
809
+
810
+ _, hidden, next_token, pos = self._prefill_prompt(
811
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
812
+ )
813
+ hidden = hidden[:, -1:, :]
814
+
815
+ max_objects = (
816
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
817
+ if settings
818
+ else DEFAULT_MAX_OBJECTS
819
+ )
820
+ objects = self._generate_points(
821
+ hidden,
822
+ next_token,
823
+ pos,
824
+ include_size=False,
825
+ max_objects=max_objects,
826
+ lora=lora,
827
+ )
828
+
829
+ return {"points": objects}
830
+
831
+ def _detect_gaze(
832
+ self,
833
+ image: EncodedImage,
834
+ source: Tuple[float, float],
835
+ force_detect: bool = False,
836
+ ):
837
+ with torch.inference_mode():
838
+ before_emb = text_encoder(
839
+ torch.tensor(
840
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
841
+ ),
842
+ self.text,
843
+ )
844
+ after_emb = text_encoder(
845
+ torch.tensor(
846
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
847
+ ),
848
+ self.text,
849
+ )
850
+ x_emb = encode_coordinate(
851
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
852
+ self.region,
853
+ )
854
+ y_emb = encode_coordinate(
855
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
856
+ self.region,
857
+ )
858
+
859
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
860
+
861
+ self.load_encoded_image(image)
862
+
863
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
864
+ pos_ids = torch.arange(
865
+ image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
866
+ )
867
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
868
+ logits = lm_head(hidden, self.text)
869
+ next_token = torch.argmax(logits, dim=-1)
870
+ pos = image.pos + prompt_emb.size(1)
871
+ hidden = hidden[:, -1:, :]
872
+
873
+ if force_detect:
874
+ next_token = torch.tensor([[0]], device=self.device)
875
+
876
+ if next_token.item() == self.config.tokenizer.eos_id:
877
+ return None
878
+
879
+ gaze = self._generate_points(
880
+ hidden, next_token, pos, include_size=False, max_objects=1
881
+ )
882
+ return gaze[0]
883
+
884
+ def detect_gaze(
885
+ self,
886
+ image: Union[Image.Image, EncodedImage],
887
+ eye: Optional[Tuple[float, float]] = None,
888
+ face: Optional[Dict[str, float]] = None,
889
+ unstable_settings: Dict[str, Any] = {},
890
+ ):
891
+ if "force_detect" in unstable_settings:
892
+ force_detect = unstable_settings["force_detect"]
893
+ else:
894
+ force_detect = False
895
+
896
+ if "prioritize_accuracy" in unstable_settings:
897
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
898
+ else:
899
+ prioritize_accuracy = False
900
+
901
+ if not prioritize_accuracy:
902
+ if eye is None:
903
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
904
+ image = self.encode_image(image)
905
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
906
+ else:
907
+ if (
908
+ not isinstance(image, Image.Image)
909
+ and "flip_enc_img" not in unstable_settings
910
+ ):
911
+ raise ValueError(
912
+ "image must be a PIL Image when prioritize_accuracy=True, "
913
+ "or flip_enc_img must be provided"
914
+ )
915
+ if face is None:
916
+ raise ValueError("face must be provided when prioritize_accuracy=True")
917
+
918
+ encoded_image = self.encode_image(image)
919
+ if (
920
+ isinstance(image, Image.Image)
921
+ and "flip_enc_img" not in unstable_settings
922
+ ):
923
+ flipped_pil = image.copy()
924
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
925
+ encoded_flipped_image = self.encode_image(flipped_pil)
926
+ else:
927
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
928
+
929
+ N = 10
930
+
931
+ detections = [
932
+ self._detect_gaze(
933
+ encoded_image,
934
+ (
935
+ random.uniform(face["x_min"], face["x_max"]),
936
+ random.uniform(face["y_min"], face["y_max"]),
937
+ ),
938
+ force_detect=force_detect,
939
+ )
940
+ for _ in range(N)
941
+ ]
942
+ detections = [
943
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
944
+ ]
945
+ flipped_detections = [
946
+ self._detect_gaze(
947
+ encoded_flipped_image,
948
+ (
949
+ 1 - random.uniform(face["x_min"], face["x_max"]),
950
+ random.uniform(face["y_min"], face["y_max"]),
951
+ ),
952
+ force_detect=force_detect,
953
+ )
954
+ for _ in range(N)
955
+ ]
956
+ detections.extend(
957
+ [
958
+ (1 - gaze["x"], gaze["y"])
959
+ for gaze in flipped_detections
960
+ if gaze is not None
961
+ ]
962
+ )
963
+
964
+ if len(detections) < N:
965
+ return {"gaze": None}
966
+
967
+ detections = remove_outlier_points(detections)
968
+ mean_gaze = (
969
+ sum(gaze[0] for gaze in detections) / len(detections),
970
+ sum(gaze[1] for gaze in detections) / len(detections),
971
+ )
972
+
973
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
974
+
975
+
976
+ def _is_cjk_char(cp):
977
+ """Checks whether CP is the codepoint of a CJK character."""
978
+ # This defines a "chinese character" as anything in the CJK Unicode block:
979
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
980
+ if (
981
+ (cp >= 0x4E00 and cp <= 0x9FFF)
982
+ or (cp >= 0x3400 and cp <= 0x4DBF)
983
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
984
+ ):
985
+ return True
986
+ return False
region.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from typing import List, Tuple, Union
6
+
7
+ from .layers import mlp
8
+
9
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
10
+
11
+
12
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
15
+ projects inputs through sinusoidal functions to create higher dimensional features
16
+ that help mitigate spectral bias - the tendency of neural networks to learn
17
+ low-frequency functions more easily than high-frequency ones. By explicitly
18
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
19
+ better learning of fine details and higher frequency patterns.
20
+
21
+ Args:
22
+ x: Input tensor to transform
23
+ w: Matrix of frequencies for the Fourier features transformation
24
+
25
+ Returns:
26
+ Concatenated cosine and sine transformed features as a tensor
27
+ """
28
+ f = 2 * math.pi * x @ w
29
+ return torch.cat([f.cos(), f.sin()], dim=-1)
30
+
31
+
32
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
33
+ """
34
+ Takes as input a tensor containing a single float coordinate value (x or y)
35
+ and encodes it into hidden states for input to the text model.
36
+
37
+ Args:
38
+ coord: Tensor with single float coordinate value
39
+
40
+ Returns:
41
+ Encoded hidden states tensor for input to text model
42
+ """
43
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
44
+
45
+
46
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
47
+ """
48
+ Takes as input the last hidden state from the text model and outputs a single logit
49
+ representing either an x or y coordinate prediction.
50
+
51
+ Args:
52
+ hidden_state: The final hidden state tensor from the text model.
53
+
54
+ Returns:
55
+ A single logit representing the predicted coordinate value (x or y)
56
+ """
57
+ return mlp(hidden_state, w.coord_decoder)
58
+
59
+
60
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
61
+ """
62
+ Takes a tensor containing width and height values and encodes them into
63
+ hidden states for input to the text model.
64
+
65
+ Args:
66
+ size: Tensor with two floats for width and height
67
+
68
+ Returns:
69
+ Encoded hidden states tensor for input to text model
70
+ """
71
+ return w.size_encoder(fourier_features(size, w.size_features))
72
+
73
+
74
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
+ """
76
+ Takes as input the last hidden state from the text model and outputs logits
77
+ for 1024 bins representing width and height in log-scale.
78
+
79
+ The bins are distributed according to the formula:
80
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
81
+ where size values are clamped to be at least 1/1024.
82
+
83
+ To convert from bin back to size:
84
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
85
+
86
+ Args:
87
+ hidden_state: The final hidden state tensor from the text model.
88
+
89
+ Returns:
90
+ A tensor containing logits for 1024 bins for width and height.
91
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
92
+ """
93
+ return mlp(hidden_state, w.size_decoder).view(2, -1)
94
+
95
+
96
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
97
+ """
98
+ Takes a list of spatial references (points or regions) and encodes them into
99
+ hidden states for input to the text model.
100
+
101
+ Args:
102
+ spatial_refs: List of spatial references (points or boxes)
103
+ - Points are represented as normalized (x, y) tuples
104
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
105
+
106
+ Returns:
107
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
108
+ """
109
+ coords, sizes = [], []
110
+ for ref in spatial_refs:
111
+ if len(ref) == 2:
112
+ coords.append(ref[0])
113
+ coords.append(ref[1])
114
+ else:
115
+ x_c = (ref[0] + ref[2]) / 2
116
+ y_c = (ref[1] + ref[3]) / 2
117
+ width = ref[2] - ref[0]
118
+ height = ref[3] - ref[1]
119
+ coords.append(x_c)
120
+ coords.append(y_c)
121
+ sizes.append([width, height])
122
+
123
+ coords = torch.tensor(
124
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
125
+ ).view(-1, 1)
126
+ coords = encode_coordinate(coords, w)
127
+
128
+ if sizes:
129
+ sizes = torch.tensor(
130
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
131
+ )
132
+ sizes = encode_size(sizes, w)
133
+ else:
134
+ sizes = None
135
+
136
+ return {"coords": coords, "sizes": sizes}
region_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .fourier_features import FourierFeatures
4
+
5
+ class RegionModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ self.position_features = FourierFeatures(2, 256)
10
+ self.position_encoder = nn.Linear(256, 2048)
11
+ self.size_features = FourierFeatures(2, 256)
12
+ self.size_encoder = nn.Linear(256, 2048)
13
+
14
+ self.position_decoder = nn.Linear(2048, 2)
15
+ self.size_decoder = nn.Linear(2048, 2)
16
+ self.confidence_decoder = nn.Linear(2048, 1)
17
+
18
+ def encode_position(self, position):
19
+ return self.position_encoder(self.position_features(position))
20
+
21
+ def encode_size(self, size):
22
+ return self.size_encoder(self.size_features(size))
23
+
24
+ def decode_position(self, x):
25
+ return self.position_decoder(x)
26
+
27
+ def decode_size(self, x):
28
+ return self.size_decoder(x)
29
+
30
+ def decode_confidence(self, x):
31
+ return self.confidence_decoder(x)
32
+
33
+ def encode(self, position, size):
34
+ return torch.stack(
35
+ [self.encode_position(position), self.encode_size(size)], dim=0
36
+ )
37
+
38
+ def decode(self, position_logits, size_logits):
39
+ return (
40
+ self.decode_position(position_logits),
41
+ self.decode_size(size_logits),
42
+ self.decode_confidence(size_logits),
43
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ einops
2
+ pyvips-binary==8.16.0
3
+ pyvips==2.2.3
rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 10000.0,
10
+ use_scaled: bool = False,
11
+ dtype: torch.dtype = torch.float32,
12
+ ) -> torch.Tensor:
13
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
14
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
15
+ freqs = t * freqs.unsqueeze(0)
16
+ freqs = torch.exp(1j * freqs)
17
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ x: torch.Tensor,
22
+ freqs_cis: torch.Tensor,
23
+ position_ids: torch.Tensor,
24
+ num_heads: int,
25
+ rot_dim: int = 32,
26
+ interleave: bool = False,
27
+ ) -> torch.Tensor:
28
+ assert rot_dim == freqs_cis.shape[-2] * 2
29
+ assert num_heads == x.shape[1]
30
+
31
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
32
+
33
+ if interleave:
34
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
35
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
36
+ else:
37
+ d_q = x_rot.shape[-1] // 2
38
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
39
+
40
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
42
+
43
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
44
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
45
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
46
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
47
+
48
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
text.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from typing import Optional
6
+
7
+ from .layers import layer_norm, mlp, QuantizedLinear
8
+ from .rope import apply_rotary_emb, precompute_freqs_cis
9
+ from .config import TextConfig
10
+
11
+
12
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
13
+ return F.embedding(input_ids, w.wte)
14
+
15
+
16
+ def attn(
17
+ x: torch.Tensor,
18
+ w: nn.Module,
19
+ freqs_cis: torch.Tensor,
20
+ kv_cache: nn.Module,
21
+ attn_mask: torch.Tensor,
22
+ n_heads: int,
23
+ n_kv_heads: int,
24
+ position_ids: torch.Tensor,
25
+ lora: Optional[dict],
26
+ ):
27
+ bsz, q_len, d_model = x.shape
28
+ head_dim = d_model // n_heads
29
+
30
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
31
+ if lora is not None:
32
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
33
+ q_dim = n_heads * head_dim
34
+ kv_dim = n_kv_heads * head_dim
35
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
36
+ del qkv_out
37
+
38
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
39
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
40
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
41
+
42
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
43
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
44
+
45
+ if kv_cache is not None:
46
+ k, v = kv_cache.update(position_ids, k, v)
47
+
48
+ out = F.scaled_dot_product_attention(
49
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
50
+ )
51
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
52
+
53
+ out0 = w.proj(out)
54
+ if lora is not None:
55
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
56
+ out = out0 + out1
57
+ else:
58
+ out = out0
59
+
60
+ return out
61
+
62
+
63
+ def _attn(
64
+ x: torch.Tensor,
65
+ w: torch.Tensor,
66
+ freqs_cis: torch.Tensor,
67
+ attn_mask: torch.Tensor,
68
+ n_heads: int,
69
+ n_kv_heads: int,
70
+ ):
71
+ bsz, q_len, d_model = x.shape
72
+ head_dim = d_model // n_heads
73
+ pos = 0
74
+
75
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
76
+ q_dim = n_heads * head_dim
77
+ kv_dim = n_kv_heads * head_dim
78
+
79
+ q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
80
+ k = (
81
+ qkv_out[..., q_dim : q_dim + kv_dim]
82
+ .view(bsz, q_len, n_kv_heads, head_dim)
83
+ .transpose(1, 2)
84
+ )
85
+ v = (
86
+ qkv_out[..., q_dim + kv_dim :]
87
+ .view(bsz, q_len, n_kv_heads, head_dim)
88
+ .transpose(1, 2)
89
+ )
90
+
91
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
92
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
93
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
94
+ out = F.scaled_dot_product_attention(
95
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
96
+ )
97
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
98
+ out = w.proj(out)
99
+ return out
100
+
101
+
102
+ def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
103
+ hidden_BTC = inputs_embeds
104
+
105
+ bsz, q_len, d_model = inputs_embeds.shape
106
+ attn_mask = torch.zeros(q_len, q_len)
107
+ attn_mask[:730, :730] = 1
108
+ for i in range(730, q_len):
109
+ attn_mask[i, : i + 1] = 1
110
+ attn_mask = attn_mask.to(dtype=torch.bool)
111
+
112
+ for i, block in enumerate(w.blocks):
113
+ l_in = layer_norm(hidden_BTC, block.ln)
114
+ l_attn = _attn(
115
+ x=l_in,
116
+ w=block.attn,
117
+ freqs_cis=w.freqs_cis,
118
+ attn_mask=attn_mask,
119
+ n_heads=config.n_heads,
120
+ n_kv_heads=config.n_kv_heads,
121
+ )
122
+ l_mlp = mlp(l_in, block.mlp)
123
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
124
+
125
+ return hidden_BTC
126
+
127
+
128
+ def text_decoder(
129
+ x: torch.Tensor,
130
+ w: nn.Module,
131
+ attn_mask: torch.Tensor,
132
+ position_ids: torch.Tensor,
133
+ config: TextConfig,
134
+ lora: Optional[dict],
135
+ ):
136
+ for i, block in enumerate(w.blocks):
137
+ if lora is not None:
138
+ layer_lora = lora["text"]["blocks"][str(i)]
139
+ mlp_lora = layer_lora["mlp"]
140
+ attn_lora = layer_lora["attn"]
141
+ else:
142
+ mlp_lora = None
143
+ attn_lora = None
144
+
145
+ l_in = layer_norm(x, block.ln)
146
+ l_attn = attn(
147
+ l_in,
148
+ block.attn,
149
+ freqs_cis=w.freqs_cis,
150
+ kv_cache=block.kv_cache,
151
+ attn_mask=attn_mask,
152
+ n_heads=config.n_heads,
153
+ n_kv_heads=config.n_kv_heads,
154
+ position_ids=position_ids,
155
+ lora=attn_lora,
156
+ )
157
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
158
+ x = x + l_attn + l_mlp
159
+
160
+ return x
161
+
162
+
163
+ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
164
+ hidden_BC = hidden_BTC[:, -1, :]
165
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
166
+ logits = w.lm_head(hidden_BC)
167
+ return logits
168
+
169
+
170
+ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
171
+ hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
172
+ logits = w.lm_head(hidden_BTC)
173
+ return logits
174
+
175
+
176
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
177
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
178
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
179
+
180
+ text = nn.ModuleDict(
181
+ {
182
+ "blocks": nn.ModuleList(
183
+ [
184
+ nn.ModuleDict(
185
+ {
186
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
187
+ "attn": nn.ModuleDict(
188
+ {
189
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
190
+ "proj": linear_cls(
191
+ config.dim, config.dim, dtype=dtype
192
+ ),
193
+ }
194
+ ),
195
+ "mlp": nn.ModuleDict(
196
+ {
197
+ "fc1": linear_cls(
198
+ config.dim, config.ff_dim, dtype=dtype
199
+ ),
200
+ "fc2": linear_cls(
201
+ config.ff_dim, config.dim, dtype=dtype
202
+ ),
203
+ }
204
+ ),
205
+ }
206
+ )
207
+ for _ in range(config.n_layers)
208
+ ]
209
+ ),
210
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
211
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
212
+ }
213
+ )
214
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
215
+ text.register_buffer(
216
+ "freqs_cis",
217
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
218
+ persistent=False,
219
+ )
220
+
221
+ return text
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "50257": {
13
+ "content": " ",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": false
19
+ },
20
+ "50258": {
21
+ "content": " ",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": false
27
+ },
28
+ "50259": {
29
+ "content": " ",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": false
35
+ },
36
+ "50260": {
37
+ "content": " ",
38
+ "lstrip": false,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": false
43
+ },
44
+ "50261": {
45
+ "content": " ",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": false
51
+ },
52
+ "50262": {
53
+ "content": " ",
54
+ "lstrip": false,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": false
59
+ },
60
+ "50263": {
61
+ "content": " ",
62
+ "lstrip": false,
63
+ "normalized": true,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": false
67
+ },
68
+ "50264": {
69
+ "content": " ",
70
+ "lstrip": false,
71
+ "normalized": true,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": false
75
+ },
76
+ "50265": {
77
+ "content": " ",
78
+ "lstrip": false,
79
+ "normalized": true,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": false
83
+ },
84
+ "50266": {
85
+ "content": " ",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": false
91
+ },
92
+ "50267": {
93
+ "content": " ",
94
+ "lstrip": false,
95
+ "normalized": true,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": false
99
+ },
100
+ "50268": {
101
+ "content": " ",
102
+ "lstrip": false,
103
+ "normalized": true,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": false
107
+ },
108
+ "50269": {
109
+ "content": " ",
110
+ "lstrip": false,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": false
115
+ },
116
+ "50270": {
117
+ "content": " ",
118
+ "lstrip": false,
119
+ "normalized": true,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "50271": {
125
+ "content": " ",
126
+ "lstrip": false,
127
+ "normalized": true,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "50272": {
133
+ "content": " ",
134
+ "lstrip": false,
135
+ "normalized": true,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "50273": {
141
+ "content": " ",
142
+ "lstrip": false,
143
+ "normalized": true,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "50274": {
149
+ "content": " ",
150
+ "lstrip": false,
151
+ "normalized": true,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "50275": {
157
+ "content": " ",
158
+ "lstrip": false,
159
+ "normalized": true,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "50276": {
165
+ "content": " ",
166
+ "lstrip": false,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "50277": {
173
+ "content": " ",
174
+ "lstrip": false,
175
+ "normalized": true,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ },
180
+ "50278": {
181
+ "content": " ",
182
+ "lstrip": false,
183
+ "normalized": true,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": false
187
+ },
188
+ "50279": {
189
+ "content": " ",
190
+ "lstrip": false,
191
+ "normalized": true,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": false
195
+ },
196
+ "50280": {
197
+ "content": " ",
198
+ "lstrip": false,
199
+ "normalized": true,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": false
203
+ },
204
+ "50281": {
205
+ "content": " ",
206
+ "lstrip": false,
207
+ "normalized": true,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": false
211
+ },
212
+ "50282": {
213
+ "content": " ",
214
+ "lstrip": false,
215
+ "normalized": true,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": false
219
+ },
220
+ "50283": {
221
+ "content": " ",
222
+ "lstrip": false,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": false
227
+ },
228
+ "50284": {
229
+ "content": " ",
230
+ "lstrip": false,
231
+ "normalized": true,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": false
235
+ },
236
+ "50285": {
237
+ "content": " ",
238
+ "lstrip": false,
239
+ "normalized": true,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": false
243
+ },
244
+ "50286": {
245
+ "content": " ",
246
+ "lstrip": false,
247
+ "normalized": true,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": false
251
+ },
252
+ "50287": {
253
+ "content": "\t\t\t\t\t\t\t\t\t",
254
+ "lstrip": false,
255
+ "normalized": true,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": false
259
+ },
260
+ "50288": {
261
+ "content": "\t\t\t\t\t\t\t\t",
262
+ "lstrip": false,
263
+ "normalized": true,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": false
267
+ },
268
+ "50289": {
269
+ "content": "\t\t\t\t\t\t\t",
270
+ "lstrip": false,
271
+ "normalized": true,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": false
275
+ },
276
+ "50290": {
277
+ "content": "\t\t\t\t\t\t",
278
+ "lstrip": false,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": false
283
+ },
284
+ "50291": {
285
+ "content": "\t\t\t\t\t",
286
+ "lstrip": false,
287
+ "normalized": true,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": false
291
+ },
292
+ "50292": {
293
+ "content": "\t\t\t\t",
294
+ "lstrip": false,
295
+ "normalized": true,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": false
299
+ },
300
+ "50293": {
301
+ "content": "\t\t\t",
302
+ "lstrip": false,
303
+ "normalized": true,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": false
307
+ },
308
+ "50294": {
309
+ "content": "\t\t",
310
+ "lstrip": false,
311
+ "normalized": true,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": false
315
+ }
316
+ },
317
+ "bos_token": "<|endoftext|>",
318
+ "clean_up_tokenization_spaces": true,
319
+ "eos_token": "<|endoftext|>",
320
+ "model_max_length": 2048,
321
+ "tokenizer_class": "CodeGenTokenizer",
322
+ "unk_token": "<|endoftext|>"
323
+ }
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
versions.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-03-04
2
+ 2024-03-06
3
+ 2024-03-13
4
+ 2024-04-02
5
+ 2024-05-08
6
+ 2024-05-20
7
+ 2024-07-23
8
+ 2024-08-26
9
+ 2025-01-09
10
+ 2025-03-27
11
+ 2025-04-14
12
+ 2025-06-21
vision.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from PIL import Image
8
+
9
+ from .layers import attn, layer_norm, mlp
10
+ from .image_crops import overlap_crop_image
11
+ from .config import VisionConfig
12
+
13
+ if torch.backends.mps.is_available():
14
+ # Non-divisible input sizes are not implemented on MPS device yet.
15
+ # https://github.com/pytorch/pytorch/issues/96056
16
+ def adaptive_avg_pool2d(input, output_size):
17
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
+
19
+ else:
20
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
+
22
+ DeviceLike = Union[str, torch.device, int]
23
+
24
+
25
+ def prepare_crops(
26
+ image: Image.Image, config: VisionConfig, device: DeviceLike
27
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28
+ np_image = np.array(image.convert("RGB"))
29
+ overlap_crops = overlap_crop_image(
30
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31
+ )
32
+ all_crops = overlap_crops["crops"]
33
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
+ all_crops = (
35
+ torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
+ .div_(255.0)
38
+ .sub_(0.5)
39
+ .div_(0.5)
40
+ )
41
+ return all_crops, overlap_crops["tiling"]
42
+
43
+
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
66
+
67
+ x = w.patch_emb(x)
68
+ x = x + w.pos_emb
69
+ for block in w.blocks:
70
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72
+ x = layer_norm(x, w.post_ln)
73
+
74
+ return x
75
+
76
+
77
+ def vision_projection(
78
+ global_features: torch.Tensor,
79
+ reconstructed: torch.Tensor,
80
+ w: nn.Module,
81
+ config: VisionConfig,
82
+ ):
83
+ reconstructed = reconstructed.permute(2, 0, 1)
84
+ reconstructed = adaptive_avg_pool2d(
85
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86
+ )
87
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
89
+ return mlp(final_features, w.proj_mlp)
90
+
91
+
92
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94
+ grid_size = config.crop_size // config.enc_patch_size
95
+ num_patches = grid_size * grid_size
96
+
97
+ vision = nn.ModuleDict(
98
+ {
99
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100
+ "blocks": nn.ModuleList(
101
+ [
102
+ nn.ModuleDict(
103
+ {
104
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105
+ "attn": nn.ModuleDict(
106
+ {
107
+ "qkv": nn.Linear(
108
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
109
+ ),
110
+ "proj": nn.Linear(
111
+ config.enc_dim, config.enc_dim, dtype=dtype
112
+ ),
113
+ }
114
+ ),
115
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116
+ "mlp": nn.ModuleDict(
117
+ {
118
+ "fc1": nn.Linear(
119
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
120
+ ),
121
+ "fc2": nn.Linear(
122
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
123
+ ),
124
+ }
125
+ ),
126
+ }
127
+ )
128
+ for _ in range(config.enc_n_layers)
129
+ ]
130
+ ),
131
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132
+ "proj_mlp": nn.ModuleDict(
133
+ {
134
+ "fc1": nn.Linear(
135
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136
+ ),
137
+ "fc2": nn.Linear(
138
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139
+ ),
140
+ }
141
+ ),
142
+ }
143
+ )
144
+ vision.pos_emb = nn.Parameter(
145
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146
+ )
147
+ return vision
vision_encoder.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import PIL.Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from einops import rearrange
8
+ import PIL
9
+ from torchvision.transforms.v2 import (
10
+ Compose,
11
+ Resize,
12
+ InterpolationMode,
13
+ ToImage,
14
+ ToDtype,
15
+ Normalize,
16
+ )
17
+ from transformers.utils import is_flash_attn_2_available
18
+
19
+ try:
20
+ if is_flash_attn_2_available():
21
+ from flash_attn.modules.mha import FlashSelfAttention
22
+ else:
23
+ FlashSelfAttention = None
24
+ except ImportError:
25
+ FlashSelfAttention = None
26
+
27
+
28
+ class Attention(nn.Module):
29
+
30
+ def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
+ super().__init__()
32
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
33
+
34
+ self.num_heads = num_heads
35
+ self.head_dim = dim // num_heads
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3)
38
+ self.proj = nn.Linear(dim, dim)
39
+
40
+ if use_flash_attn and FlashSelfAttention is not None:
41
+ self.flash_attn = FlashSelfAttention()
42
+ else:
43
+ self.flash_attn = None
44
+
45
+ torch.nn.init.kaiming_normal_(
46
+ self.qkv.weight, mode="fan_in", nonlinearity="relu"
47
+ )
48
+ torch.nn.init.kaiming_normal_(
49
+ self.proj.weight, mode="fan_in", nonlinearity="relu"
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ if self.flash_attn is not None:
54
+ qkv = self.qkv(x)
55
+ qkv = rearrange(
56
+ qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
57
+ )
58
+ attn_output = self.flash_attn(qkv)
59
+ output = rearrange(attn_output, "... h d -> ... (h d)")
60
+ output = self.proj(output)
61
+ return output
62
+ else:
63
+ B, N, C = x.shape
64
+ qkv = (
65
+ self.qkv(x)
66
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
67
+ .permute(2, 0, 3, 1, 4)
68
+ )
69
+ q, k, v = qkv.unbind(0)
70
+
71
+ x = F.scaled_dot_product_attention(q, k, v)
72
+
73
+ x = x.transpose(1, 2).reshape(B, N, C)
74
+ x = self.proj(x)
75
+ return x
76
+
77
+
78
+ class VitBlock(nn.Module):
79
+
80
+ def __init__(self, embed_dim, use_flash_attn=False):
81
+ super().__init__()
82
+ self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
+ self.mlp = MLP(embed_dim, 4304)
84
+ self.norm1 = nn.LayerNorm(embed_dim)
85
+ self.norm2 = nn.LayerNorm(embed_dim)
86
+
87
+ def forward(self, x):
88
+ x = x + self.attn(self.norm1(x))
89
+ x = x + self.mlp(self.norm2(x))
90
+ return x
91
+
92
+
93
+ class VisionTransformer(nn.Module):
94
+
95
+ def __init__(self, use_flash_attn=False):
96
+ super().__init__()
97
+
98
+ embed_len = 729
99
+ embed_dim = 1152
100
+
101
+ self.patch_embed = LinearPatchEmbedding()
102
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
+ self.blocks = nn.Sequential(
104
+ *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
105
+ )
106
+ self.norm = nn.LayerNorm(embed_dim)
107
+
108
+ def forward(self, x):
109
+ x = self.patch_embed(x)
110
+ x = x + self.pos_embed
111
+ for block in self.blocks:
112
+ x = block(x)
113
+ return self.norm(x)
114
+
115
+
116
+ class EncoderWrapper(nn.Module):
117
+
118
+ def __init__(self, use_flash_attn=False):
119
+ super().__init__()
120
+ self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
+
122
+ def forward(self, x):
123
+ return self.model["visual"](x)
124
+
125
+
126
+ class LinearPatchEmbedding(nn.Module):
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+ self.linear = nn.Linear(588, 1152)
131
+
132
+ def forward(self, x):
133
+ b, c, hp1, wp2 = x.shape
134
+ p1, p2 = 14, 14
135
+ h, w = hp1 // p1, wp2 // p2
136
+ x = x.reshape(b, c, h, p1, w, p2)
137
+ x = x.permute(0, 2, 4, 1, 3, 5)
138
+ x = x.reshape(b, h * w, c * p1 * p2)
139
+
140
+ return self.linear(x)
141
+
142
+
143
+ class MLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ in_features: int,
147
+ hidden_features: int = None,
148
+ out_features: int = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = nn.GELU(approximate="tanh")
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+
157
+ torch.nn.init.kaiming_normal_(
158
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
+ )
160
+ torch.nn.init.kaiming_normal_(
161
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
+ )
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.fc2(x)
168
+ return x
169
+
170
+
171
+ class VisionProjection(nn.Module):
172
+ def __init__(self):
173
+ super().__init__()
174
+
175
+ image_embedding_dim = 1152
176
+ model_dim = 2048
177
+ hidden_dim = model_dim * 4
178
+
179
+ self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
+
181
+ @property
182
+ def device(self):
183
+ return self.mlp.fc1.weight.device
184
+
185
+ def forward(self, x):
186
+ return self.mlp(x)
187
+
188
+
189
+ def create_patches(image, patch_size=(378, 378)):
190
+ assert image.dim() == 3, "Image must be in CHW format"
191
+
192
+ _, height, width = image.shape # Channels, Height, Width
193
+ patch_height, patch_width = patch_size
194
+
195
+ if height == patch_height and width == patch_width:
196
+ return []
197
+
198
+ # Iterate over the image and create patches
199
+ patches = []
200
+ for i in range(0, height, patch_height):
201
+ row_patches = []
202
+ for j in range(0, width, patch_width):
203
+ patch = image[:, i : i + patch_height, j : j + patch_width]
204
+ row_patches.append(patch)
205
+ patches.append(torch.stack(row_patches))
206
+ return patches
207
+
208
+
209
+ class VisionEncoder(nn.Module):
210
+
211
+ def __init__(self, use_flash_attn=False):
212
+ super().__init__()
213
+
214
+ self.encoder = EncoderWrapper(use_flash_attn)
215
+ self.projection = VisionProjection()
216
+ self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
+
218
+ @property
219
+ def device(self):
220
+ return self.projection.mlp.fc1.weight.device
221
+
222
+ @property
223
+ def dtype(self):
224
+ return self.projection.mlp.fc1.weight.dtype
225
+
226
+ def preprocess(self, image: PIL.Image.Image):
227
+ width, height = image.size
228
+ max_dim = max(width, height)
229
+ if max_dim < 512:
230
+ im_size = (378, 378)
231
+ else:
232
+ aspect_ratio = width / height
233
+ im_size = min(
234
+ self.supported_sizes,
235
+ key=lambda size: (
236
+ abs((size[1] / size[0]) - aspect_ratio),
237
+ abs(size[0] - width) + abs(size[1] - height),
238
+ ),
239
+ )
240
+
241
+ return Compose(
242
+ [
243
+ Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
+ ToImage(),
245
+ ToDtype(torch.float32, scale=True),
246
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
+ ]
248
+ )(image)
249
+
250
+ def forward(
251
+ self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
+ ) -> torch.Tensor:
253
+ im_list = None
254
+ if isinstance(images, torch.Tensor):
255
+ # Input must have dimensions (B, C, H, W)
256
+ assert (
257
+ len(images.shape) == 4
258
+ ), "Tensor input must have dimensions (B, C, H, W)"
259
+ im_list = list(images)
260
+ elif isinstance(images, PIL.Image.Image):
261
+ im_list = [images]
262
+ elif isinstance(images, list):
263
+ im_list = images
264
+ else:
265
+ raise ValueError(
266
+ "Input must be a PIL image, list of PIL images, or a tensor"
267
+ )
268
+
269
+ # Preprocess unless the images are already tensors (indicating that
270
+ # they have already been preprocessed)
271
+ if not isinstance(im_list[0], torch.Tensor):
272
+ im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
+
274
+ patches = [create_patches(im) for im in im_list]
275
+ flat_patches = [patch for image_patches in patches for patch in image_patches]
276
+
277
+ # Images may be variable size, and need to be resized to a common size after
278
+ # creating patches.
279
+ resized_images = [
280
+ F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
+ for im in im_list
282
+ ]
283
+
284
+ combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
+ combined_images = combined_images.to(self.device, dtype=self.dtype)
286
+
287
+ combined_features = self.encoder(combined_images)
288
+
289
+ full_img_features = combined_features[: len(im_list)]
290
+ patch_features = (
291
+ combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
+ )
293
+
294
+ # Reshape patch features back to their original structure
295
+ reshaped_patch_features = []
296
+ patch_idx = 0
297
+ for i, patch_set in enumerate(patches):
298
+ if len(patch_set) == 0:
299
+ reshaped_patch_features.append(
300
+ full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
+ )
302
+ else:
303
+ sample_features = []
304
+ for row_patches in patch_set:
305
+ row_len = len(row_patches)
306
+ row_features = patch_features[
307
+ patch_idx : patch_idx + row_len
308
+ ] # row_len, T, C
309
+ row_features = torch.cat(
310
+ list(row_features), dim=2
311
+ ) # T, C * row_len
312
+ patch_idx += row_len
313
+ sample_features.append(row_features)
314
+ sample_features = torch.cat(sample_features, dim=1)
315
+ sample_features = F.interpolate(
316
+ sample_features.unsqueeze(0), size=(27, 27), mode="bilinear"
317
+ ).squeeze(0)
318
+ reshaped_patch_features.append(sample_features)
319
+ reshaped_patch_features = (
320
+ torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
+ )
322
+
323
+ final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
+
325
+ return self.projection(final_features)
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
weights.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Callable, List
8
+
9
+ from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
10
+
11
+
12
+ @dataclass
13
+ class VisionBlock:
14
+ ln1: LayerNormWeights
15
+ attn: AttentionWeights
16
+ ln2: LayerNormWeights
17
+ mlp: MLPWeights
18
+
19
+
20
+ @dataclass
21
+ class VisionModel:
22
+ patch_emb: LinearWeights
23
+ pos_emb: torch.Tensor
24
+ blocks: List[VisionBlock]
25
+ post_ln: LayerNormWeights
26
+ proj_mlp: MLPWeights
27
+
28
+
29
+ @dataclass
30
+ class TextBlock:
31
+ ln: LayerNormWeights
32
+ attn: AttentionWeights
33
+ mlp: MLPWeights
34
+
35
+
36
+ @dataclass
37
+ class TextModel:
38
+ wte: torch.Tensor
39
+ blocks: List[TextBlock]
40
+ post_ln: LayerNormWeights
41
+ lm_head: LinearWeights
42
+
43
+
44
+ @dataclass
45
+ class RegionModel:
46
+ coord_features: torch.Tensor
47
+ coord_encoder: LinearWeights
48
+ coord_decoder: MLPWeights
49
+ size_features: torch.Tensor
50
+ size_encoder: LinearWeights
51
+ size_decoder: MLPWeights
52
+
53
+
54
+ @dataclass
55
+ class MoondreamModel:
56
+ vision: VisionModel
57
+ text: TextModel
58
+ region: RegionModel
59
+
60
+
61
+ @contextmanager
62
+ def safetensors_open(safetensors_file: str):
63
+ """
64
+ Simplify interfacing with safetensors files. Eliminates the need to ignore
65
+ type errors when using the `safe_open` function.
66
+ """
67
+ with safetensors.safe_open(
68
+ safetensors_file, framework="pt"
69
+ ) as st: # pyright: ignore
70
+
71
+ def get_tensor(name: str) -> torch.Tensor:
72
+ return st.get_tensor(name)
73
+
74
+ def get_keys() -> List[str]:
75
+ return st.keys()
76
+
77
+ get_tensor.keys = get_keys
78
+
79
+ yield get_tensor
80
+
81
+
82
+ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
83
+ """Internal function to load weights using a tensor getter function."""
84
+ model = model.to(dtype=torch.float16)
85
+
86
+ # Vision Model
87
+ model.vision["patch_emb"].weight.data.copy_(
88
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
89
+ )
90
+ model.vision["patch_emb"].bias.data.copy_(
91
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
92
+ )
93
+ model.vision.pos_emb.data.copy_(
94
+ get_tensor("vision_encoder.encoder.model.visual.pos_embed")
95
+ )
96
+
97
+ for i in range(len(model.vision["blocks"])):
98
+ prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
99
+
100
+ # Layer norms
101
+ model.vision["blocks"][i]["ln1"].weight.data.copy_(
102
+ get_tensor(f"{prefix}.norm1.weight")
103
+ )
104
+ model.vision["blocks"][i]["ln1"].bias.data.copy_(
105
+ get_tensor(f"{prefix}.norm1.bias")
106
+ )
107
+ model.vision["blocks"][i]["ln2"].weight.data.copy_(
108
+ get_tensor(f"{prefix}.norm2.weight")
109
+ )
110
+ model.vision["blocks"][i]["ln2"].bias.data.copy_(
111
+ get_tensor(f"{prefix}.norm2.bias")
112
+ )
113
+
114
+ # Attention
115
+ model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
116
+ get_tensor(f"{prefix}.attn.qkv.weight")
117
+ )
118
+ model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
119
+ get_tensor(f"{prefix}.attn.qkv.bias")
120
+ )
121
+ model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
122
+ get_tensor(f"{prefix}.attn.proj.weight")
123
+ )
124
+ model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
125
+ get_tensor(f"{prefix}.attn.proj.bias")
126
+ )
127
+
128
+ # MLP
129
+ model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
130
+ get_tensor(f"{prefix}.mlp.fc1.weight")
131
+ )
132
+ model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
133
+ get_tensor(f"{prefix}.mlp.fc1.bias")
134
+ )
135
+ model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
136
+ get_tensor(f"{prefix}.mlp.fc2.weight")
137
+ )
138
+ model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
139
+ get_tensor(f"{prefix}.mlp.fc2.bias")
140
+ )
141
+
142
+ model.vision["post_ln"].weight.data.copy_(
143
+ get_tensor("vision_encoder.encoder.model.visual.norm.weight")
144
+ )
145
+ model.vision["post_ln"].bias.data.copy_(
146
+ get_tensor("vision_encoder.encoder.model.visual.norm.bias")
147
+ )
148
+
149
+ model.vision["proj_mlp"]["fc1"].weight.data.copy_(
150
+ get_tensor("vision_encoder.projection.mlp.fc1.weight")
151
+ )
152
+ model.vision["proj_mlp"]["fc1"].bias.data.copy_(
153
+ get_tensor("vision_encoder.projection.mlp.fc1.bias")
154
+ )
155
+ model.vision["proj_mlp"]["fc2"].weight.data.copy_(
156
+ get_tensor("vision_encoder.projection.mlp.fc2.weight")
157
+ )
158
+ model.vision["proj_mlp"]["fc2"].bias.data.copy_(
159
+ get_tensor("vision_encoder.projection.mlp.fc2.bias")
160
+ )
161
+
162
+ # Text Model
163
+ model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
164
+
165
+ for i in range(len(model.text["blocks"])):
166
+ prefix = f"text_model.transformer.h.{i}"
167
+
168
+ # Layer norm
169
+ model.text["blocks"][i]["ln"].weight.data.copy_(
170
+ get_tensor(f"{prefix}.ln.weight")
171
+ )
172
+ model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
173
+
174
+ # Attention
175
+ model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
176
+ get_tensor(f"{prefix}.mixer.Wqkv.weight")
177
+ )
178
+ model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
179
+ get_tensor(f"{prefix}.mixer.Wqkv.bias")
180
+ )
181
+ model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
182
+ get_tensor(f"{prefix}.mixer.out_proj.weight")
183
+ )
184
+ model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
185
+ get_tensor(f"{prefix}.mixer.out_proj.bias")
186
+ )
187
+
188
+ # MLP
189
+ model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
190
+ get_tensor(f"{prefix}.mlp.fc1.weight")
191
+ )
192
+ model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
193
+ get_tensor(f"{prefix}.mlp.fc1.bias")
194
+ )
195
+ model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
196
+ get_tensor(f"{prefix}.mlp.fc2.weight")
197
+ )
198
+ model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
199
+ get_tensor(f"{prefix}.mlp.fc2.bias")
200
+ )
201
+
202
+ model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
203
+ model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
204
+
205
+ model.text["lm_head"].weight.data.copy_(
206
+ get_tensor("text_model.lm_head.linear.weight")
207
+ )
208
+ model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
209
+
210
+ # Region Model
211
+ model.region.coord_features.data.copy_(
212
+ get_tensor("region_model.coordinate_features.weight").T
213
+ )
214
+ model.region["coord_encoder"].weight.data.copy_(
215
+ get_tensor("region_model.coordinate_encoder.weight")
216
+ )
217
+ model.region["coord_encoder"].bias.data.copy_(
218
+ get_tensor("region_model.coordinate_encoder.bias")
219
+ )
220
+
221
+ model.region["coord_decoder"]["fc1"].weight.data.copy_(
222
+ get_tensor("region_model.coordinate_decoder.fc1.weight")
223
+ )
224
+ model.region["coord_decoder"]["fc1"].bias.data.copy_(
225
+ get_tensor("region_model.coordinate_decoder.fc1.bias")
226
+ )
227
+ model.region["coord_decoder"]["fc2"].weight.data.copy_(
228
+ get_tensor("region_model.coordinate_decoder.fc2.weight")
229
+ )
230
+ model.region["coord_decoder"]["fc2"].bias.data.copy_(
231
+ get_tensor("region_model.coordinate_decoder.fc2.bias")
232
+ )
233
+
234
+ model.region.size_features.data.copy_(
235
+ get_tensor("region_model.size_features.weight").T
236
+ )
237
+ model.region["size_encoder"].weight.data.copy_(
238
+ get_tensor("region_model.size_encoder.weight")
239
+ )
240
+ model.region["size_encoder"].bias.data.copy_(
241
+ get_tensor("region_model.size_encoder.bias")
242
+ )
243
+
244
+ model.region["size_decoder"]["fc1"].weight.data.copy_(
245
+ get_tensor("region_model.size_decoder.fc1.weight")
246
+ )
247
+ model.region["size_decoder"]["fc1"].bias.data.copy_(
248
+ get_tensor("region_model.size_decoder.fc1.bias")
249
+ )
250
+ model.region["size_decoder"]["fc2"].weight.data.copy_(
251
+ get_tensor("region_model.size_decoder.fc2.weight")
252
+ )
253
+ model.region["size_decoder"]["fc2"].bias.data.copy_(
254
+ get_tensor("region_model.size_decoder.fc2.bias")
255
+ )
256
+
257
+
258
+ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
259
+ """Load weights from a safetensors file into a MoondreamModel instance."""
260
+ with safetensors_open(weights_file) as get_tensor:
261
+ # Wrap the get_tensor function to handle key normalization
262
+ name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
263
+ _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
264
+
265
+
266
+ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
267
+ """Load weights from a PyTorch file into a MoondreamModel instance."""
268
+ device = str(torch.empty(0).device)
269
+ tensors = torch.load(weights_file, map_location=device, weights_only=True)
270
+ tensors = {
271
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
272
+ for k, v in tensors.items()
273
+ }
274
+ _load_weights(lambda x: tensors[x], model)
275
+
276
+
277
+ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
278
+ """
279
+ Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.
280
+
281
+ Args:
282
+ weights_file: Path to weights file (either .safetensors or .pt)
283
+ model: MoondreamModel instance to load weights into
284
+ """
285
+ if weights_file.endswith(".safetensors"):
286
+ load_weights_from_safetensors(weights_file, model)
287
+ else:
288
+ load_weights_from_pt(weights_file, model)
289
+
290
+ # Make all parameters contiguous
291
+ for param in model.parameters():
292
+ param.data = param.data.contiguous()