Error while running the run_inference_with_slide_encoder
I tried running the prov-gigapath/demo/run_gigapath.ipynb file but for some reason it throws the following error:
AssertionError Traceback (most recent call last)
Cell In[8], line 3
1 from gigapath.pipeline import run_inference_with_slide_encoder
2 # run inference with the slide encoder
----> 3 slide_embeds = run_inference_with_slide_encoder(slide_encoder_model=slide_encoder_model, **tile_encoder_outputs)
4 print(slide_embeds.keys())
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ../prov-gigapath/gigapath/pipeline.py:128, in run_inference_with_slide_encoder(tile_embeds, coords, slide_encoder_model)
126 # run inference
127 with torch.cuda.amp.autocast(dtype=torch.float16):
--> 128 slide_embeds = slide_encoder_model(tile_embeds.cuda(), coords.cuda(), all_layer_embed=True)
129 outputs = {"layer_{}_embed".format(i): slide_embeds[i].cpu() for i in range(len(slide_embeds))}
130 outputs["last_layer_embed"] = slide_embeds[-1].cpu()
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ../prov-gigapath/gigapath/slide_encoder.py:208, in LongNetViT.forward(self, x, coords, all_layer_embed)
206 # apply Transformer blocks
207 if all_layer_embed:
--> 208 x_list = self.encoder(src_tokens=None, token_embeddings=x, return_all_hiddens=all_layer_embed)["encoder_states"]
209 else:
210 x_list = [self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"]]
File /blue/pinaki.sarder/harishwarreddy.k/conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ..prov-gigapath/gigapath/torchscale/model/../../torchscale/architecture/encoder.py:374, in Encoder.forward(self, src_tokens, encoder_padding_mask, attn_mask, return_all_hiddens, token_embeddings, multiway_split_position, features_only, incremental_state, positions, **kwargs)
372 l_aux = []
373 for idx, layer in enumerate(self.layers):
--> 374 x, l_aux_i = layer(
375 x,
376 encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
377 attn_mask=attn_mask,
378 rel_pos=rel_pos_bias,
379 multiway_split_position=multiway_split_position,
380 incremental_state=incremental_state[idx] if incremental_state is not None else None,
381 )
382 if return_all_hiddens:
383 assert encoder_states is not None
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ../prov-gigapath/gigapath/torchscale/model/../../torchscale/architecture/encoder.py:127, in EncoderLayer.forward(self, x, encoder_padding_mask, attn_mask, rel_pos, multiway_split_position, incremental_state)
125 if self.normalize_before:
126 x = self.self_attn_layer_norm(x)
--> 127 x, _ = self.self_attn(
128 query=x,
129 key=x,
130 value=x,
131 key_padding_mask=encoder_padding_mask,
132 attn_mask=attn_mask,
133 rel_pos=rel_pos,
134 incremental_state=incremental_state,
135 )
136 x = self.dropout_module(x)
138 if self.drop_path is not None:
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ../conda/envs/Foundation/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ../prov-gigapath/gigapath/torchscale/model/../../torchscale/component/dilated_attention.py:205, in DilatedAttention.forward(self, query, key, value, incremental_state, key_padding_mask, attn_mask, rel_pos, is_first_step, is_causal)
202 vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
203 qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel)
--> 205 out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
207 outs.append(out)
208 lses.append(lse)
File ../prov-gigapath/gigapath/torchscale/model/../../torchscale/component/multihead_attention.py:98, in MultiheadAttention.attention_ops(self, q, k, v, key_padding_mask, attn_mask, rel_pos, is_causal)
96 attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads)
97 else:
---> 98 assert flash_attn_func is not None
99 assert rel_pos is None
100 q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
AssertionError:
Can someone help me out?
Hi, I think you didn't install flash-attn successfully. Please run this command and try again: pip install flash-attn --no-build-isolation