Benefit from loading fp32 weights instead?

#12
by ryanramos - opened

I noticed that the Flan T5 XXL model used in this space (philschmid/flan-t5-xxl-sharded-fp16) is supposedly stored in fp16, but iirc there's been a lot of discussion on how quantizing Flan T5's XXL variant at 8-bit still needs some layers need to be in fp32 for performance more faithful to the full precision original (see here). Just wondering if that means performance here can be improved if an fp32 version is loaded then 8-bit quantized instead (the transformers library should already make sure that T5ForConditionalGeneration models keep wo layers at fp32 by default iirc)?

Yup, and I'm thinking it's because of the XXL variant's need for fp32 weights in some layers.

Sorry to ping you @osanseviero since you don't really owe anyone anything, but have you considered swapping out the model to tomrb/flan-t5-xxl-sharded (not that I can verify that it's the correct model or anything, I just found it a while back)? Unless of course it won't fit on an A10G.

Edit: I originally suggested tomrb/flan-t5-xxl-sharded because I wasn't sure if an A10G could handle the original's shard sizes, but if it can then ofc google/flan-t5-xxl is better.

@ryanramos , I think the reason why the 8-bit model was used is latency? An A10G is 24GB, whereas the fp32 flan is 46GB or so in weights alone. It would perhaps be very challenging to serve the full flan-t5-xxl due to cost and latency constraints? @osanseviero can confirm if my logic makes sense. Perhaps we could branch out and have a second demo with the FLAN-T5-XL in bf16 so that the community can compare the performance more exhaustively? It would be a great benefit to many research and practitioner communities!

@deathcrush Yeah 8-bit quantization is definitely useful for those reasons. The thing is though while Flan T5 XXL can (and on an A10G, should) be 8-bit quantized for the most part, its wo layers (and lm_head) need to be in 32-bit. Again, the rest of the model can be in 8-bit, no problem, it's just that some layers need to be kept at full precision for stability purposes (discussed here and here). This is why I was suggesting that a checkpoint stored (not loaded) in fp32 be used, so that the necessary weights can actually be properly used at full precision.

Ah, good point @ryanramos , now I see your point. So you think that the strategy should be stored in fp32 and the right parts quantized, and other layers as you mentioned should be kept in fp32. I'm happy to test if it still fits on a 24GB GPU (I have an RTX 3090). I also have A100-80GB for testing purposes. Should I just test the checkpoint you mention? I supposed that quantisation just happens by passing load_in_8_bit=True to from_pretrained, right?

@deathcrush I should apologize, I made the title of this thread misleading. But yes, download fp32 weights, load most of it in 8-bit, then keep certain parts of it in 32-bit. As you mentioned, this should be accomplished automatically after passing load_in_8bit=True as well as device_map="auto" into from_pretrained. However last time I checked, infer_auto_device_map (which is called internally if using device_map="auto") had no way of knowing about the 32-bit layers i.e. it prepares a device map assuming everything is 8-bit but underestimates the true load because some of it is actually in 32-bit. If the whole thing doesn't fit onto the GPU, this might make things annoying to deal with/you might have to make a custom device map. On a Tesla T4 I use

device map ``` {'encoder.block.0.layer.1.DenseReluDense.wo': 0, 'encoder.block.1.layer.1.DenseReluDense.wo': 0, 'encoder.block.2.layer.1.DenseReluDense.wo': 0, 'encoder.block.3.layer.1.DenseReluDense.wo': 0, 'encoder.block.4.layer.1.DenseReluDense.wo': 0, 'encoder.block.5.layer.1.DenseReluDense.wo': 0, 'encoder.block.6.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.7.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.8.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.9.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.10.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.11.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.12.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.13.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.14.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.15.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.16.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.17.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.18.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.19.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.20.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.21.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.22.layer.1.DenseReluDense.wo': 'cpu', 'encoder.block.23.layer.1.DenseReluDense.wo': 'cpu', 'decoder.block.0.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.1.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.2.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.3.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.4.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.5.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.6.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.7.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.8.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.9.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.10.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.11.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.12.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.13.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.14.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.15.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.16.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.17.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.18.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.19.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.20.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.21.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.22.layer.2.DenseReluDense.wo': 'cpu', 'decoder.block.23.layer.2.DenseReluDense.wo': 'cpu', 'shared': 0, 'encoder.embed_tokens': 0, 'encoder.block.0.layer.0': 0, 'encoder.block.0.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.0.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.0.layer.1.DenseReluDense.dropout': 0, 'encoder.block.0.layer.1.DenseReluDense.act': 0, 'encoder.block.0.layer.1.layer_norm': 0, 'encoder.block.0.layer.1.dropout': 0, 'encoder.block.1.layer.0': 0, 'encoder.block.1.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.1.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.1.layer.1.DenseReluDense.dropout': 0, 'encoder.block.1.layer.1.DenseReluDense.act': 0, 'encoder.block.1.layer.1.layer_norm': 0, 'encoder.block.1.layer.1.dropout': 0, 'encoder.block.2.layer.0': 0, 'encoder.block.2.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.2.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.2.layer.1.DenseReluDense.dropout': 0, 'encoder.block.2.layer.1.DenseReluDense.act': 0, 'encoder.block.2.layer.1.layer_norm': 0, 'encoder.block.2.layer.1.dropout': 0, 'encoder.block.3.layer.0': 0, 'encoder.block.3.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.3.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.3.layer.1.DenseReluDense.dropout': 0, 'encoder.block.3.layer.1.DenseReluDense.act': 0, 'encoder.block.3.layer.1.layer_norm': 0, 'encoder.block.3.layer.1.dropout': 0, 'encoder.block.4.layer.0': 0, 'encoder.block.4.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.4.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.4.layer.1.DenseReluDense.dropout': 0, 'encoder.block.4.layer.1.DenseReluDense.act': 0, 'encoder.block.4.layer.1.layer_norm': 0, 'encoder.block.4.layer.1.dropout': 0, 'encoder.block.5.layer.0': 0, 'encoder.block.5.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.5.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.5.layer.1.DenseReluDense.dropout': 0, 'encoder.block.5.layer.1.DenseReluDense.act': 0, 'encoder.block.5.layer.1.layer_norm': 0, 'encoder.block.5.layer.1.dropout': 0, 'encoder.block.6.layer.0': 0, 'encoder.block.6.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.6.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.6.layer.1.DenseReluDense.dropout': 0, 'encoder.block.6.layer.1.DenseReluDense.act': 0, 'encoder.block.6.layer.1.layer_norm': 0, 'encoder.block.6.layer.1.dropout': 0, 'encoder.block.7.layer.0': 0, 'encoder.block.7.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.7.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.7.layer.1.DenseReluDense.dropout': 0, 'encoder.block.7.layer.1.DenseReluDense.act': 0, 'encoder.block.7.layer.1.layer_norm': 0, 'encoder.block.7.layer.1.dropout': 0, 'encoder.block.8.layer.0': 0, 'encoder.block.8.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.8.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.8.layer.1.DenseReluDense.dropout': 0, 'encoder.block.8.layer.1.DenseReluDense.act': 0, 'encoder.block.8.layer.1.layer_norm': 0, 'encoder.block.8.layer.1.dropout': 0, 'encoder.block.9.layer.0': 0, 'encoder.block.9.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.9.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.9.layer.1.DenseReluDense.dropout': 0, 'encoder.block.9.layer.1.DenseReluDense.act': 0, 'encoder.block.9.layer.1.layer_norm': 0, 'encoder.block.9.layer.1.dropout': 0, 'encoder.block.10.layer.0': 0, 'encoder.block.10.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.10.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.10.layer.1.DenseReluDense.dropout': 0, 'encoder.block.10.layer.1.DenseReluDense.act': 0, 'encoder.block.10.layer.1.layer_norm': 0, 'encoder.block.10.layer.1.dropout': 0, 'encoder.block.11.layer.0': 0, 'encoder.block.11.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.11.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.11.layer.1.DenseReluDense.dropout': 0, 'encoder.block.11.layer.1.DenseReluDense.act': 0, 'encoder.block.11.layer.1.layer_norm': 0, 'encoder.block.11.layer.1.dropout': 0, 'encoder.block.12.layer.0': 0, 'encoder.block.12.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.12.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.12.layer.1.DenseReluDense.dropout': 0, 'encoder.block.12.layer.1.DenseReluDense.act': 0, 'encoder.block.12.layer.1.layer_norm': 0, 'encoder.block.12.layer.1.dropout': 0, 'encoder.block.13.layer.0': 0, 'encoder.block.13.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.13.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.13.layer.1.DenseReluDense.dropout': 0, 'encoder.block.13.layer.1.DenseReluDense.act': 0, 'encoder.block.13.layer.1.layer_norm': 0, 'encoder.block.13.layer.1.dropout': 0, 'encoder.block.14.layer.0': 0, 'encoder.block.14.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.14.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.14.layer.1.DenseReluDense.dropout': 0, 'encoder.block.14.layer.1.DenseReluDense.act': 0, 'encoder.block.14.layer.1.layer_norm': 0, 'encoder.block.14.layer.1.dropout': 0, 'encoder.block.15.layer.0': 0, 'encoder.block.15.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.15.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.15.layer.1.DenseReluDense.dropout': 0, 'encoder.block.15.layer.1.DenseReluDense.act': 0, 'encoder.block.15.layer.1.layer_norm': 0, 'encoder.block.15.layer.1.dropout': 0, 'encoder.block.16.layer.0': 0, 'encoder.block.16.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.16.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.16.layer.1.DenseReluDense.dropout': 0, 'encoder.block.16.layer.1.DenseReluDense.act': 0, 'encoder.block.16.layer.1.layer_norm': 0, 'encoder.block.16.layer.1.dropout': 0, 'encoder.block.17.layer.0': 0, 'encoder.block.17.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.17.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.17.layer.1.DenseReluDense.dropout': 0, 'encoder.block.17.layer.1.DenseReluDense.act': 0, 'encoder.block.17.layer.1.layer_norm': 0, 'encoder.block.17.layer.1.dropout': 0, 'encoder.block.18.layer.0': 0, 'encoder.block.18.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.18.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.18.layer.1.DenseReluDense.dropout': 0, 'encoder.block.18.layer.1.DenseReluDense.act': 0, 'encoder.block.18.layer.1.layer_norm': 0, 'encoder.block.18.layer.1.dropout': 0, 'encoder.block.19.layer.0': 0, 'encoder.block.19.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.19.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.19.layer.1.DenseReluDense.dropout': 0, 'encoder.block.19.layer.1.DenseReluDense.act': 0, 'encoder.block.19.layer.1.layer_norm': 0, 'encoder.block.19.layer.1.dropout': 0, 'encoder.block.20.layer.0': 0, 'encoder.block.20.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.20.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.20.layer.1.DenseReluDense.dropout': 0, 'encoder.block.20.layer.1.DenseReluDense.act': 0, 'encoder.block.20.layer.1.layer_norm': 0, 'encoder.block.20.layer.1.dropout': 0, 'encoder.block.21.layer.0': 0, 'encoder.block.21.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.21.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.21.layer.1.DenseReluDense.dropout': 0, 'encoder.block.21.layer.1.DenseReluDense.act': 0, 'encoder.block.21.layer.1.layer_norm': 0, 'encoder.block.21.layer.1.dropout': 0, 'encoder.block.22.layer.0': 0, 'encoder.block.22.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.22.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.22.layer.1.DenseReluDense.dropout': 0, 'encoder.block.22.layer.1.DenseReluDense.act': 0, 'encoder.block.22.layer.1.layer_norm': 0, 'encoder.block.22.layer.1.dropout': 0, 'encoder.block.23.layer.0': 0, 'encoder.block.23.layer.1.DenseReluDense.wi_0': 0, 'encoder.block.23.layer.1.DenseReluDense.wi_1': 0, 'encoder.block.23.layer.1.DenseReluDense.dropout': 0, 'encoder.block.23.layer.1.DenseReluDense.act': 0, 'encoder.block.23.layer.1.layer_norm': 0, 'encoder.block.23.layer.1.dropout': 0, 'encoder.final_layer_norm': 0, 'encoder.dropout': 0, 'decoder.embed_tokens': 0, 'decoder.block.0.layer.0': 0, 'decoder.block.0.layer.1': 0, 'decoder.block.0.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.0.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.0.layer.2.DenseReluDense.dropout': 0, 'decoder.block.0.layer.2.DenseReluDense.act': 0, 'decoder.block.0.layer.2.layer_norm': 0, 'decoder.block.0.layer.2.dropout': 0, 'decoder.block.1.layer.0': 0, 'decoder.block.1.layer.1': 0, 'decoder.block.1.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.1.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.1.layer.2.DenseReluDense.dropout': 0, 'decoder.block.1.layer.2.DenseReluDense.act': 0, 'decoder.block.1.layer.2.layer_norm': 0, 'decoder.block.1.layer.2.dropout': 0, 'decoder.block.2.layer.0': 0, 'decoder.block.2.layer.1': 0, 'decoder.block.2.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.2.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.2.layer.2.DenseReluDense.dropout': 0, 'decoder.block.2.layer.2.DenseReluDense.act': 0, 'decoder.block.2.layer.2.layer_norm': 0, 'decoder.block.2.layer.2.dropout': 0, 'decoder.block.3.layer.0': 0, 'decoder.block.3.layer.1': 0, 'decoder.block.3.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.3.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.3.layer.2.DenseReluDense.dropout': 0, 'decoder.block.3.layer.2.DenseReluDense.act': 0, 'decoder.block.3.layer.2.layer_norm': 0, 'decoder.block.3.layer.2.dropout': 0, 'decoder.block.4.layer.0': 0, 'decoder.block.4.layer.1': 0, 'decoder.block.4.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.4.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.4.layer.2.DenseReluDense.dropout': 0, 'decoder.block.4.layer.2.DenseReluDense.act': 0, 'decoder.block.4.layer.2.layer_norm': 0, 'decoder.block.4.layer.2.dropout': 0, 'decoder.block.5.layer.0': 0, 'decoder.block.5.layer.1': 0, 'decoder.block.5.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.5.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.5.layer.2.DenseReluDense.dropout': 0, 'decoder.block.5.layer.2.DenseReluDense.act': 0, 'decoder.block.5.layer.2.layer_norm': 0, 'decoder.block.5.layer.2.dropout': 0, 'decoder.block.6.layer.0': 0, 'decoder.block.6.layer.1': 0, 'decoder.block.6.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.6.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.6.layer.2.DenseReluDense.dropout': 0, 'decoder.block.6.layer.2.DenseReluDense.act': 0, 'decoder.block.6.layer.2.layer_norm': 0, 'decoder.block.6.layer.2.dropout': 0, 'decoder.block.7.layer.0': 0, 'decoder.block.7.layer.1': 0, 'decoder.block.7.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.7.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.7.layer.2.DenseReluDense.dropout': 0, 'decoder.block.7.layer.2.DenseReluDense.act': 0, 'decoder.block.7.layer.2.layer_norm': 0, 'decoder.block.7.layer.2.dropout': 0, 'decoder.block.8.layer.0': 0, 'decoder.block.8.layer.1': 0, 'decoder.block.8.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.8.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.8.layer.2.DenseReluDense.dropout': 0, 'decoder.block.8.layer.2.DenseReluDense.act': 0, 'decoder.block.8.layer.2.layer_norm': 0, 'decoder.block.8.layer.2.dropout': 0, 'decoder.block.9.layer.0': 0, 'decoder.block.9.layer.1': 0, 'decoder.block.9.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.9.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.9.layer.2.DenseReluDense.dropout': 0, 'decoder.block.9.layer.2.DenseReluDense.act': 0, 'decoder.block.9.layer.2.layer_norm': 0, 'decoder.block.9.layer.2.dropout': 0, 'decoder.block.10.layer.0': 0, 'decoder.block.10.layer.1': 0, 'decoder.block.10.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.10.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.10.layer.2.DenseReluDense.dropout': 0, 'decoder.block.10.layer.2.DenseReluDense.act': 0, 'decoder.block.10.layer.2.layer_norm': 0, 'decoder.block.10.layer.2.dropout': 0, 'decoder.block.11.layer.0': 0, 'decoder.block.11.layer.1': 0, 'decoder.block.11.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.11.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.11.layer.2.DenseReluDense.dropout': 0, 'decoder.block.11.layer.2.DenseReluDense.act': 0, 'decoder.block.11.layer.2.layer_norm': 0, 'decoder.block.11.layer.2.dropout': 0, 'decoder.block.12.layer.0': 0, 'decoder.block.12.layer.1': 0, 'decoder.block.12.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.12.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.12.layer.2.DenseReluDense.dropout': 0, 'decoder.block.12.layer.2.DenseReluDense.act': 0, 'decoder.block.12.layer.2.layer_norm': 0, 'decoder.block.12.layer.2.dropout': 0, 'decoder.block.13.layer.0': 0, 'decoder.block.13.layer.1': 0, 'decoder.block.13.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.13.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.13.layer.2.DenseReluDense.dropout': 0, 'decoder.block.13.layer.2.DenseReluDense.act': 0, 'decoder.block.13.layer.2.layer_norm': 0, 'decoder.block.13.layer.2.dropout': 0, 'decoder.block.14.layer.0': 0, 'decoder.block.14.layer.1': 0, 'decoder.block.14.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.14.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.14.layer.2.DenseReluDense.dropout': 0, 'decoder.block.14.layer.2.DenseReluDense.act': 0, 'decoder.block.14.layer.2.layer_norm': 0, 'decoder.block.14.layer.2.dropout': 0, 'decoder.block.15.layer.0': 0, 'decoder.block.15.layer.1': 0, 'decoder.block.15.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.15.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.15.layer.2.DenseReluDense.dropout': 0, 'decoder.block.15.layer.2.DenseReluDense.act': 0, 'decoder.block.15.layer.2.layer_norm': 0, 'decoder.block.15.layer.2.dropout': 0, 'decoder.block.16.layer.0': 0, 'decoder.block.16.layer.1': 0, 'decoder.block.16.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.16.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.16.layer.2.DenseReluDense.dropout': 0, 'decoder.block.16.layer.2.DenseReluDense.act': 0, 'decoder.block.16.layer.2.layer_norm': 0, 'decoder.block.16.layer.2.dropout': 0, 'decoder.block.17.layer.0': 0, 'decoder.block.17.layer.1': 0, 'decoder.block.17.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.17.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.17.layer.2.DenseReluDense.dropout': 0, 'decoder.block.17.layer.2.DenseReluDense.act': 0, 'decoder.block.17.layer.2.layer_norm': 0, 'decoder.block.17.layer.2.dropout': 0, 'decoder.block.18.layer.0': 0, 'decoder.block.18.layer.1': 0, 'decoder.block.18.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.18.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.18.layer.2.DenseReluDense.dropout': 0, 'decoder.block.18.layer.2.DenseReluDense.act': 0, 'decoder.block.18.layer.2.layer_norm': 0, 'decoder.block.18.layer.2.dropout': 0, 'decoder.block.19.layer.0': 0, 'decoder.block.19.layer.1': 0, 'decoder.block.19.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.19.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.19.layer.2.DenseReluDense.dropout': 0, 'decoder.block.19.layer.2.DenseReluDense.act': 0, 'decoder.block.19.layer.2.layer_norm': 0, 'decoder.block.19.layer.2.dropout': 0, 'decoder.block.20.layer.0': 0, 'decoder.block.20.layer.1': 0, 'decoder.block.20.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.20.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.20.layer.2.DenseReluDense.dropout': 0, 'decoder.block.20.layer.2.DenseReluDense.act': 0, 'decoder.block.20.layer.2.layer_norm': 0, 'decoder.block.20.layer.2.dropout': 0, 'decoder.block.21.layer.0': 0, 'decoder.block.21.layer.1': 0, 'decoder.block.21.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.21.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.21.layer.2.DenseReluDense.dropout': 0, 'decoder.block.21.layer.2.DenseReluDense.act': 0, 'decoder.block.21.layer.2.layer_norm': 0, 'decoder.block.21.layer.2.dropout': 0, 'decoder.block.22.layer.0': 0, 'decoder.block.22.layer.1': 0, 'decoder.block.22.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.22.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.22.layer.2.DenseReluDense.dropout': 0, 'decoder.block.22.layer.2.DenseReluDense.act': 0, 'decoder.block.22.layer.2.layer_norm': 0, 'decoder.block.22.layer.2.dropout': 0, 'decoder.block.23.layer.0': 0, 'decoder.block.23.layer.1': 0, 'decoder.block.23.layer.2.DenseReluDense.wi_0': 0, 'decoder.block.23.layer.2.DenseReluDense.wi_1': 0, 'decoder.block.23.layer.2.DenseReluDense.dropout': 0, 'decoder.block.23.layer.2.DenseReluDense.act': 0, 'decoder.block.23.layer.2.layer_norm': 0, 'decoder.block.23.layer.2.dropout': 0, 'decoder.final_layer_norm': 0, 'decoder.dropout': 0, 'lm_head': 0} ```

Regarding the checkpoint, I originally suggested tomrb/flan-t5-xxl-sharded because I wasn't sure if the Space could handle the original's shard sizes. But if your device can, the original google/flan-t5-xxl is probably best.

@ryanramos I can certainly try over the next couple of days! Just for my understanding, do we really need to pass device_map if we are in single device mode? Like I see that you map everything to device:0, will it error out if you don't do so?

@deathcrush Sorry my bad, I was overthinking that, you can probably skip the whole device_map thing assuming everything fits on GPU.

@ryanramos , I can confirm that I was able to load google/flan-t5-xxl on a 24GB NVIDIA RTX 3090 and run inference with a trivial input to get the same results as the demo! device_map='auto' is required even in the single gpu case because it is required for quantisation.

In that case perhaps 16-bit quantization is good enough for the wo layers and lm_head? If so that's good to know

Well, I did not extensively test, only checked a single input!

Sign up or log in to comment