seems borked; ValueError: Shapes (1,10,1,256) and (1,10,1,128) cannot be broadcast.

#2
by bmorphism - opened
MLX Community org
% ipython
Python 3.11.6 (main, Oct  2 2023, 13:45:54) [Clang 16.0.6 ]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.24.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import mlx_lm

In [2]: mlx_lm.__version__
Out[2]: '0.13.1'

In [3]: from mlx_lm import load, generate
   ...:
   ...: model, tokenizer = load("mlx-community/Phi-3-medium-128k-instruct-8bit")
   ...: response = generate(model, tokenizer, prompt="hello", verbose=True)
   ...:
Fetching 11 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 11/11 [00:00<00:00, 30155.13it/s]
[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
==========
Prompt: hello
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 4
      1 from mlx_lm import load, generate
      3 model, tokenizer = load("mlx-community/Phi-3-medium-128k-instruct-8bit")
----> 4 response = generate(model, tokenizer, prompt="hello", verbose=True)

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/utils.py:247, in generate(model, tokenizer, prompt, temp, max_tokens, verbose, formatter, repetition_penalty, repetition_context_size, top_p, logit_bias)
    244 tic = time.perf_counter()
    245 detokenizer.reset()
--> 247 for (token, prob), n in zip(
    248     generate_step(
    249         prompt_tokens,
    250         model,
    251         temp,
    252         repetition_penalty,
    253         repetition_context_size,
    254         top_p,
    255         logit_bias,
    256     ),
    257     range(max_tokens),
    258 ):
    259     if n == 0:
    260         prompt_time = time.perf_counter() - tic

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/utils.py:195, in generate_step(prompt, model, temp, repetition_penalty, repetition_context_size, top_p, logit_bias)
    192             repetition_context = repetition_context[-repetition_context_size:]
    193     return y, prob
--> 195 y, p = _step(y)
    197 mx.async_eval(y)
    198 while True:

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/utils.py:178, in generate_step.<locals>._step(y)
    176 def _step(y):
    177     nonlocal repetition_context
--> 178     logits = model(y[None], cache=cache)
    179     logits = logits[:, -1, :]
    181     if repetition_penalty:

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/models/phi3.py:183, in Model.__call__(self, inputs, cache)
    178 def __call__(
    179     self,
    180     inputs: mx.array,
    181     cache=None,
    182 ):
--> 183     out = self.model(inputs, cache)
    184     return self.lm_head(out)

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/models/phi3.py:165, in Phi3Model.__call__(self, inputs, cache)
    162     cache = [None] * len(self.layers)
    164 for layer, c in zip(self.layers, cache):
--> 165     h = layer(h, mask, c)
    167 return self.norm(h)

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/models/phi3.py:129, in TransformerBlock.__call__(self, x, mask, cache)
    123 def __call__(
    124     self,
    125     x: mx.array,
    126     mask: Optional[mx.array] = None,
    127     cache: Optional[Tuple[mx.array, mx.array]] = None,
    128 ) -> mx.array:
--> 129     r = self.self_attn(self.input_layernorm(x), mask, cache)
    130     h = x + r
    131     r = self.mlp(self.post_attention_layernorm(h))

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/models/phi3.py:86, in Attention.__call__(self, x, mask, cache)
     84     queries = self.rope(queries, offset=cache.offset)
     85     keys = self.rope(keys, offset=cache.offset)
---> 86     keys, values = cache.update_and_fetch(keys, values)
     87 else:
     88     queries = self.rope(queries)

File ~/Library/Caches/pypoetry/virtualenvs/cyberneticus-rqHG5l2E-py3.11/lib/python3.11/site-packages/mlx_lm/models/base.py:34, in KVCache.update_and_fetch(self, keys, values)
     31         self.keys, self.values = new_k, new_v
     33 self.offset += keys.shape[2]
---> 34 self.keys[..., prev : self.offset, :] = keys
     35 self.values[..., prev : self.offset, :] = values
     36 return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

ValueError: Shapes (1,10,1,256) and (1,10,1,128) cannot be broadcast.
bmorphism changed discussion title from seems borked to seems borked ValueError: Shapes (1,10,1,256) and (1,10,1,128) cannot be broadcast.
bmorphism changed discussion title from seems borked ValueError: Shapes (1,10,1,256) and (1,10,1,128) cannot be broadcast. to seems borked; ValueError: Shapes (1,10,1,256) and (1,10,1,128) cannot be broadcast.
MLX Community org

I get same error, both with the original microsoft version and this one, mlx 0.13.1.

MLX Community org

Sign up or log in to comment