Training in progress, step 500
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +19 -0
- added_tokens.json +28 -0
- config.json +32 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- special_tokens_map.json +31 -0
- tmpcnl2zetw/__pycache__/_remote_module_non_scriptable.cpython-312.pyc +0 -0
- tmpcnl2zetw/_remote_module_non_scriptable.py +81 -0
- tmpfehyc297/_remote_module_non_scriptable.py +81 -0
- tokenizer.json +3 -0
- tokenizer_config.json +240 -0
- torchinductor_ch-epfl-345354-j/2f/8f4778b9c2bdc504a3c5d1b5bc09dac279f18294d78f93a4c581178508bcf83b.best_config +1 -0
- torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py +46 -0
- torchinductor_ch-epfl-345354-j/2r/717267c6902a1a61a8cc50b68a007cec3f90a0241185b112347df8b34fa8c605.best_config +1 -0
- torchinductor_ch-epfl-345354-j/2r/c2rhxwmxh62lojowjb65g6mbzowlwbjcacwjmn3vu63z4qatxuo3.py +26 -0
- torchinductor_ch-epfl-345354-j/2y/c2ykxnj2iqrpp4u3ihziotcanxl3tc27h7ajzahx5wypy4anuhuj.py +88 -0
- torchinductor_ch-epfl-345354-j/3m/c3mt4utggpr6zcsqyeele6646fofhvyk4xxtwll4gqqa5w6nrbct.py +55 -0
- torchinductor_ch-epfl-345354-j/43/c43m5ctxi7dcy4hjgz5jijzo4xp7fp3bmvzcjp3ygmirxptgoerd.py +53 -0
- torchinductor_ch-epfl-345354-j/4i/c4iarmybewwgyq7pa6izmajgs66hg4cgb6yhmezt4tg6j77oklfi.py +50 -0
- torchinductor_ch-epfl-345354-j/53/c53mrwlx5sxivgg5x5z6kkaldo2q5yn2pjsymcv27tpzj2cdoeww.py +66 -0
- torchinductor_ch-epfl-345354-j/56/c56q66j66nfzeu5puvuhal4wt2foih6rnb5nwmqolafn3iq33kjp.py +66 -0
- torchinductor_ch-epfl-345354-j/57/c573irrqes6p6it4yfyvqw2efgfbbgp7yjzxjjxq5jpeesj3bi77.py +353 -0
- torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py +30 -0
- torchinductor_ch-epfl-345354-j/57/dad7be19dc394c1e08368515640dff88b78797aaabae100a15a1f195476a9a87.best_config +1 -0
- torchinductor_ch-epfl-345354-j/5e/c5enonf6qztlsw7dozsqkejk4exzt4n56gbz6fiey2gnus5vdf76.py +66 -0
- torchinductor_ch-epfl-345354-j/5x/c5xsvywggx5vrzm2l5uaktu7pipclhdn5h6263yru2ugvuhe2nak.py +57 -0
- torchinductor_ch-epfl-345354-j/6d/c6dsbxlebwjqawzeprkq3lkldtxoiept4c6bpgtva5r4mjlrnwlr.py +229 -0
- torchinductor_ch-epfl-345354-j/6j/7c215475e7b40a21cf286026270965eb7f07e7c3af1c4052d331de3f74c6449e.best_config +1 -0
- torchinductor_ch-epfl-345354-j/6j/c6j5lx5qgycfvyi3dm5f4mo3ssluzzsrmdq32pka7e6pyhg42zvd.py +499 -0
- torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py +40 -0
- torchinductor_ch-epfl-345354-j/7a/c7a4b5izank2343xz4473c4igojrrhlfxb5ulctqd32qrtkreq3m.py +42 -0
- torchinductor_ch-epfl-345354-j/7a/ddaeab32a9175f6d14ae7329f3defe09537605c41fa4d35da5bf9cbac1616b91.best_config +1 -0
- torchinductor_ch-epfl-345354-j/7q/c7qudnwq7tyfwnepjsm2ilmratxdwkx4euvow7brbvrfif7hgnwh.py +324 -0
- torchinductor_ch-epfl-345354-j/a6/ca64rxymdowafnowfq53ckfynl3yei5mmfkeefu6f6xndlg3ukok.py +200 -0
- torchinductor_ch-epfl-345354-j/aotautograd/acxk7xhb35e5myvrfk4m2smos5f3rwybegalnbqbgtl3ghlaw2vw/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/adw7o5w6jucvlwdu4mn3nk52nno5z3lt73pmvaksrn3cahxlwc5t/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/agm67xcx3b2ejeqf3t422b43zsalmtzgitagqmb4kcd76dzg2sr6/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/ahinqqlnserz457jqclv2vjeogmqix7jcrylpuhbc64kw4k3apfy/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/ahji7b2arusm47q6ox5itjvurtws6r6kls2kskgxfnc2rqm4ojdg/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/aibbpzcrlnv7lrbehiaaab4olrvijekv6m46vdzzqh3tbnvnl67m/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/aig3hpjgzj7f27hhdphh7ozndqiwpruhugzjsiwyog75fn4y3rbj/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/algm7vsngjdke6rmqon76peuppnhsp625k5d4zxnwgwdbdueo4ay/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/amtpnp6cq6z6ddoun3fwe4zemhgpsp5jicklj6cf3qzsd3xbdeps/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/aqskia64x2j4xks7dhp5cpq52le5j6js6ghxfhlvw7gfa6qr6stx/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/aub73aicaqeihl4qdqbrvljzl2qxdzyu52zezeket676qt3pkgwk/entry +0 -0
- torchinductor_ch-epfl-345354-j/aotautograd/aypob3g4nwzt66m7ur252rhjjobqgnn4hvdhagr4474twkamikxg/entry +0 -0
- torchinductor_ch-epfl-345354-j/bm/cbmn253c3hy77ciw3f6meqi4bsbiio5zhw7hra5np6k5jyjqetnp.py +29 -0
- torchinductor_ch-epfl-345354-j/ce/cceyvpztlniy45jdq6sxx7o44obzjinfuxgsvnlhcr3hjdvmek73.py +38 -0
- torchinductor_ch-epfl-345354-j/cf/ae1632ffa009afdc4d40d5477a8e2ffd544972ad9ddf0c636c451826b3219579.best_config +1 -0
- torchinductor_ch-epfl-345354-j/cf/ccfnt2f53rlwauznvnabnitvjchbzg7at22w4x4fskqzmyirxuxq.py +50 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
torchinductor_ch-epfl-345354-j/fxgraph/5s/f5sdzjvbwcmgigljts5qiy6lpxvzqdph6wzu6phn3y3ibrcaorli/khjgqt4e4qpxv6eeg3vxxgpwqsv3grytkxtje4324jtolexrezo filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
torchinductor_ch-epfl-345354-j/fxgraph/6x/f6x5yse7q3r4kgrx3mzcmr7b2m72jtxrtziweqwk5lposwvr3y52/wqdd5re3wcyjqm2yolzwsoaqo7xdfcnam6vnnh7oy4krfpy35mu filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
torchinductor_ch-epfl-345354-j/fxgraph/ap/fap4jdkhbhos2zlgzy4vqldjzu7uaf3wuhatrkr2kcwc42gvg2yz/2zrvkpixflxwoduo4mf45p672nlwxrnyfi7jiaahhn2lw6eafh7 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
torchinductor_ch-epfl-345354-j/fxgraph/ay/fay24mgdmhudth5v4jelopeu2revquda5zewn436r6biwrnpgabo/rua4hs554jun3xlx2vr5pka3oqdzq77wfoxhvymrffnwxy27nbt filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
torchinductor_ch-epfl-345354-j/fxgraph/be/fbei2uvrpqs44s3zmuwyjn6byfolnsqoa7juh23nj5xwvypzh4qm/isndyt5bxypcz3icmz7sxxnxnv5k5flzys3ydjieaujxkpv2by2 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
torchinductor_ch-epfl-345354-j/fxgraph/dz/fdz6ntzqbiwadat6ybb6wx5r72slohcdcufolrfguxpyy72ha4k3/6wwa6mdlxwrojkbhhttg3r434evsqlaub3wfx47anl25o536b6x filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
torchinductor_ch-epfl-345354-j/fxgraph/el/felw7pslqa4fo6ex4wmphdc2ybuhij3y7nr6ek7weyw6exi5n6un/vtvnvlc4u5nq7hdpjcieofwiwev7pvczdillvjg3yxva6aojzbt filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
torchinductor_ch-epfl-345354-j/fxgraph/io/fiop3vsjr3dy3bv5akbf4igspko5bzgikfs2c5r65fwccfuvu7ux/qzhalrsuhi7mryjabm3iq6rhxlb5ozesgq6j7k5e7iyja2k7vsq filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
torchinductor_ch-epfl-345354-j/fxgraph/kl/fklqttybh3jmeqtx6bvvj4haqgxnjqufqyoz2nks5uigm4r36cx4/hxqz22nepqoek7pc76k5r3icp4zjsh4kaue55ohl6xt6tbvp4oi filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
torchinductor_ch-epfl-345354-j/fxgraph/lg/flg4hgizerwmptq2h73kr43ahriht7psoidazshg6qzgrnt647t2/pijpcpasxrep34djgpqcasekfyyagfa4g3bjwvevt3aj3ofuu6q filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
torchinductor_ch-epfl-345354-j/fxgraph/ng/fngku4s6qemdyw2pe5ve3jmj4j3kwxfgrfoqpednglkm5rrggltm/gpu6deqxpscuw7qomk5b3bn5nhb7huhwwde46fignb66f72pvmw filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
torchinductor_ch-epfl-345354-j/fxgraph/ro/froa33wtzy5mq3utlkybs5daxqnv2apit6smxhgb22cf4rr7fysv/rvynlwjqk2pli2qmjywzpal7ly5ufkhjhxzh4v7y5dgjunk5bmh filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
torchinductor_ch-epfl-345354-j/fxgraph/tm/ftmxrirms7cnlnwoxoz62nacrceoyhhzhbnuwmnbxdx426kuxbca/nmw3t6ssgbvfrpuhau55inf3woabflb6qc3ynybm53y42slvkos filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
torchinductor_ch-epfl-345354-j/fxgraph/tu/ftuiasjkco5eqbhoc3ebj57kfzkbkbzglhq2wrkg4hjpkhixsfbl/qpwpe3w2vqzxprimhpukbj46zopeocm5rat5f5apx7ghtoyyi67 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
torchinductor_ch-epfl-345354-j/fxgraph/xb/fxbygtwj5f3jfikmhfoz7vd3mtceabjyw5ggscr4lin73qu5vibc/typ73hul4yzbqjyctyc3spof5hr4okcdyp27oj4p5lpfovus76n filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
torchinductor_ch-epfl-345354-j/triton/0/35ZDQUBPR56EFY64BDG6UB7OKJODSPJPYCTXXIKSOWZ2CA3EPWGA/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
torchinductor_ch-epfl-345354-j/triton/0/SS6ERMBZYPQUFPYGWYCICDOLY35WYNXDIHS3C45P5EZJOFLNB5SQ/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
torchinductor_ch-epfl-345354-j/triton/0/YGRXAKJ5T6UCGDY6RY3GMSSD4JVJFIRM6WYZVP7IT63SMM7NWHSA/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</think>": 151668,
|
| 3 |
+
"</tool_call>": 151658,
|
| 4 |
+
"</tool_response>": 151666,
|
| 5 |
+
"<think>": 151667,
|
| 6 |
+
"<tool_call>": 151657,
|
| 7 |
+
"<tool_response>": 151665,
|
| 8 |
+
"<|box_end|>": 151649,
|
| 9 |
+
"<|box_start|>": 151648,
|
| 10 |
+
"<|endoftext|>": 151643,
|
| 11 |
+
"<|file_sep|>": 151664,
|
| 12 |
+
"<|fim_middle|>": 151660,
|
| 13 |
+
"<|fim_pad|>": 151662,
|
| 14 |
+
"<|fim_prefix|>": 151659,
|
| 15 |
+
"<|fim_suffix|>": 151661,
|
| 16 |
+
"<|im_end|>": 151645,
|
| 17 |
+
"<|im_start|>": 151644,
|
| 18 |
+
"<|image_pad|>": 151655,
|
| 19 |
+
"<|object_ref_end|>": 151647,
|
| 20 |
+
"<|object_ref_start|>": 151646,
|
| 21 |
+
"<|quad_end|>": 151651,
|
| 22 |
+
"<|quad_start|>": 151650,
|
| 23 |
+
"<|repo_name|>": 151663,
|
| 24 |
+
"<|video_pad|>": 151656,
|
| 25 |
+
"<|vision_end|>": 151653,
|
| 26 |
+
"<|vision_pad|>": 151654,
|
| 27 |
+
"<|vision_start|>": 151652
|
| 28 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"eos_token_id": 151643,
|
| 8 |
+
"head_dim": 128,
|
| 9 |
+
"hidden_act": "silu",
|
| 10 |
+
"hidden_size": 1024,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": 3072,
|
| 13 |
+
"max_position_embeddings": 32768,
|
| 14 |
+
"max_window_layers": 28,
|
| 15 |
+
"model_type": "qwen3",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_hidden_layers": 28,
|
| 18 |
+
"num_key_value_heads": 8,
|
| 19 |
+
"pad_token_id": 151654,
|
| 20 |
+
"rms_norm_eps": 1e-06,
|
| 21 |
+
"rope_scaling": null,
|
| 22 |
+
"rope_theta": 1000000,
|
| 23 |
+
"sliding_window": null,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"torch_dtype": "bfloat16",
|
| 26 |
+
"transformers_version": "4.51.3",
|
| 27 |
+
"unsloth_fixed": true,
|
| 28 |
+
"unsloth_version": "2025.5.7",
|
| 29 |
+
"use_cache": true,
|
| 30 |
+
"use_sliding_window": false,
|
| 31 |
+
"vocab_size": 151936
|
| 32 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26045c44c8f7b94239544f69255de30ea5e05c7f8e23f0c01a67e755eaa0beba
|
| 3 |
+
size 1192135096
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": {
|
| 25 |
+
"content": "<|vision_pad|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
}
|
| 31 |
+
}
|
tmpcnl2zetw/__pycache__/_remote_module_non_scriptable.cpython-312.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
tmpcnl2zetw/_remote_module_non_scriptable.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed.rpc as rpc
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch._jit_internal import Future
|
| 7 |
+
from torch.distributed.rpc import RRef
|
| 8 |
+
from typing import Tuple # pyre-ignore: unused import
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
module_interface_cls = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def forward_async(self, *args, **kwargs):
|
| 15 |
+
args = (self.module_rref, self.device, self.is_device_map_set, *args)
|
| 16 |
+
kwargs = {**kwargs}
|
| 17 |
+
return rpc.rpc_async(
|
| 18 |
+
self.module_rref.owner(),
|
| 19 |
+
_remote_forward,
|
| 20 |
+
args,
|
| 21 |
+
kwargs,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def forward(self, *args, **kwargs):
|
| 26 |
+
args = (self.module_rref, self.device, self.is_device_map_set, *args)
|
| 27 |
+
kwargs = {**kwargs}
|
| 28 |
+
ret_fut = rpc.rpc_async(
|
| 29 |
+
self.module_rref.owner(),
|
| 30 |
+
_remote_forward,
|
| 31 |
+
args,
|
| 32 |
+
kwargs,
|
| 33 |
+
)
|
| 34 |
+
return ret_fut.wait()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_generated_methods = [
|
| 38 |
+
forward_async,
|
| 39 |
+
forward,
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _remote_forward(
|
| 46 |
+
module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs):
|
| 47 |
+
module = module_rref.local_value()
|
| 48 |
+
device = torch.device(device)
|
| 49 |
+
|
| 50 |
+
if device.type != "cuda":
|
| 51 |
+
return module.forward(*args, **kwargs)
|
| 52 |
+
|
| 53 |
+
# If the module is on a cuda device,
|
| 54 |
+
# move any CPU tensor in args or kwargs to the same cuda device.
|
| 55 |
+
# Since torch script does not support generator expression,
|
| 56 |
+
# have to use concatenation instead of
|
| 57 |
+
# ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
|
| 58 |
+
args = (*args,)
|
| 59 |
+
out_args: Tuple[()] = ()
|
| 60 |
+
for arg in args:
|
| 61 |
+
arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
|
| 62 |
+
out_args = out_args + arg
|
| 63 |
+
|
| 64 |
+
kwargs = {**kwargs}
|
| 65 |
+
for k, v in kwargs.items():
|
| 66 |
+
if isinstance(v, Tensor):
|
| 67 |
+
kwargs[k] = kwargs[k].to(device)
|
| 68 |
+
|
| 69 |
+
if is_device_map_set:
|
| 70 |
+
return module.forward(*out_args, **kwargs)
|
| 71 |
+
|
| 72 |
+
# If the device map is empty, then only CPU tensors are allowed to send over wire,
|
| 73 |
+
# so have to move any GPU tensor to CPU in the output.
|
| 74 |
+
# Since torch script does not support generator expression,
|
| 75 |
+
# have to use concatenation instead of
|
| 76 |
+
# ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, **kwargs))``.
|
| 77 |
+
ret: Tuple[()] = ()
|
| 78 |
+
for i in module.forward(*out_args, **kwargs):
|
| 79 |
+
i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
|
| 80 |
+
ret = ret + i
|
| 81 |
+
return ret
|
tmpfehyc297/_remote_module_non_scriptable.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed.rpc as rpc
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch._jit_internal import Future
|
| 7 |
+
from torch.distributed.rpc import RRef
|
| 8 |
+
from typing import Tuple # pyre-ignore: unused import
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
module_interface_cls = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def forward_async(self, *args, **kwargs):
|
| 15 |
+
args = (self.module_rref, self.device, self.is_device_map_set, *args)
|
| 16 |
+
kwargs = {**kwargs}
|
| 17 |
+
return rpc.rpc_async(
|
| 18 |
+
self.module_rref.owner(),
|
| 19 |
+
_remote_forward,
|
| 20 |
+
args,
|
| 21 |
+
kwargs,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def forward(self, *args, **kwargs):
|
| 26 |
+
args = (self.module_rref, self.device, self.is_device_map_set, *args)
|
| 27 |
+
kwargs = {**kwargs}
|
| 28 |
+
ret_fut = rpc.rpc_async(
|
| 29 |
+
self.module_rref.owner(),
|
| 30 |
+
_remote_forward,
|
| 31 |
+
args,
|
| 32 |
+
kwargs,
|
| 33 |
+
)
|
| 34 |
+
return ret_fut.wait()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_generated_methods = [
|
| 38 |
+
forward_async,
|
| 39 |
+
forward,
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _remote_forward(
|
| 46 |
+
module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs):
|
| 47 |
+
module = module_rref.local_value()
|
| 48 |
+
device = torch.device(device)
|
| 49 |
+
|
| 50 |
+
if device.type != "cuda":
|
| 51 |
+
return module.forward(*args, **kwargs)
|
| 52 |
+
|
| 53 |
+
# If the module is on a cuda device,
|
| 54 |
+
# move any CPU tensor in args or kwargs to the same cuda device.
|
| 55 |
+
# Since torch script does not support generator expression,
|
| 56 |
+
# have to use concatenation instead of
|
| 57 |
+
# ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
|
| 58 |
+
args = (*args,)
|
| 59 |
+
out_args: Tuple[()] = ()
|
| 60 |
+
for arg in args:
|
| 61 |
+
arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
|
| 62 |
+
out_args = out_args + arg
|
| 63 |
+
|
| 64 |
+
kwargs = {**kwargs}
|
| 65 |
+
for k, v in kwargs.items():
|
| 66 |
+
if isinstance(v, Tensor):
|
| 67 |
+
kwargs[k] = kwargs[k].to(device)
|
| 68 |
+
|
| 69 |
+
if is_device_map_set:
|
| 70 |
+
return module.forward(*out_args, **kwargs)
|
| 71 |
+
|
| 72 |
+
# If the device map is empty, then only CPU tensors are allowed to send over wire,
|
| 73 |
+
# so have to move any GPU tensor to CPU in the output.
|
| 74 |
+
# Since torch script does not support generator expression,
|
| 75 |
+
# have to use concatenation instead of
|
| 76 |
+
# ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, **kwargs))``.
|
| 77 |
+
ret: Tuple[()] = ()
|
| 78 |
+
for i in module.forward(*out_args, **kwargs):
|
| 79 |
+
i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
|
| 80 |
+
ret = ret + i
|
| 81 |
+
return ret
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
|
| 3 |
+
size 11422654
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"clean_up_tokenization_spaces": false,
|
| 231 |
+
"eos_token": "<|endoftext|>",
|
| 232 |
+
"errors": "replace",
|
| 233 |
+
"extra_special_tokens": {},
|
| 234 |
+
"model_max_length": 32768,
|
| 235 |
+
"pad_token": "<|vision_pad|>",
|
| 236 |
+
"padding_side": "left",
|
| 237 |
+
"split_special_tokens": false,
|
| 238 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 239 |
+
"unk_token": null
|
| 240 |
+
}
|
torchinductor_ch-epfl-345354-j/2f/8f4778b9c2bdc504a3c5d1b5bc09dac279f18294d78f93a4c581178508bcf83b.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50}
|
torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 8388608},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x4 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x2 = ((xindex // ks1) % ks2)
|
| 25 |
+
x5 = xindex // ks0
|
| 26 |
+
tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 27 |
+
tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 28 |
+
tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 29 |
+
tmp2 = tmp0 * tmp1
|
| 30 |
+
tmp3 = x0
|
| 31 |
+
tmp4 = tl.full([1], 0, tl.int64)
|
| 32 |
+
tmp5 = tmp3 >= tmp4
|
| 33 |
+
tmp6 = ks0 + (-1)*(ks0 // 2)
|
| 34 |
+
tmp7 = tmp3 < tmp6
|
| 35 |
+
tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 36 |
+
tmp9 = -tmp8
|
| 37 |
+
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
|
| 38 |
+
tmp11 = tl.where(tmp7, tmp9, tmp10)
|
| 39 |
+
tmp12 = tmp3 >= tmp6
|
| 40 |
+
tmp13 = ks0
|
| 41 |
+
tmp14 = tmp3 < tmp13
|
| 42 |
+
tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 43 |
+
tmp16 = tl.where(tmp7, tmp11, tmp15)
|
| 44 |
+
tmp18 = tmp16 * tmp17
|
| 45 |
+
tmp19 = tmp2 + tmp18
|
| 46 |
+
tl.store(out_ptr0 + (x4), tmp19, xmask)
|
torchinductor_ch-epfl-345354-j/2r/717267c6902a1a61a8cc50b68a007cec3f90a0241185b112347df8b34fa8c605.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 81}
|
torchinductor_ch-epfl-345354-j/2r/c2rhxwmxh62lojowjb65g6mbzowlwbjcacwjmn3vu63z4qatxuo3.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 2048},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_3', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_nll_loss_backward_3(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x0 = xindex
|
| 23 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 24 |
+
tl.device_assert(((0 <= tmp0) & (tmp0 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp0 < ks0")
|
| 25 |
+
tmp2 = -1.0
|
| 26 |
+
tl.store(out_ptr0 + (tmp0 + ks0*x0), tmp2, xmask)
|
torchinductor_ch-epfl-345354-j/2y/c2ykxnj2iqrpp4u3ihziotcanxl3tc27h7ajzahx5wypy4anuhuj.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 1, 'r0_': 2048},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'in_ptr7': '*fp32', 'in_ptr8': '*fp32', 'in_ptr9': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_nll_loss_backward_nll_loss_forward_15', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 10, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_add_nll_loss_backward_nll_loss_forward_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
_tmp28 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 28 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 29 |
+
r0_index = r0_offset + r0_base
|
| 30 |
+
r0_mask = r0_index < r0_numel
|
| 31 |
+
roffset = r0_offset
|
| 32 |
+
rindex = r0_index
|
| 33 |
+
r0_0 = r0_index
|
| 34 |
+
tmp5 = tl.load(in_ptr1 + (r0_0 + 6*((6 + ks0*ks1) // 7)), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 35 |
+
tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 36 |
+
tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 37 |
+
tmp0 = ((r0_0 + 6*((6 + ks0*ks1) // 7)) % ks1)
|
| 38 |
+
tmp1 = (-1) + ks1
|
| 39 |
+
tmp2 = tmp0 == tmp1
|
| 40 |
+
tmp3 = tmp0 < tmp1
|
| 41 |
+
tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + r0_0 + 6*((6 + ks0*ks1) // 7), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_first', other=0.0)
|
| 42 |
+
tmp6 = tl.where(tmp3, tmp4, tmp5)
|
| 43 |
+
tmp7 = tl.full([1, 1], -100, tl.int64)
|
| 44 |
+
tmp8 = tl.where(tmp2, tmp7, tmp6)
|
| 45 |
+
tmp9 = tmp8 != tmp7
|
| 46 |
+
tmp10 = tl.full([1, 1], 0, tl.int64)
|
| 47 |
+
tmp11 = tl.where(tmp9, tmp8, tmp10)
|
| 48 |
+
tmp12 = ks2
|
| 49 |
+
tmp13 = tmp11 + tmp12
|
| 50 |
+
tmp14 = tmp11 < 0
|
| 51 |
+
tmp15 = tl.where(tmp14, tmp13, tmp11)
|
| 52 |
+
tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
|
| 53 |
+
tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 6*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 54 |
+
tmp18 = tmp17.to(tl.float32)
|
| 55 |
+
tmp20 = tmp18 - tmp19
|
| 56 |
+
tmp22 = tl_math.log(tmp21)
|
| 57 |
+
tmp23 = tmp20 - tmp22
|
| 58 |
+
tmp24 = -tmp23
|
| 59 |
+
tmp25 = 0.0
|
| 60 |
+
tmp26 = tl.where(tmp9, tmp24, tmp25)
|
| 61 |
+
tmp27 = tl.broadcast_to(tmp26, [XBLOCK, R0_BLOCK])
|
| 62 |
+
tmp29 = _tmp28 + tmp27
|
| 63 |
+
_tmp28 = tl.where(r0_mask, tmp29, _tmp28)
|
| 64 |
+
tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
|
| 65 |
+
tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
|
| 66 |
+
tmp28 = tl.sum(_tmp28, 1)[:, None]
|
| 67 |
+
tmp30 = tl.load(in_out_ptr0 + (0))
|
| 68 |
+
tmp31 = tl.broadcast_to(tmp30, [XBLOCK, 1])
|
| 69 |
+
tmp34 = tl.load(in_ptr5 + (0))
|
| 70 |
+
tmp35 = tl.broadcast_to(tmp34, [XBLOCK, 1])
|
| 71 |
+
tmp37 = tl.load(in_ptr6 + (0))
|
| 72 |
+
tmp38 = tl.broadcast_to(tmp37, [XBLOCK, 1])
|
| 73 |
+
tmp40 = tl.load(in_ptr7 + (0))
|
| 74 |
+
tmp41 = tl.broadcast_to(tmp40, [XBLOCK, 1])
|
| 75 |
+
tmp43 = tl.load(in_ptr8 + (0))
|
| 76 |
+
tmp44 = tl.broadcast_to(tmp43, [XBLOCK, 1])
|
| 77 |
+
tmp46 = tl.load(in_ptr9 + (0))
|
| 78 |
+
tmp47 = tl.broadcast_to(tmp46, [XBLOCK, 1])
|
| 79 |
+
tmp32 = 0.0
|
| 80 |
+
tmp33 = tmp31 + tmp32
|
| 81 |
+
tmp36 = tmp33 + tmp35
|
| 82 |
+
tmp39 = tmp36 + tmp38
|
| 83 |
+
tmp42 = tmp39 + tmp41
|
| 84 |
+
tmp45 = tmp42 + tmp44
|
| 85 |
+
tmp48 = tmp45 + tmp47
|
| 86 |
+
tmp49 = tmp48 + tmp28
|
| 87 |
+
tl.debug_barrier()
|
| 88 |
+
tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp49, None)
|
torchinductor_ch-epfl-345354-j/3m/c3mt4utggpr6zcsqyeele6646fofhvyk4xxtwll4gqqa5w6nrbct.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 2048, 'r0_': 262144},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__to_copy_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused__log_softmax__to_copy_2(in_out_ptr0, in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
rnumel = r0_numel
|
| 20 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 21 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 22 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 23 |
+
xmask = xindex < xnumel
|
| 24 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 25 |
+
rbase = r0_base
|
| 26 |
+
x0 = xindex
|
| 27 |
+
_tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
|
| 28 |
+
_tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 29 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 30 |
+
r0_index = r0_offset + r0_base
|
| 31 |
+
r0_mask = r0_index < r0_numel
|
| 32 |
+
roffset = r0_offset
|
| 33 |
+
rindex = r0_index
|
| 34 |
+
r0_1 = r0_index
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + ks2*x0 + 2*ks2*((6 + ks0*ks1) // 7)), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 36 |
+
tmp1 = tmp0.to(tl.float32)
|
| 37 |
+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
| 38 |
+
|
| 39 |
+
_tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
|
| 40 |
+
_tmp3_max, _tmp3_sum, tmp2, False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
|
| 44 |
+
_tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
|
| 45 |
+
|
| 46 |
+
tmp5, tmp6 = triton_helpers.online_softmax_reduce(
|
| 47 |
+
_tmp3_max, _tmp3_sum, 1, False)
|
| 48 |
+
tmp5 = tmp5[:, None]
|
| 49 |
+
tmp6 = tmp6[:, None]
|
| 50 |
+
tmp3 = tmp5
|
| 51 |
+
tmp4 = tmp6
|
| 52 |
+
tl.store(out_ptr0 + (x0), tmp3, xmask)
|
| 53 |
+
tmp7 = tl_math.log(tmp4)
|
| 54 |
+
tl.debug_barrier()
|
| 55 |
+
tl.store(in_out_ptr0 + (x0), tmp7, xmask)
|
torchinductor_ch-epfl-345354-j/43/c43m5ctxi7dcy4hjgz5jijzo4xp7fp3bmvzcjp3ygmirxptgoerd.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 2048, 'r0_': 262144},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_5', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused__to_copy_5(in_ptr0, out_ptr0, out_ptr1, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
rnumel = r0_numel
|
| 20 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 21 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 22 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 23 |
+
xmask = xindex < xnumel
|
| 24 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 25 |
+
rbase = r0_base
|
| 26 |
+
x0 = xindex
|
| 27 |
+
_tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
|
| 28 |
+
_tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 29 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 30 |
+
r0_index = r0_offset + r0_base
|
| 31 |
+
r0_mask = r0_index < r0_numel
|
| 32 |
+
roffset = r0_offset
|
| 33 |
+
rindex = r0_index
|
| 34 |
+
r0_1 = r0_index
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + ks2*x0 + 5*ks2*((6 + ks0*ks1) // 7)), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 36 |
+
tmp1 = tmp0.to(tl.float32)
|
| 37 |
+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
| 38 |
+
|
| 39 |
+
_tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
|
| 40 |
+
_tmp3_max, _tmp3_sum, tmp2, False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
|
| 44 |
+
_tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
|
| 45 |
+
|
| 46 |
+
tmp5, tmp6 = triton_helpers.online_softmax_reduce(
|
| 47 |
+
_tmp3_max, _tmp3_sum, 1, False)
|
| 48 |
+
tmp5 = tmp5[:, None]
|
| 49 |
+
tmp6 = tmp6[:, None]
|
| 50 |
+
tmp3 = tmp5
|
| 51 |
+
tmp4 = tmp6
|
| 52 |
+
tl.store(out_ptr0 + (x0), tmp3, xmask)
|
| 53 |
+
tl.store(out_ptr1 + (x0), tmp4, xmask)
|
torchinductor_ch-epfl-345354-j/4i/c4iarmybewwgyq7pa6izmajgs66hg4cgb6yhmezt4tg6j77oklfi.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 8192, 'r0_': 1024},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel):
|
| 19 |
+
XBLOCK: tl.constexpr = 1
|
| 20 |
+
r0_numel = 1024
|
| 21 |
+
R0_BLOCK: tl.constexpr = 1024
|
| 22 |
+
rnumel = r0_numel
|
| 23 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 24 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 25 |
+
xindex = tl.full([1], xoffset, tl.int32)
|
| 26 |
+
xmask = tl.full([R0_BLOCK], True, tl.int1)
|
| 27 |
+
r0_index = tl.arange(0, R0_BLOCK)[:]
|
| 28 |
+
r0_offset = 0
|
| 29 |
+
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
|
| 30 |
+
roffset = r0_offset
|
| 31 |
+
rindex = r0_index
|
| 32 |
+
r0_1 = r0_index
|
| 33 |
+
x0 = xindex
|
| 34 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
|
| 35 |
+
tmp11 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 36 |
+
tmp1 = tmp0.to(tl.float32)
|
| 37 |
+
tmp2 = tmp1 * tmp1
|
| 38 |
+
tmp3 = tl.broadcast_to(tmp2, [R0_BLOCK])
|
| 39 |
+
tmp5 = triton_helpers.promote_to_tensor(tl.sum(tmp3, 0))
|
| 40 |
+
tmp6 = 1024.0
|
| 41 |
+
tmp7 = (tmp5 / tmp6)
|
| 42 |
+
tmp8 = 1e-06
|
| 43 |
+
tmp9 = tmp7 + tmp8
|
| 44 |
+
tmp10 = libdevice.rsqrt(tmp9)
|
| 45 |
+
tmp12 = tmp1 * tmp10
|
| 46 |
+
tmp13 = tmp12.to(tl.float32)
|
| 47 |
+
tmp14 = tmp11 * tmp13
|
| 48 |
+
tl.debug_barrier()
|
| 49 |
+
tl.store(in_out_ptr0 + (x0), tmp10, None)
|
| 50 |
+
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp14, None)
|
torchinductor_ch-epfl-345354-j/53/c53mrwlx5sxivgg5x5z6kkaldo2q5yn2pjsymcv27tpzj2cdoeww.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 1, 'r0_': 2048},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_11', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_nll_loss_backward_nll_loss_forward_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
_tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 28 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 29 |
+
r0_index = r0_offset + r0_base
|
| 30 |
+
r0_mask = r0_index < r0_numel
|
| 31 |
+
roffset = r0_offset
|
| 32 |
+
rindex = r0_index
|
| 33 |
+
r0_0 = r0_index
|
| 34 |
+
tmp5 = tl.load(in_ptr1 + (((r0_0 + 2*((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 35 |
+
tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 36 |
+
tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 37 |
+
tmp0 = ((r0_0 + 2*((6 + ks0*ks1) // 7)) % ks1)
|
| 38 |
+
tmp1 = (-1) + ks1
|
| 39 |
+
tmp2 = tmp0 == tmp1
|
| 40 |
+
tmp3 = tmp0 < tmp1
|
| 41 |
+
tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + 2*((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
|
| 42 |
+
tmp6 = tl.where(tmp3, tmp4, tmp5)
|
| 43 |
+
tmp7 = tl.full([1, 1], -100, tl.int64)
|
| 44 |
+
tmp8 = tl.where(tmp2, tmp7, tmp6)
|
| 45 |
+
tmp9 = tmp8 != tmp7
|
| 46 |
+
tmp10 = tl.full([1, 1], 0, tl.int64)
|
| 47 |
+
tmp11 = tl.where(tmp9, tmp8, tmp10)
|
| 48 |
+
tmp12 = ks2
|
| 49 |
+
tmp13 = tmp11 + tmp12
|
| 50 |
+
tmp14 = tmp11 < 0
|
| 51 |
+
tmp15 = tl.where(tmp14, tmp13, tmp11)
|
| 52 |
+
tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
|
| 53 |
+
tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 2*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 54 |
+
tmp18 = tmp17.to(tl.float32)
|
| 55 |
+
tmp20 = tmp18 - tmp19
|
| 56 |
+
tmp22 = tmp20 - tmp21
|
| 57 |
+
tmp23 = -tmp22
|
| 58 |
+
tmp24 = 0.0
|
| 59 |
+
tmp25 = tl.where(tmp9, tmp23, tmp24)
|
| 60 |
+
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
|
| 61 |
+
tmp28 = _tmp27 + tmp26
|
| 62 |
+
_tmp27 = tl.where(r0_mask, tmp28, _tmp27)
|
| 63 |
+
tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
|
| 64 |
+
tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
|
| 65 |
+
tmp27 = tl.sum(_tmp27, 1)[:, None]
|
| 66 |
+
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
|
torchinductor_ch-epfl-345354-j/56/c56q66j66nfzeu5puvuhal4wt2foih6rnb5nwmqolafn3iq33kjp.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 1, 'r0_': 2048},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_10', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_nll_loss_backward_nll_loss_forward_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
_tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 28 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 29 |
+
r0_index = r0_offset + r0_base
|
| 30 |
+
r0_mask = r0_index < r0_numel
|
| 31 |
+
roffset = r0_offset
|
| 32 |
+
rindex = r0_index
|
| 33 |
+
r0_0 = r0_index
|
| 34 |
+
tmp5 = tl.load(in_ptr1 + (((r0_0 + ((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 35 |
+
tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 36 |
+
tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 37 |
+
tmp0 = ((r0_0 + ((6 + ks0*ks1) // 7)) % ks1)
|
| 38 |
+
tmp1 = (-1) + ks1
|
| 39 |
+
tmp2 = tmp0 == tmp1
|
| 40 |
+
tmp3 = tmp0 < tmp1
|
| 41 |
+
tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + ((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
|
| 42 |
+
tmp6 = tl.where(tmp3, tmp4, tmp5)
|
| 43 |
+
tmp7 = tl.full([1, 1], -100, tl.int64)
|
| 44 |
+
tmp8 = tl.where(tmp2, tmp7, tmp6)
|
| 45 |
+
tmp9 = tmp8 != tmp7
|
| 46 |
+
tmp10 = tl.full([1, 1], 0, tl.int64)
|
| 47 |
+
tmp11 = tl.where(tmp9, tmp8, tmp10)
|
| 48 |
+
tmp12 = ks2
|
| 49 |
+
tmp13 = tmp11 + tmp12
|
| 50 |
+
tmp14 = tmp11 < 0
|
| 51 |
+
tmp15 = tl.where(tmp14, tmp13, tmp11)
|
| 52 |
+
tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
|
| 53 |
+
tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 54 |
+
tmp18 = tmp17.to(tl.float32)
|
| 55 |
+
tmp20 = tmp18 - tmp19
|
| 56 |
+
tmp22 = tmp20 - tmp21
|
| 57 |
+
tmp23 = -tmp22
|
| 58 |
+
tmp24 = 0.0
|
| 59 |
+
tmp25 = tl.where(tmp9, tmp23, tmp24)
|
| 60 |
+
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
|
| 61 |
+
tmp28 = _tmp27 + tmp26
|
| 62 |
+
_tmp27 = tl.where(r0_mask, tmp28, _tmp27)
|
| 63 |
+
tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
|
| 64 |
+
tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
|
| 65 |
+
tmp27 = tl.sum(_tmp27, 1)[:, None]
|
| 66 |
+
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
|
torchinductor_ch-epfl-345354-j/57/c573irrqes6p6it4yfyvqw2efgfbbgp7yjzxjjxq5jpeesj3bi77.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile-time auto-tuning block:
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.testing import rand_strided
|
| 6 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 7 |
+
from torch._inductor.select_algorithm import AlgorithmSelectorCache
|
| 8 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 9 |
+
|
| 10 |
+
async_compile = AsyncCompile()
|
| 11 |
+
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
| 12 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 13 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
triton_poi_fused_add_cat_mul_0 = async_compile.triton('triton_poi_fused_add_cat_mul_0', '''
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 21 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 22 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 23 |
+
triton_helpers.set_driver_to_gpu()
|
| 24 |
+
|
| 25 |
+
@triton_heuristics.pointwise(
|
| 26 |
+
size_hints={'x': 16777216},
|
| 27 |
+
filename=__file__,
|
| 28 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 29 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 30 |
+
min_elem_per_thread=0
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def triton_poi_fused_add_cat_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
|
| 34 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 35 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 36 |
+
xmask = xindex < xnumel
|
| 37 |
+
x4 = xindex
|
| 38 |
+
x0 = (xindex % ks0)
|
| 39 |
+
x2 = ((xindex // ks1) % ks2)
|
| 40 |
+
x5 = xindex // ks0
|
| 41 |
+
tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 42 |
+
tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 43 |
+
tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 44 |
+
tmp2 = tmp0 * tmp1
|
| 45 |
+
tmp3 = x0
|
| 46 |
+
tmp4 = tl.full([1], 0, tl.int64)
|
| 47 |
+
tmp5 = tmp3 >= tmp4
|
| 48 |
+
tmp6 = ks0 + (-1)*(ks0 // 2)
|
| 49 |
+
tmp7 = tmp3 < tmp6
|
| 50 |
+
tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 51 |
+
tmp9 = -tmp8
|
| 52 |
+
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
|
| 53 |
+
tmp11 = tl.where(tmp7, tmp9, tmp10)
|
| 54 |
+
tmp12 = tmp3 >= tmp6
|
| 55 |
+
tmp13 = ks0
|
| 56 |
+
tmp14 = tmp3 < tmp13
|
| 57 |
+
tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 58 |
+
tmp16 = tl.where(tmp7, tmp11, tmp15)
|
| 59 |
+
tmp18 = tmp16 * tmp17
|
| 60 |
+
tmp19 = tmp2 + tmp18
|
| 61 |
+
tl.store(out_ptr0 + (x4), tmp19, xmask)
|
| 62 |
+
''', device_str='cuda')
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
triton_poi_fused_add_cat_mul_1 = async_compile.triton('triton_poi_fused_add_cat_mul_1', '''
|
| 66 |
+
import triton
|
| 67 |
+
import triton.language as tl
|
| 68 |
+
|
| 69 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 70 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 71 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 72 |
+
triton_helpers.set_driver_to_gpu()
|
| 73 |
+
|
| 74 |
+
@triton_heuristics.pointwise(
|
| 75 |
+
size_hints={'x': 8388608},
|
| 76 |
+
filename=__file__,
|
| 77 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 78 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 79 |
+
min_elem_per_thread=0
|
| 80 |
+
)
|
| 81 |
+
@triton.jit
|
| 82 |
+
def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
|
| 83 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 84 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 85 |
+
xmask = xindex < xnumel
|
| 86 |
+
x4 = xindex
|
| 87 |
+
x0 = (xindex % ks0)
|
| 88 |
+
x2 = ((xindex // ks1) % ks2)
|
| 89 |
+
x5 = xindex // ks0
|
| 90 |
+
tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 91 |
+
tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 92 |
+
tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 93 |
+
tmp2 = tmp0 * tmp1
|
| 94 |
+
tmp3 = x0
|
| 95 |
+
tmp4 = tl.full([1], 0, tl.int64)
|
| 96 |
+
tmp5 = tmp3 >= tmp4
|
| 97 |
+
tmp6 = ks0 + (-1)*(ks0 // 2)
|
| 98 |
+
tmp7 = tmp3 < tmp6
|
| 99 |
+
tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 100 |
+
tmp9 = -tmp8
|
| 101 |
+
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
|
| 102 |
+
tmp11 = tl.where(tmp7, tmp9, tmp10)
|
| 103 |
+
tmp12 = tmp3 >= tmp6
|
| 104 |
+
tmp13 = ks0
|
| 105 |
+
tmp14 = tmp3 < tmp13
|
| 106 |
+
tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 107 |
+
tmp16 = tl.where(tmp7, tmp11, tmp15)
|
| 108 |
+
tmp18 = tmp16 * tmp17
|
| 109 |
+
tmp19 = tmp2 + tmp18
|
| 110 |
+
tl.store(out_ptr0 + (x4), tmp19, xmask)
|
| 111 |
+
''', device_str='cuda')
|
| 112 |
+
|
| 113 |
+
async_compile.wait(globals())
|
| 114 |
+
del async_compile
|
| 115 |
+
|
| 116 |
+
import triton
|
| 117 |
+
import triton.language as tl
|
| 118 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 119 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 120 |
+
with torch.cuda._DeviceGuard(0):
|
| 121 |
+
torch.cuda.set_device(0)
|
| 122 |
+
stream0 = get_raw_stream(0)
|
| 123 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 124 |
+
stream0 = get_raw_stream(0)
|
| 125 |
+
arg7_1 = generate_example_value((8, 16, 1000, 128), (2048000, 128, 2048, 1), 'cuda:0', torch.bfloat16, 0, (8, 16, 1000, 128))
|
| 126 |
+
arg2_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
|
| 127 |
+
arg4_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
|
| 128 |
+
buf0 = generate_example_value((8, 16, 1000, 128), (2048000, 128, 2048, 1), 'cuda:0', torch.bfloat16, 0, (8, 16, 1000, 128))
|
| 129 |
+
triton_poi_fused_add_cat_mul_0.run(arg7_1, arg2_1, arg4_1, buf0, 128, 2048, 1000, 16384000, stream=stream0)
|
| 130 |
+
del arg7_1, arg2_1, arg4_1, buf0
|
| 131 |
+
|
| 132 |
+
stream0 = get_raw_stream(0)
|
| 133 |
+
arg8_1 = generate_example_value((8, 8, 1000, 128), (1024000, 128, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 8, 1000, 128))
|
| 134 |
+
arg2_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
|
| 135 |
+
arg4_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
|
| 136 |
+
buf1 = generate_example_value((8, 8, 1000, 128), (1024000, 128, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 8, 1000, 128))
|
| 137 |
+
triton_poi_fused_add_cat_mul_1.run(arg8_1, arg2_1, arg4_1, buf1, 128, 1024, 1000, 8192000, stream=stream0)
|
| 138 |
+
del arg8_1, arg2_1, arg4_1, buf1
|
| 139 |
+
|
| 140 |
+
"""
|
| 141 |
+
# AOT ID: ['3_inference']
|
| 142 |
+
from ctypes import c_void_p, c_long, c_int
|
| 143 |
+
import torch
|
| 144 |
+
import math
|
| 145 |
+
import random
|
| 146 |
+
import os
|
| 147 |
+
import tempfile
|
| 148 |
+
from math import inf, nan
|
| 149 |
+
from cmath import nanj
|
| 150 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 151 |
+
from torch._inductor.utils import maybe_profile
|
| 152 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 153 |
+
from torch import device, empty_strided
|
| 154 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 155 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 156 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 157 |
+
import triton
|
| 158 |
+
import triton.language as tl
|
| 159 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 160 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 161 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 162 |
+
|
| 163 |
+
aten = torch.ops.aten
|
| 164 |
+
inductor_ops = torch.ops.inductor
|
| 165 |
+
_quantized = torch.ops._quantized
|
| 166 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 167 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 168 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 169 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 170 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 171 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 172 |
+
async_compile = AsyncCompile()
|
| 173 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/nf/cnf7ddto2mtv7utbz2ev3zn2hkgmh5nivfwjn3kwh7vb2j3fnuyw.py
|
| 177 |
+
# Topologically Sorted Source Nodes: [mul, cat, mul_1, q_embed], Original ATen: [aten.mul, aten.cat, aten.add]
|
| 178 |
+
# Source node to ATen node mapping:
|
| 179 |
+
# cat => cat
|
| 180 |
+
# mul => mul_8
|
| 181 |
+
# mul_1 => mul_29
|
| 182 |
+
# q_embed => add_36
|
| 183 |
+
# Graph fragment:
|
| 184 |
+
# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg7_1, %unsqueeze), kwargs = {})
|
| 185 |
+
# %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
|
| 186 |
+
# %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
|
| 187 |
+
# %add_36 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %mul_29), kwargs = {})
|
| 188 |
+
triton_poi_fused_add_cat_mul_0 = async_compile.triton('triton_poi_fused_add_cat_mul_0', '''
|
| 189 |
+
import triton
|
| 190 |
+
import triton.language as tl
|
| 191 |
+
|
| 192 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 193 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 194 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 195 |
+
triton_helpers.set_driver_to_gpu()
|
| 196 |
+
|
| 197 |
+
@triton_heuristics.pointwise(
|
| 198 |
+
size_hints={'x': 16777216},
|
| 199 |
+
filename=__file__,
|
| 200 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 201 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 202 |
+
min_elem_per_thread=0
|
| 203 |
+
)
|
| 204 |
+
@triton.jit
|
| 205 |
+
def triton_poi_fused_add_cat_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
|
| 206 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 207 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 208 |
+
xmask = xindex < xnumel
|
| 209 |
+
x4 = xindex
|
| 210 |
+
x0 = (xindex % ks0)
|
| 211 |
+
x2 = ((xindex // ks1) % ks2)
|
| 212 |
+
x5 = xindex // ks0
|
| 213 |
+
tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 214 |
+
tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 215 |
+
tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 216 |
+
tmp2 = tmp0 * tmp1
|
| 217 |
+
tmp3 = x0
|
| 218 |
+
tmp4 = tl.full([1], 0, tl.int64)
|
| 219 |
+
tmp5 = tmp3 >= tmp4
|
| 220 |
+
tmp6 = ks0 + (-1)*(ks0 // 2)
|
| 221 |
+
tmp7 = tmp3 < tmp6
|
| 222 |
+
tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 223 |
+
tmp9 = -tmp8
|
| 224 |
+
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
|
| 225 |
+
tmp11 = tl.where(tmp7, tmp9, tmp10)
|
| 226 |
+
tmp12 = tmp3 >= tmp6
|
| 227 |
+
tmp13 = ks0
|
| 228 |
+
tmp14 = tmp3 < tmp13
|
| 229 |
+
tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 230 |
+
tmp16 = tl.where(tmp7, tmp11, tmp15)
|
| 231 |
+
tmp18 = tmp16 * tmp17
|
| 232 |
+
tmp19 = tmp2 + tmp18
|
| 233 |
+
tl.store(out_ptr0 + (x4), tmp19, xmask)
|
| 234 |
+
''', device_str='cuda')
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py
|
| 238 |
+
# Topologically Sorted Source Nodes: [mul_2, cat_1, mul_3, k_embed], Original ATen: [aten.mul, aten.cat, aten.add]
|
| 239 |
+
# Source node to ATen node mapping:
|
| 240 |
+
# cat_1 => cat_1
|
| 241 |
+
# k_embed => add_72
|
| 242 |
+
# mul_2 => mul_38
|
| 243 |
+
# mul_3 => mul_59
|
| 244 |
+
# Graph fragment:
|
| 245 |
+
# %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg8_1, %unsqueeze), kwargs = {})
|
| 246 |
+
# %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
|
| 247 |
+
# %mul_59 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
|
| 248 |
+
# %add_72 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_38, %mul_59), kwargs = {})
|
| 249 |
+
triton_poi_fused_add_cat_mul_1 = async_compile.triton('triton_poi_fused_add_cat_mul_1', '''
|
| 250 |
+
import triton
|
| 251 |
+
import triton.language as tl
|
| 252 |
+
|
| 253 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 254 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 255 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 256 |
+
triton_helpers.set_driver_to_gpu()
|
| 257 |
+
|
| 258 |
+
@triton_heuristics.pointwise(
|
| 259 |
+
size_hints={'x': 8388608},
|
| 260 |
+
filename=__file__,
|
| 261 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 262 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 263 |
+
min_elem_per_thread=0
|
| 264 |
+
)
|
| 265 |
+
@triton.jit
|
| 266 |
+
def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
|
| 267 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 268 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 269 |
+
xmask = xindex < xnumel
|
| 270 |
+
x4 = xindex
|
| 271 |
+
x0 = (xindex % ks0)
|
| 272 |
+
x2 = ((xindex // ks1) % ks2)
|
| 273 |
+
x5 = xindex // ks0
|
| 274 |
+
tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 275 |
+
tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 276 |
+
tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 277 |
+
tmp2 = tmp0 * tmp1
|
| 278 |
+
tmp3 = x0
|
| 279 |
+
tmp4 = tl.full([1], 0, tl.int64)
|
| 280 |
+
tmp5 = tmp3 >= tmp4
|
| 281 |
+
tmp6 = ks0 + (-1)*(ks0 // 2)
|
| 282 |
+
tmp7 = tmp3 < tmp6
|
| 283 |
+
tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 284 |
+
tmp9 = -tmp8
|
| 285 |
+
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
|
| 286 |
+
tmp11 = tl.where(tmp7, tmp9, tmp10)
|
| 287 |
+
tmp12 = tmp3 >= tmp6
|
| 288 |
+
tmp13 = ks0
|
| 289 |
+
tmp14 = tmp3 < tmp13
|
| 290 |
+
tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 291 |
+
tmp16 = tl.where(tmp7, tmp11, tmp15)
|
| 292 |
+
tmp18 = tmp16 * tmp17
|
| 293 |
+
tmp19 = tmp2 + tmp18
|
| 294 |
+
tl.store(out_ptr0 + (x4), tmp19, xmask)
|
| 295 |
+
''', device_str='cuda')
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
async_compile.wait(globals())
|
| 299 |
+
del async_compile
|
| 300 |
+
|
| 301 |
+
def call(args):
|
| 302 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1 = args
|
| 303 |
+
args.clear()
|
| 304 |
+
s0 = arg0_1
|
| 305 |
+
s1 = arg1_1
|
| 306 |
+
s7 = arg5_1
|
| 307 |
+
s8 = arg6_1
|
| 308 |
+
assert_size_stride(arg2_1, (1, s0, s1), (s0*s1, s1, 1))
|
| 309 |
+
assert_size_stride(arg4_1, (1, s0, s1), (s0*s1, s1, 1))
|
| 310 |
+
assert_size_stride(arg7_1, (s7, s8, s0, s1), (s0*s1*s8, s1, s1*s8, 1))
|
| 311 |
+
assert_size_stride(arg8_1, (s7, s7, s0, s1), (s0*s1*s7, s1, s1*s7, 1))
|
| 312 |
+
with torch.cuda._DeviceGuard(0):
|
| 313 |
+
torch.cuda.set_device(0)
|
| 314 |
+
ps0 = s1*s8
|
| 315 |
+
pool1 = empty_strided_cuda((s7, s8, s0, s1), (s0*s1*s8, s1, s1*s8, 1), torch.bfloat16)
|
| 316 |
+
buf0 = pool1 # alloc
|
| 317 |
+
# Topologically Sorted Source Nodes: [mul, cat, mul_1, q_embed], Original ATen: [aten.mul, aten.cat, aten.add]
|
| 318 |
+
triton_poi_fused_add_cat_mul_0_xnumel = s0*s1*s7*s8
|
| 319 |
+
stream0 = get_raw_stream(0)
|
| 320 |
+
triton_poi_fused_add_cat_mul_0.run(arg7_1, arg2_1, arg4_1, buf0, s1, ps0, s0, triton_poi_fused_add_cat_mul_0_xnumel, stream=stream0)
|
| 321 |
+
del arg7_1
|
| 322 |
+
ps1 = s1*s7
|
| 323 |
+
pool0 = empty_strided_cuda((s7, s7, s0, s1), (s0*s1*s7, s1, s1*s7, 1), torch.bfloat16)
|
| 324 |
+
buf1 = pool0 # alloc
|
| 325 |
+
# Topologically Sorted Source Nodes: [mul_2, cat_1, mul_3, k_embed], Original ATen: [aten.mul, aten.cat, aten.add]
|
| 326 |
+
triton_poi_fused_add_cat_mul_1_xnumel = s0*s1*s7*s7
|
| 327 |
+
stream0 = get_raw_stream(0)
|
| 328 |
+
triton_poi_fused_add_cat_mul_1.run(arg8_1, arg2_1, arg4_1, buf1, s1, ps1, s0, triton_poi_fused_add_cat_mul_1_xnumel, stream=stream0)
|
| 329 |
+
del arg2_1
|
| 330 |
+
del arg4_1
|
| 331 |
+
del arg8_1
|
| 332 |
+
return (buf0, buf1, )
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 336 |
+
from torch._dynamo.testing import rand_strided
|
| 337 |
+
from torch._inductor.utils import print_performance
|
| 338 |
+
arg0_1 = 1000
|
| 339 |
+
arg1_1 = 128
|
| 340 |
+
arg2_1 = rand_strided((1, 1000, 128), (128000, 128, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 341 |
+
arg3_1 = 1
|
| 342 |
+
arg4_1 = rand_strided((1, 1000, 128), (128000, 128, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 343 |
+
arg5_1 = 8
|
| 344 |
+
arg6_1 = 16
|
| 345 |
+
arg7_1 = rand_strided((8, 16, 1000, 128), (2048000, 128, 2048, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 346 |
+
arg8_1 = rand_strided((8, 8, 1000, 128), (1024000, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 347 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1])
|
| 348 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 353 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 33554432},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x0 = xindex
|
| 23 |
+
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 24 |
+
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 25 |
+
tmp1 = tmp0.to(tl.float32)
|
| 26 |
+
tmp2 = tl.sigmoid(tmp1)
|
| 27 |
+
tmp3 = tmp1 * tmp2
|
| 28 |
+
tmp4 = tmp3.to(tl.float32)
|
| 29 |
+
tmp6 = tmp4 * tmp5
|
| 30 |
+
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
|
torchinductor_ch-epfl-345354-j/57/dad7be19dc394c1e08368515640dff88b78797aaabae100a15a1f195476a9a87.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 67}
|
torchinductor_ch-epfl-345354-j/5e/c5enonf6qztlsw7dozsqkejk4exzt4n56gbz6fiey2gnus5vdf76.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 1, 'r0_': 2048},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_12', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_nll_loss_backward_nll_loss_forward_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
_tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 28 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 29 |
+
r0_index = r0_offset + r0_base
|
| 30 |
+
r0_mask = r0_index < r0_numel
|
| 31 |
+
roffset = r0_offset
|
| 32 |
+
rindex = r0_index
|
| 33 |
+
r0_0 = r0_index
|
| 34 |
+
tmp5 = tl.load(in_ptr1 + (((r0_0 + 3*((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 35 |
+
tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 36 |
+
tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
|
| 37 |
+
tmp0 = ((r0_0 + 3*((6 + ks0*ks1) // 7)) % ks1)
|
| 38 |
+
tmp1 = (-1) + ks1
|
| 39 |
+
tmp2 = tmp0 == tmp1
|
| 40 |
+
tmp3 = tmp0 < tmp1
|
| 41 |
+
tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + 3*((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
|
| 42 |
+
tmp6 = tl.where(tmp3, tmp4, tmp5)
|
| 43 |
+
tmp7 = tl.full([1, 1], -100, tl.int64)
|
| 44 |
+
tmp8 = tl.where(tmp2, tmp7, tmp6)
|
| 45 |
+
tmp9 = tmp8 != tmp7
|
| 46 |
+
tmp10 = tl.full([1, 1], 0, tl.int64)
|
| 47 |
+
tmp11 = tl.where(tmp9, tmp8, tmp10)
|
| 48 |
+
tmp12 = ks2
|
| 49 |
+
tmp13 = tmp11 + tmp12
|
| 50 |
+
tmp14 = tmp11 < 0
|
| 51 |
+
tmp15 = tl.where(tmp14, tmp13, tmp11)
|
| 52 |
+
tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
|
| 53 |
+
tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 3*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 54 |
+
tmp18 = tmp17.to(tl.float32)
|
| 55 |
+
tmp20 = tmp18 - tmp19
|
| 56 |
+
tmp22 = tmp20 - tmp21
|
| 57 |
+
tmp23 = -tmp22
|
| 58 |
+
tmp24 = 0.0
|
| 59 |
+
tmp25 = tl.where(tmp9, tmp23, tmp24)
|
| 60 |
+
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
|
| 61 |
+
tmp28 = _tmp27 + tmp26
|
| 62 |
+
_tmp27 = tl.where(r0_mask, tmp28, _tmp27)
|
| 63 |
+
tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
|
| 64 |
+
tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
|
| 65 |
+
tmp27 = tl.sum(_tmp27, 1)[:, None]
|
| 66 |
+
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
|
torchinductor_ch-epfl-345354-j/5x/c5xsvywggx5vrzm2l5uaktu7pipclhdn5h6263yru2ugvuhe2nak.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 8192, 'r0_': 1024},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel):
|
| 19 |
+
XBLOCK: tl.constexpr = 1
|
| 20 |
+
r0_numel = 1024
|
| 21 |
+
R0_BLOCK: tl.constexpr = 1024
|
| 22 |
+
rnumel = r0_numel
|
| 23 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 24 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 25 |
+
xindex = tl.full([1], xoffset, tl.int32)
|
| 26 |
+
xmask = tl.full([R0_BLOCK], True, tl.int1)
|
| 27 |
+
r0_index = tl.arange(0, R0_BLOCK)[:]
|
| 28 |
+
r0_offset = 0
|
| 29 |
+
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
|
| 30 |
+
roffset = r0_offset
|
| 31 |
+
rindex = r0_index
|
| 32 |
+
r0_1 = r0_index
|
| 33 |
+
x0 = xindex
|
| 34 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
|
| 35 |
+
tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 36 |
+
tmp4 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
|
| 37 |
+
tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
|
| 38 |
+
tmp2 = tmp0 * tmp1
|
| 39 |
+
tmp3 = tmp2.to(tl.float32)
|
| 40 |
+
tmp5 = tmp4.to(tl.float32)
|
| 41 |
+
tmp6 = tmp3 * tmp5
|
| 42 |
+
tmp7 = tl.broadcast_to(tmp6, [R0_BLOCK])
|
| 43 |
+
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
|
| 44 |
+
tmp11 = tmp3 * tmp10
|
| 45 |
+
tmp12 = -0.5
|
| 46 |
+
tmp13 = tmp9 * tmp12
|
| 47 |
+
tmp14 = tmp10 * tmp10
|
| 48 |
+
tmp15 = tmp14 * tmp10
|
| 49 |
+
tmp16 = tmp13 * tmp15
|
| 50 |
+
tmp17 = 0.0009765625
|
| 51 |
+
tmp18 = tmp16 * tmp17
|
| 52 |
+
tmp19 = 2.0
|
| 53 |
+
tmp20 = tmp5 * tmp19
|
| 54 |
+
tmp21 = tmp18 * tmp20
|
| 55 |
+
tmp22 = tmp11 + tmp21
|
| 56 |
+
tmp23 = tmp22.to(tl.float32)
|
| 57 |
+
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, None)
|
torchinductor_ch-epfl-345354-j/6d/c6dsbxlebwjqawzeprkq3lkldtxoiept4c6bpgtva5r4mjlrnwlr.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile-time auto-tuning block:
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.testing import rand_strided
|
| 6 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 7 |
+
from torch._inductor.select_algorithm import AlgorithmSelectorCache
|
| 8 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 9 |
+
|
| 10 |
+
async_compile = AsyncCompile()
|
| 11 |
+
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
| 12 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 13 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 21 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 22 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 23 |
+
triton_helpers.set_driver_to_gpu()
|
| 24 |
+
|
| 25 |
+
@triton_heuristics.persistent_reduction(
|
| 26 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 27 |
+
reduction_hint=ReductionHint.INNER,
|
| 28 |
+
filename=__file__,
|
| 29 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 30 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 34 |
+
r0_numel = 128
|
| 35 |
+
R0_BLOCK: tl.constexpr = 128
|
| 36 |
+
rnumel = r0_numel
|
| 37 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 38 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 39 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 40 |
+
xmask = xindex < xnumel
|
| 41 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 42 |
+
r0_offset = 0
|
| 43 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 44 |
+
roffset = r0_offset
|
| 45 |
+
rindex = r0_index
|
| 46 |
+
r0_1 = r0_index
|
| 47 |
+
x0 = xindex
|
| 48 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 49 |
+
tmp7 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 50 |
+
tmp1 = tmp0.to(tl.float32)
|
| 51 |
+
tmp2 = tmp1 * tmp1
|
| 52 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 53 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 54 |
+
tmp6 = tl.sum(tmp5, 1)[:, None]
|
| 55 |
+
tmp8 = 128.0
|
| 56 |
+
tmp9 = (tmp6 / tmp8)
|
| 57 |
+
tmp10 = 1e-06
|
| 58 |
+
tmp11 = tmp9 + tmp10
|
| 59 |
+
tmp12 = libdevice.rsqrt(tmp11)
|
| 60 |
+
tmp13 = tmp1 * tmp12
|
| 61 |
+
tmp14 = tmp13.to(tl.float32)
|
| 62 |
+
tmp15 = tmp7 * tmp14
|
| 63 |
+
tl.store(out_ptr1 + (r0_1 + 128*x0), tmp15, xmask)
|
| 64 |
+
''', device_str='cuda')
|
| 65 |
+
|
| 66 |
+
async_compile.wait(globals())
|
| 67 |
+
del async_compile
|
| 68 |
+
|
| 69 |
+
import triton
|
| 70 |
+
import triton.language as tl
|
| 71 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 72 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 73 |
+
with torch.cuda._DeviceGuard(0):
|
| 74 |
+
torch.cuda.set_device(0)
|
| 75 |
+
stream0 = get_raw_stream(0)
|
| 76 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 77 |
+
stream0 = get_raw_stream(0)
|
| 78 |
+
arg3_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 79 |
+
arg4_1 = generate_example_value((128,), (1,), 'cuda:0', torch.bfloat16, 0, (128,))
|
| 80 |
+
buf1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 81 |
+
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg3_1, arg4_1, buf1, 128000, 128, stream=stream0)
|
| 82 |
+
del arg3_1, arg4_1, buf1
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
# AOT ID: ['2_inference']
|
| 86 |
+
from ctypes import c_void_p, c_long, c_int
|
| 87 |
+
import torch
|
| 88 |
+
import math
|
| 89 |
+
import random
|
| 90 |
+
import os
|
| 91 |
+
import tempfile
|
| 92 |
+
from math import inf, nan
|
| 93 |
+
from cmath import nanj
|
| 94 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 95 |
+
from torch._inductor.utils import maybe_profile
|
| 96 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 97 |
+
from torch import device, empty_strided
|
| 98 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 99 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 100 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 101 |
+
import triton
|
| 102 |
+
import triton.language as tl
|
| 103 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 104 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 105 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 106 |
+
|
| 107 |
+
aten = torch.ops.aten
|
| 108 |
+
inductor_ops = torch.ops.inductor
|
| 109 |
+
_quantized = torch.ops._quantized
|
| 110 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 111 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 112 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 113 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 114 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 115 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 116 |
+
async_compile = AsyncCompile()
|
| 117 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/l7/cl7z3tsg3vjxfjb3vqjym4iefsq6h5o7fsfsmbi65q55e56x3lm7.py
|
| 121 |
+
# Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, add, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
|
| 122 |
+
# Source node to ATen node mapping:
|
| 123 |
+
# add => add_15
|
| 124 |
+
# hidden_states => convert_element_type
|
| 125 |
+
# hidden_states_1 => mul_17
|
| 126 |
+
# mul_1 => mul_26
|
| 127 |
+
# pow_1 => pow_1
|
| 128 |
+
# rsqrt => rsqrt
|
| 129 |
+
# to_1 => convert_element_type_1
|
| 130 |
+
# variance => mean
|
| 131 |
+
# Graph fragment:
|
| 132 |
+
# %convert_element_type : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg3_1, torch.float32), kwargs = {})
|
| 133 |
+
# %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
|
| 134 |
+
# %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
|
| 135 |
+
# %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
|
| 136 |
+
# %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_15,), kwargs = {})
|
| 137 |
+
# %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
|
| 138 |
+
# %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_17, torch.bfloat16), kwargs = {})
|
| 139 |
+
# %mul_26 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %convert_element_type_1), kwargs = {})
|
| 140 |
+
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
|
| 141 |
+
import triton
|
| 142 |
+
import triton.language as tl
|
| 143 |
+
|
| 144 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 145 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 146 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 147 |
+
triton_helpers.set_driver_to_gpu()
|
| 148 |
+
|
| 149 |
+
@triton_heuristics.persistent_reduction(
|
| 150 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 151 |
+
reduction_hint=ReductionHint.INNER,
|
| 152 |
+
filename=__file__,
|
| 153 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 154 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 155 |
+
)
|
| 156 |
+
@triton.jit
|
| 157 |
+
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 158 |
+
r0_numel = 128
|
| 159 |
+
R0_BLOCK: tl.constexpr = 128
|
| 160 |
+
rnumel = r0_numel
|
| 161 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 162 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 163 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 164 |
+
xmask = xindex < xnumel
|
| 165 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 166 |
+
r0_offset = 0
|
| 167 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 168 |
+
roffset = r0_offset
|
| 169 |
+
rindex = r0_index
|
| 170 |
+
r0_1 = r0_index
|
| 171 |
+
x0 = xindex
|
| 172 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 173 |
+
tmp7 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 174 |
+
tmp1 = tmp0.to(tl.float32)
|
| 175 |
+
tmp2 = tmp1 * tmp1
|
| 176 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 177 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 178 |
+
tmp6 = tl.sum(tmp5, 1)[:, None]
|
| 179 |
+
tmp8 = 128.0
|
| 180 |
+
tmp9 = (tmp6 / tmp8)
|
| 181 |
+
tmp10 = 1e-06
|
| 182 |
+
tmp11 = tmp9 + tmp10
|
| 183 |
+
tmp12 = libdevice.rsqrt(tmp11)
|
| 184 |
+
tmp13 = tmp1 * tmp12
|
| 185 |
+
tmp14 = tmp13.to(tl.float32)
|
| 186 |
+
tmp15 = tmp7 * tmp14
|
| 187 |
+
tl.store(out_ptr1 + (r0_1 + 128*x0), tmp15, xmask)
|
| 188 |
+
''', device_str='cuda')
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
async_compile.wait(globals())
|
| 192 |
+
del async_compile
|
| 193 |
+
|
| 194 |
+
def call(args):
|
| 195 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
|
| 196 |
+
args.clear()
|
| 197 |
+
s3 = arg0_1
|
| 198 |
+
s4 = arg1_1
|
| 199 |
+
s5 = arg2_1
|
| 200 |
+
assert_size_stride(arg3_1, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
|
| 201 |
+
assert_size_stride(arg4_1, (128, ), (1, ))
|
| 202 |
+
with torch.cuda._DeviceGuard(0):
|
| 203 |
+
torch.cuda.set_device(0)
|
| 204 |
+
pool0 = empty_strided_cuda((s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1), torch.bfloat16)
|
| 205 |
+
buf1 = pool0 # alloc
|
| 206 |
+
# Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, add, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
|
| 207 |
+
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0_xnumel = s3*s4*s5
|
| 208 |
+
stream0 = get_raw_stream(0)
|
| 209 |
+
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg3_1, arg4_1, buf1, triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0_xnumel, 128, stream=stream0)
|
| 210 |
+
del arg3_1
|
| 211 |
+
del arg4_1
|
| 212 |
+
return (buf1, )
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 216 |
+
from torch._dynamo.testing import rand_strided
|
| 217 |
+
from torch._inductor.utils import print_performance
|
| 218 |
+
arg0_1 = 8
|
| 219 |
+
arg1_1 = 1000
|
| 220 |
+
arg2_1 = 16
|
| 221 |
+
arg3_1 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 222 |
+
arg4_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
|
| 223 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
|
| 224 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 229 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor_ch-epfl-345354-j/6j/7c215475e7b40a21cf286026270965eb7f07e7c3af1c4052d331de3f74c6449e.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 68}
|
torchinductor_ch-epfl-345354-j/6j/c6j5lx5qgycfvyi3dm5f4mo3ssluzzsrmdq32pka7e6pyhg42zvd.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile-time auto-tuning block:
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.testing import rand_strided
|
| 6 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 7 |
+
from torch._inductor.select_algorithm import AlgorithmSelectorCache
|
| 8 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 9 |
+
|
| 10 |
+
async_compile = AsyncCompile()
|
| 11 |
+
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
| 12 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 13 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 21 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 22 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 23 |
+
triton_helpers.set_driver_to_gpu()
|
| 24 |
+
|
| 25 |
+
@triton_heuristics.reduction(
|
| 26 |
+
size_hints={'x': 65536, 'r0_': 512},
|
| 27 |
+
reduction_hint=ReductionHint.OUTER,
|
| 28 |
+
filename=__file__,
|
| 29 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 30 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 34 |
+
xnumel = 40960
|
| 35 |
+
rnumel = r0_numel
|
| 36 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 37 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 38 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 39 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 40 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 41 |
+
rbase = r0_base
|
| 42 |
+
x1 = xindex // 128
|
| 43 |
+
x0 = (xindex % 128)
|
| 44 |
+
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 45 |
+
x3 = xindex
|
| 46 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 47 |
+
r0_index = r0_offset + r0_base
|
| 48 |
+
r0_mask = r0_index < r0_numel
|
| 49 |
+
roffset = r0_offset
|
| 50 |
+
rindex = r0_index
|
| 51 |
+
r0_2 = r0_index
|
| 52 |
+
tmp0 = r0_2 + x1*((319 + ks0*ks1*ks2) // 320)
|
| 53 |
+
tmp1 = ks0*ks1*ks2
|
| 54 |
+
tmp2 = tmp0 < tmp1
|
| 55 |
+
tmp3 = tl.load(in_ptr0 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 56 |
+
tmp4 = tl.load(in_ptr1 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 57 |
+
tmp5 = tmp4.to(tl.float32)
|
| 58 |
+
tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2))), r0_mask & tmp2, eviction_policy='evict_last', other=0.0)
|
| 59 |
+
tmp7 = tmp5 * tmp6
|
| 60 |
+
tmp8 = tmp7.to(tl.float32)
|
| 61 |
+
tmp9 = tmp3 * tmp8
|
| 62 |
+
tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
|
| 63 |
+
tmp11 = tl.where(tmp2, tmp9, tmp10)
|
| 64 |
+
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
|
| 65 |
+
tmp14 = _tmp13 + tmp12
|
| 66 |
+
_tmp13 = tl.where(r0_mask, tmp14, _tmp13)
|
| 67 |
+
tmp13 = tl.sum(_tmp13, 1)[:, None]
|
| 68 |
+
tl.store(out_ptr0 + (x3), tmp13, None)
|
| 69 |
+
''', device_str='cuda')
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
triton_red_fused__to_copy_mul_sum_1 = async_compile.triton('triton_red_fused__to_copy_mul_sum_1', '''
|
| 73 |
+
import triton
|
| 74 |
+
import triton.language as tl
|
| 75 |
+
|
| 76 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 77 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 78 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 79 |
+
triton_helpers.set_driver_to_gpu()
|
| 80 |
+
|
| 81 |
+
@triton_heuristics.reduction(
|
| 82 |
+
size_hints={'x': 128, 'r0_': 512},
|
| 83 |
+
reduction_hint=ReductionHint.OUTER_TINY,
|
| 84 |
+
filename=__file__,
|
| 85 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 86 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 87 |
+
)
|
| 88 |
+
@triton.jit
|
| 89 |
+
def triton_red_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 90 |
+
xnumel = 128
|
| 91 |
+
r0_numel = 320
|
| 92 |
+
rnumel = r0_numel
|
| 93 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 94 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 95 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 96 |
+
xmask = xindex < xnumel
|
| 97 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 98 |
+
rbase = r0_base
|
| 99 |
+
x0 = xindex
|
| 100 |
+
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 101 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 102 |
+
r0_index = r0_offset + r0_base
|
| 103 |
+
r0_mask = r0_index < r0_numel
|
| 104 |
+
roffset = r0_offset
|
| 105 |
+
rindex = r0_index
|
| 106 |
+
r0_1 = r0_index
|
| 107 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 128*r0_1), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
|
| 108 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 109 |
+
tmp3 = _tmp2 + tmp1
|
| 110 |
+
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
|
| 111 |
+
tmp2 = tl.sum(_tmp2, 1)[:, None]
|
| 112 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
| 113 |
+
''', device_str='cuda')
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
|
| 117 |
+
import triton
|
| 118 |
+
import triton.language as tl
|
| 119 |
+
|
| 120 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 121 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 122 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 123 |
+
triton_helpers.set_driver_to_gpu()
|
| 124 |
+
|
| 125 |
+
@triton_heuristics.persistent_reduction(
|
| 126 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 127 |
+
reduction_hint=ReductionHint.INNER,
|
| 128 |
+
filename=__file__,
|
| 129 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 130 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 131 |
+
)
|
| 132 |
+
@triton.jit
|
| 133 |
+
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 134 |
+
r0_numel = 128
|
| 135 |
+
R0_BLOCK: tl.constexpr = 128
|
| 136 |
+
rnumel = r0_numel
|
| 137 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 138 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 139 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 140 |
+
xmask = xindex < xnumel
|
| 141 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 142 |
+
r0_offset = 0
|
| 143 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 144 |
+
roffset = r0_offset
|
| 145 |
+
rindex = r0_index
|
| 146 |
+
r0_1 = r0_index
|
| 147 |
+
x0 = xindex
|
| 148 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 149 |
+
tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 150 |
+
tmp4 = tl.load(in_ptr2 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 151 |
+
tmp11 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
|
| 152 |
+
tmp2 = tmp0 * tmp1
|
| 153 |
+
tmp3 = tmp2.to(tl.float32)
|
| 154 |
+
tmp5 = tmp4.to(tl.float32)
|
| 155 |
+
tmp6 = tmp3 * tmp5
|
| 156 |
+
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
|
| 157 |
+
tmp9 = tl.where(xmask, tmp7, 0)
|
| 158 |
+
tmp10 = tl.sum(tmp9, 1)[:, None]
|
| 159 |
+
tmp12 = tmp3 * tmp11
|
| 160 |
+
tmp13 = -0.5
|
| 161 |
+
tmp14 = tmp10 * tmp13
|
| 162 |
+
tmp15 = tmp11 * tmp11
|
| 163 |
+
tmp16 = tmp15 * tmp11
|
| 164 |
+
tmp17 = tmp14 * tmp16
|
| 165 |
+
tmp18 = 0.0078125
|
| 166 |
+
tmp19 = tmp17 * tmp18
|
| 167 |
+
tmp20 = 2.0
|
| 168 |
+
tmp21 = tmp5 * tmp20
|
| 169 |
+
tmp22 = tmp19 * tmp21
|
| 170 |
+
tmp23 = tmp12 + tmp22
|
| 171 |
+
tmp24 = tmp23.to(tl.float32)
|
| 172 |
+
tl.store(out_ptr1 + (r0_1 + 128*x0), tmp24, xmask)
|
| 173 |
+
''', device_str='cuda')
|
| 174 |
+
|
| 175 |
+
async_compile.wait(globals())
|
| 176 |
+
del async_compile
|
| 177 |
+
|
| 178 |
+
import triton
|
| 179 |
+
import triton.language as tl
|
| 180 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 181 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 182 |
+
with torch.cuda._DeviceGuard(0):
|
| 183 |
+
torch.cuda.set_device(0)
|
| 184 |
+
stream0 = get_raw_stream(0)
|
| 185 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 186 |
+
stream0 = get_raw_stream(0)
|
| 187 |
+
tangents_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 188 |
+
primals_4 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 189 |
+
rsqrt = generate_example_value((8, 1000, 16, 1), (16000, 16, 1, 1), 'cuda:0', torch.float32, 0, (8, 1000, 16, 1))
|
| 190 |
+
buf0 = generate_example_value((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), 'cuda:0', torch.float32, 0, (1, 1, 1, 128, 320))
|
| 191 |
+
triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, 8, 1000, 16, 40960, 400, stream=stream0)
|
| 192 |
+
del tangents_1, primals_4, rsqrt, buf0
|
| 193 |
+
|
| 194 |
+
stream0 = get_raw_stream(0)
|
| 195 |
+
buf0 = generate_example_value((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), 'cuda:0', torch.float32, 0, (1, 1, 1, 128, 320))
|
| 196 |
+
buf1 = generate_example_value((1, 1, 1, 128), (128, 128, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1, 1, 128))
|
| 197 |
+
triton_red_fused__to_copy_mul_sum_1.run(buf0, buf1, 128, 320, stream=stream0)
|
| 198 |
+
del buf0, buf1
|
| 199 |
+
|
| 200 |
+
stream0 = get_raw_stream(0)
|
| 201 |
+
tangents_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 202 |
+
primals_5 = generate_example_value((128,), (1,), 'cuda:0', torch.bfloat16, 0, (128,))
|
| 203 |
+
primals_4 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 204 |
+
rsqrt = generate_example_value((8, 1000, 16, 1), (16000, 16, 1, 1), 'cuda:0', torch.float32, 0, (8, 1000, 16, 1))
|
| 205 |
+
buf3 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
|
| 206 |
+
triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(tangents_1, primals_5, primals_4, rsqrt, buf3, 128000, 128, stream=stream0)
|
| 207 |
+
del tangents_1, primals_5, primals_4, rsqrt, buf3
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
# AOT ID: ['9_backward']
|
| 211 |
+
from ctypes import c_void_p, c_long, c_int
|
| 212 |
+
import torch
|
| 213 |
+
import math
|
| 214 |
+
import random
|
| 215 |
+
import os
|
| 216 |
+
import tempfile
|
| 217 |
+
from math import inf, nan
|
| 218 |
+
from cmath import nanj
|
| 219 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 220 |
+
from torch._inductor.utils import maybe_profile
|
| 221 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 222 |
+
from torch import device, empty_strided
|
| 223 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 224 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 225 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 226 |
+
import triton
|
| 227 |
+
import triton.language as tl
|
| 228 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 229 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 230 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 231 |
+
|
| 232 |
+
aten = torch.ops.aten
|
| 233 |
+
inductor_ops = torch.ops.inductor
|
| 234 |
+
_quantized = torch.ops._quantized
|
| 235 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 236 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 237 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 238 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 239 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 240 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 241 |
+
async_compile = AsyncCompile()
|
| 242 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/mm/cmmurv7ol6o3kll2wm4b6wgtdca4tsysrq7yrhhvfkf7ikm72y24.py
|
| 246 |
+
# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
|
| 247 |
+
# Source node to ATen node mapping:
|
| 248 |
+
# hidden_states => convert_element_type
|
| 249 |
+
# hidden_states_1 => mul_23
|
| 250 |
+
# to_1 => convert_element_type_1
|
| 251 |
+
# Graph fragment:
|
| 252 |
+
# %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
|
| 253 |
+
# %mul_23 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
|
| 254 |
+
# %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
|
| 255 |
+
# %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
|
| 256 |
+
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_38, [0, 1, 2], True), kwargs = {})
|
| 257 |
+
triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
|
| 258 |
+
import triton
|
| 259 |
+
import triton.language as tl
|
| 260 |
+
|
| 261 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 262 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 263 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 264 |
+
triton_helpers.set_driver_to_gpu()
|
| 265 |
+
|
| 266 |
+
@triton_heuristics.reduction(
|
| 267 |
+
size_hints={'x': 65536, 'r0_': 512},
|
| 268 |
+
reduction_hint=ReductionHint.OUTER,
|
| 269 |
+
filename=__file__,
|
| 270 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 271 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 272 |
+
)
|
| 273 |
+
@triton.jit
|
| 274 |
+
def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 275 |
+
xnumel = 40960
|
| 276 |
+
rnumel = r0_numel
|
| 277 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 278 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 279 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 280 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 281 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 282 |
+
rbase = r0_base
|
| 283 |
+
x1 = xindex // 128
|
| 284 |
+
x0 = (xindex % 128)
|
| 285 |
+
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 286 |
+
x3 = xindex
|
| 287 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 288 |
+
r0_index = r0_offset + r0_base
|
| 289 |
+
r0_mask = r0_index < r0_numel
|
| 290 |
+
roffset = r0_offset
|
| 291 |
+
rindex = r0_index
|
| 292 |
+
r0_2 = r0_index
|
| 293 |
+
tmp0 = r0_2 + x1*((319 + ks0*ks1*ks2) // 320)
|
| 294 |
+
tmp1 = ks0*ks1*ks2
|
| 295 |
+
tmp2 = tmp0 < tmp1
|
| 296 |
+
tmp3 = tl.load(in_ptr0 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 297 |
+
tmp4 = tl.load(in_ptr1 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 298 |
+
tmp5 = tmp4.to(tl.float32)
|
| 299 |
+
tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2))), r0_mask & tmp2, eviction_policy='evict_last', other=0.0)
|
| 300 |
+
tmp7 = tmp5 * tmp6
|
| 301 |
+
tmp8 = tmp7.to(tl.float32)
|
| 302 |
+
tmp9 = tmp3 * tmp8
|
| 303 |
+
tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
|
| 304 |
+
tmp11 = tl.where(tmp2, tmp9, tmp10)
|
| 305 |
+
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
|
| 306 |
+
tmp14 = _tmp13 + tmp12
|
| 307 |
+
_tmp13 = tl.where(r0_mask, tmp14, _tmp13)
|
| 308 |
+
tmp13 = tl.sum(_tmp13, 1)[:, None]
|
| 309 |
+
tl.store(out_ptr0 + (x3), tmp13, None)
|
| 310 |
+
''', device_str='cuda')
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/om/com66ihngf42hseqjadt4jcosvwffgk5ynrurmthhraywffjqcop.py
|
| 314 |
+
# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
|
| 315 |
+
# Source node to ATen node mapping:
|
| 316 |
+
# hidden_states => convert_element_type
|
| 317 |
+
# hidden_states_1 => mul_23
|
| 318 |
+
# to_1 => convert_element_type_1
|
| 319 |
+
# Graph fragment:
|
| 320 |
+
# %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
|
| 321 |
+
# %mul_23 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
|
| 322 |
+
# %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
|
| 323 |
+
# %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
|
| 324 |
+
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_38, [0, 1, 2], True), kwargs = {})
|
| 325 |
+
triton_red_fused__to_copy_mul_sum_1 = async_compile.triton('triton_red_fused__to_copy_mul_sum_1', '''
|
| 326 |
+
import triton
|
| 327 |
+
import triton.language as tl
|
| 328 |
+
|
| 329 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 330 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 331 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 332 |
+
triton_helpers.set_driver_to_gpu()
|
| 333 |
+
|
| 334 |
+
@triton_heuristics.reduction(
|
| 335 |
+
size_hints={'x': 128, 'r0_': 512},
|
| 336 |
+
reduction_hint=ReductionHint.OUTER_TINY,
|
| 337 |
+
filename=__file__,
|
| 338 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 339 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 340 |
+
)
|
| 341 |
+
@triton.jit
|
| 342 |
+
def triton_red_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 343 |
+
xnumel = 128
|
| 344 |
+
r0_numel = 320
|
| 345 |
+
rnumel = r0_numel
|
| 346 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 347 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 348 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 349 |
+
xmask = xindex < xnumel
|
| 350 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 351 |
+
rbase = r0_base
|
| 352 |
+
x0 = xindex
|
| 353 |
+
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 354 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 355 |
+
r0_index = r0_offset + r0_base
|
| 356 |
+
r0_mask = r0_index < r0_numel
|
| 357 |
+
roffset = r0_offset
|
| 358 |
+
rindex = r0_index
|
| 359 |
+
r0_1 = r0_index
|
| 360 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 128*r0_1), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
|
| 361 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 362 |
+
tmp3 = _tmp2 + tmp1
|
| 363 |
+
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
|
| 364 |
+
tmp2 = tl.sum(_tmp2, 1)[:, None]
|
| 365 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
| 366 |
+
''', device_str='cuda')
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/p3/cp3uyorjp57oo5jc6wsaegyfcdcbyu6gppqliksrkascc6kis3o2.py
|
| 370 |
+
# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
|
| 371 |
+
# Source node to ATen node mapping:
|
| 372 |
+
# hidden_states => convert_element_type
|
| 373 |
+
# Graph fragment:
|
| 374 |
+
# %mul_37 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_5), kwargs = {})
|
| 375 |
+
# %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
|
| 376 |
+
# %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_37, torch.float32), kwargs = {})
|
| 377 |
+
# %mul_39 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {})
|
| 378 |
+
# %mul_40 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {})
|
| 379 |
+
# %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_39, [3], True), kwargs = {})
|
| 380 |
+
# %div : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, 128), kwargs = {})
|
| 381 |
+
# %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {})
|
| 382 |
+
# %mul_43 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {})
|
| 383 |
+
# %mul_44 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_43), kwargs = {})
|
| 384 |
+
# %add_46 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_40, %mul_44), kwargs = {})
|
| 385 |
+
# %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_46, torch.bfloat16), kwargs = {})
|
| 386 |
+
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
|
| 387 |
+
import triton
|
| 388 |
+
import triton.language as tl
|
| 389 |
+
|
| 390 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 391 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 392 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 393 |
+
triton_helpers.set_driver_to_gpu()
|
| 394 |
+
|
| 395 |
+
@triton_heuristics.persistent_reduction(
|
| 396 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 397 |
+
reduction_hint=ReductionHint.INNER,
|
| 398 |
+
filename=__file__,
|
| 399 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 400 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 401 |
+
)
|
| 402 |
+
@triton.jit
|
| 403 |
+
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 404 |
+
r0_numel = 128
|
| 405 |
+
R0_BLOCK: tl.constexpr = 128
|
| 406 |
+
rnumel = r0_numel
|
| 407 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 408 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 409 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 410 |
+
xmask = xindex < xnumel
|
| 411 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 412 |
+
r0_offset = 0
|
| 413 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 414 |
+
roffset = r0_offset
|
| 415 |
+
rindex = r0_index
|
| 416 |
+
r0_1 = r0_index
|
| 417 |
+
x0 = xindex
|
| 418 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 419 |
+
tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 420 |
+
tmp4 = tl.load(in_ptr2 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 421 |
+
tmp11 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
|
| 422 |
+
tmp2 = tmp0 * tmp1
|
| 423 |
+
tmp3 = tmp2.to(tl.float32)
|
| 424 |
+
tmp5 = tmp4.to(tl.float32)
|
| 425 |
+
tmp6 = tmp3 * tmp5
|
| 426 |
+
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
|
| 427 |
+
tmp9 = tl.where(xmask, tmp7, 0)
|
| 428 |
+
tmp10 = tl.sum(tmp9, 1)[:, None]
|
| 429 |
+
tmp12 = tmp3 * tmp11
|
| 430 |
+
tmp13 = -0.5
|
| 431 |
+
tmp14 = tmp10 * tmp13
|
| 432 |
+
tmp15 = tmp11 * tmp11
|
| 433 |
+
tmp16 = tmp15 * tmp11
|
| 434 |
+
tmp17 = tmp14 * tmp16
|
| 435 |
+
tmp18 = 0.0078125
|
| 436 |
+
tmp19 = tmp17 * tmp18
|
| 437 |
+
tmp20 = 2.0
|
| 438 |
+
tmp21 = tmp5 * tmp20
|
| 439 |
+
tmp22 = tmp19 * tmp21
|
| 440 |
+
tmp23 = tmp12 + tmp22
|
| 441 |
+
tmp24 = tmp23.to(tl.float32)
|
| 442 |
+
tl.store(out_ptr1 + (r0_1 + 128*x0), tmp24, xmask)
|
| 443 |
+
''', device_str='cuda')
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
async_compile.wait(globals())
|
| 447 |
+
del async_compile
|
| 448 |
+
|
| 449 |
+
def call(args):
|
| 450 |
+
primals_1, primals_2, primals_3, primals_4, primals_5, rsqrt, tangents_1 = args
|
| 451 |
+
args.clear()
|
| 452 |
+
s3 = primals_1
|
| 453 |
+
s4 = primals_2
|
| 454 |
+
s5 = primals_3
|
| 455 |
+
assert_size_stride(primals_4, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
|
| 456 |
+
assert_size_stride(primals_5, (128, ), (1, ))
|
| 457 |
+
assert_size_stride(rsqrt, (s3, s4, s5, 1), (s4*s5, s5, 1, 1))
|
| 458 |
+
assert_size_stride(tangents_1, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
|
| 459 |
+
with torch.cuda._DeviceGuard(0):
|
| 460 |
+
torch.cuda.set_device(0)
|
| 461 |
+
buf0 = empty_strided_cuda((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), torch.float32)
|
| 462 |
+
# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
|
| 463 |
+
triton_red_fused__to_copy_mul_sum_0_r0_numel = (319 + s3*s4*s5) // 320
|
| 464 |
+
stream0 = get_raw_stream(0)
|
| 465 |
+
triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s3, s4, s5, 40960, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream0)
|
| 466 |
+
buf1 = empty_strided_cuda((1, 1, 1, 128), (128, 128, 128, 1), torch.bfloat16)
|
| 467 |
+
# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
|
| 468 |
+
stream0 = get_raw_stream(0)
|
| 469 |
+
triton_red_fused__to_copy_mul_sum_1.run(buf0, buf1, 128, 320, stream=stream0)
|
| 470 |
+
del buf0
|
| 471 |
+
buf3 = empty_strided_cuda((s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1), torch.bfloat16)
|
| 472 |
+
# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
|
| 473 |
+
triton_per_fused__to_copy_add_div_mul_pow_sum_2_xnumel = s3*s4*s5
|
| 474 |
+
stream0 = get_raw_stream(0)
|
| 475 |
+
triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(tangents_1, primals_5, primals_4, rsqrt, buf3, triton_per_fused__to_copy_add_div_mul_pow_sum_2_xnumel, 128, stream=stream0)
|
| 476 |
+
del primals_4
|
| 477 |
+
del primals_5
|
| 478 |
+
del rsqrt
|
| 479 |
+
del tangents_1
|
| 480 |
+
return (None, None, None, buf3, reinterpret_tensor(buf1, (128, ), (1, ), 0), )
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 484 |
+
from torch._dynamo.testing import rand_strided
|
| 485 |
+
from torch._inductor.utils import print_performance
|
| 486 |
+
primals_1 = 8
|
| 487 |
+
primals_2 = 1000
|
| 488 |
+
primals_3 = 16
|
| 489 |
+
primals_4 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 490 |
+
primals_5 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
|
| 491 |
+
rsqrt = rand_strided((8, 1000, 16, 1), (16000, 16, 1, 1), device='cuda:0', dtype=torch.float32)
|
| 492 |
+
tangents_1 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 493 |
+
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, rsqrt, tangents_1])
|
| 494 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 499 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 33554432},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x0 = xindex
|
| 23 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 24 |
+
tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
|
| 25 |
+
tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 26 |
+
tmp2 = tmp1.to(tl.float32)
|
| 27 |
+
tmp3 = tl.sigmoid(tmp2)
|
| 28 |
+
tmp4 = tmp2 * tmp3
|
| 29 |
+
tmp5 = tmp4.to(tl.float32)
|
| 30 |
+
tmp6 = tmp0 * tmp5
|
| 31 |
+
tmp8 = tmp0 * tmp7
|
| 32 |
+
tmp9 = tl.sigmoid(tmp1)
|
| 33 |
+
tmp10 = 1.0
|
| 34 |
+
tmp11 = tmp10 - tmp9
|
| 35 |
+
tmp12 = tmp1 * tmp11
|
| 36 |
+
tmp13 = tmp12 + tmp10
|
| 37 |
+
tmp14 = tmp9 * tmp13
|
| 38 |
+
tmp15 = tmp8 * tmp14
|
| 39 |
+
tl.store(out_ptr0 + (x0), tmp6, xmask)
|
| 40 |
+
tl.store(in_out_ptr0 + (x0), tmp15, xmask)
|
torchinductor_ch-epfl-345354-j/7a/c7a4b5izank2343xz4473c4igojrrhlfxb5ulctqd32qrtkreq3m.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 268435456},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*i1', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__log_softmax__log_softmax_backward_data__to_copy_nll_loss_backward_nll_loss_forward_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused__log_softmax__log_softmax_backward_data__to_copy_nll_loss_backward_nll_loss_forward_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x1 = xindex // ks0
|
| 24 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 25 |
+
tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last').to(tl.int1)
|
| 26 |
+
tmp2 = tl.load(in_ptr2 + (0))
|
| 27 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
|
| 28 |
+
tmp7 = tl.load(in_ptr3 + (x2), xmask, eviction_policy='evict_last').to(tl.float32)
|
| 29 |
+
tmp9 = tl.load(in_ptr4 + (x1), xmask, eviction_policy='evict_last')
|
| 30 |
+
tmp11 = tl.load(in_ptr5 + (x1), xmask, eviction_policy='evict_last')
|
| 31 |
+
tmp14 = tl.load(in_ptr6 + (x1), xmask, eviction_policy='evict_last')
|
| 32 |
+
tmp4 = 0.0
|
| 33 |
+
tmp5 = tl.where(tmp1, tmp3, tmp4)
|
| 34 |
+
tmp6 = tmp0 * tmp5
|
| 35 |
+
tmp8 = tmp7.to(tl.float32)
|
| 36 |
+
tmp10 = tmp8 - tmp9
|
| 37 |
+
tmp12 = tmp10 - tmp11
|
| 38 |
+
tmp13 = tl_math.exp(tmp12)
|
| 39 |
+
tmp15 = tmp13 * tmp14
|
| 40 |
+
tmp16 = tmp6 - tmp15
|
| 41 |
+
tmp17 = tmp16.to(tl.float32)
|
| 42 |
+
tl.store(out_ptr0 + (x2), tmp17, xmask)
|
torchinductor_ch-epfl-345354-j/7a/ddaeab32a9175f6d14ae7329f3defe09537605c41fa4d35da5bf9cbac1616b91.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 85}
|
torchinductor_ch-epfl-345354-j/7q/c7qudnwq7tyfwnepjsm2ilmratxdwkx4euvow7brbvrfif7hgnwh.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile-time auto-tuning block:
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.testing import rand_strided
|
| 6 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 7 |
+
from torch._inductor.select_algorithm import AlgorithmSelectorCache
|
| 8 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 9 |
+
|
| 10 |
+
async_compile = AsyncCompile()
|
| 11 |
+
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
| 12 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 13 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0 = async_compile.triton('triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', '''
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 21 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 22 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 23 |
+
triton_helpers.set_driver_to_gpu()
|
| 24 |
+
|
| 25 |
+
@triton_heuristics.pointwise(
|
| 26 |
+
size_hints={'x': 33554432},
|
| 27 |
+
filename=__file__,
|
| 28 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 29 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 30 |
+
min_elem_per_thread=0
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 34 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 35 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 36 |
+
xmask = xindex < xnumel
|
| 37 |
+
x0 = xindex
|
| 38 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 39 |
+
tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
|
| 40 |
+
tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 41 |
+
tmp2 = tmp1.to(tl.float32)
|
| 42 |
+
tmp3 = tl.sigmoid(tmp2)
|
| 43 |
+
tmp4 = tmp2 * tmp3
|
| 44 |
+
tmp5 = tmp4.to(tl.float32)
|
| 45 |
+
tmp6 = tmp0 * tmp5
|
| 46 |
+
tmp8 = tmp0 * tmp7
|
| 47 |
+
tmp9 = tl.sigmoid(tmp1)
|
| 48 |
+
tmp10 = 1.0
|
| 49 |
+
tmp11 = tmp10 - tmp9
|
| 50 |
+
tmp12 = tmp1 * tmp11
|
| 51 |
+
tmp13 = tmp12 + tmp10
|
| 52 |
+
tmp14 = tmp9 * tmp13
|
| 53 |
+
tmp15 = tmp8 * tmp14
|
| 54 |
+
tl.store(out_ptr0 + (x0), tmp6, xmask)
|
| 55 |
+
tl.store(in_out_ptr0 + (x0), tmp15, xmask)
|
| 56 |
+
''', device_str='cuda')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
triton_poi_fused_add_1 = async_compile.triton('triton_poi_fused_add_1', '''
|
| 60 |
+
import triton
|
| 61 |
+
import triton.language as tl
|
| 62 |
+
|
| 63 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 64 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 65 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 66 |
+
triton_helpers.set_driver_to_gpu()
|
| 67 |
+
|
| 68 |
+
@triton_heuristics.pointwise(
|
| 69 |
+
size_hints={'x': 8388608},
|
| 70 |
+
filename=__file__,
|
| 71 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 72 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 73 |
+
min_elem_per_thread=0
|
| 74 |
+
)
|
| 75 |
+
@triton.jit
|
| 76 |
+
def triton_poi_fused_add_1(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 77 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 78 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 79 |
+
xmask = xindex < xnumel
|
| 80 |
+
x0 = xindex
|
| 81 |
+
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 82 |
+
tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 83 |
+
tmp2 = tmp0 + tmp1
|
| 84 |
+
tl.store(in_out_ptr0 + (x0), tmp2, xmask)
|
| 85 |
+
''', device_str='cuda')
|
| 86 |
+
|
| 87 |
+
async_compile.wait(globals())
|
| 88 |
+
del async_compile
|
| 89 |
+
|
| 90 |
+
import triton
|
| 91 |
+
import triton.language as tl
|
| 92 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 93 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 94 |
+
with torch.cuda._DeviceGuard(0):
|
| 95 |
+
torch.cuda.set_device(0)
|
| 96 |
+
stream0 = get_raw_stream(0)
|
| 97 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 98 |
+
stream0 = get_raw_stream(0)
|
| 99 |
+
buf5 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
|
| 100 |
+
buf1 = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
|
| 101 |
+
mm = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
|
| 102 |
+
buf2 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
|
| 103 |
+
triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0.run(buf5, buf1, mm, buf2, 24576000, stream=stream0)
|
| 104 |
+
del buf5, buf1, mm, buf2
|
| 105 |
+
|
| 106 |
+
stream0 = get_raw_stream(0)
|
| 107 |
+
buf8 = generate_example_value((8, 1000, 1024), (1024000, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 1024))
|
| 108 |
+
buf7 = generate_example_value((8000, 1024), (1024, 1), 'cuda:0', torch.bfloat16, 0, (8000, 1024))
|
| 109 |
+
triton_poi_fused_add_1.run(buf8, buf7, 8192000, stream=stream0)
|
| 110 |
+
del buf8, buf7
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
# AOT ID: ['11_backward']
|
| 114 |
+
from ctypes import c_void_p, c_long, c_int
|
| 115 |
+
import torch
|
| 116 |
+
import math
|
| 117 |
+
import random
|
| 118 |
+
import os
|
| 119 |
+
import tempfile
|
| 120 |
+
from math import inf, nan
|
| 121 |
+
from cmath import nanj
|
| 122 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 123 |
+
from torch._inductor.utils import maybe_profile
|
| 124 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 125 |
+
from torch import device, empty_strided
|
| 126 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 127 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 128 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 129 |
+
import triton
|
| 130 |
+
import triton.language as tl
|
| 131 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 132 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 133 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 134 |
+
|
| 135 |
+
aten = torch.ops.aten
|
| 136 |
+
inductor_ops = torch.ops.inductor
|
| 137 |
+
_quantized = torch.ops._quantized
|
| 138 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 139 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 140 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 141 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 142 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 143 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 144 |
+
async_compile = AsyncCompile()
|
| 145 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py
|
| 149 |
+
# Topologically Sorted Source Nodes: [silu], Original ATen: [aten.silu, aten.mul, aten.sigmoid, aten.fill, aten.sub, aten.add]
|
| 150 |
+
# Source node to ATen node mapping:
|
| 151 |
+
# silu => convert_element_type_2, convert_element_type_3, mul_19, sigmoid
|
| 152 |
+
# Graph fragment:
|
| 153 |
+
# %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
|
| 154 |
+
# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_2,), kwargs = {})
|
| 155 |
+
# %mul_19 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %sigmoid), kwargs = {})
|
| 156 |
+
# %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_19, torch.bfloat16), kwargs = {})
|
| 157 |
+
# %mul_61 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_7, %convert_element_type_3), kwargs = {})
|
| 158 |
+
# %mul_62 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_7, %view_3), kwargs = {})
|
| 159 |
+
# %sigmoid_1 : [num_users=2] = call_function[target=torch.ops.aten.sigmoid.default](args = (%view_1,), kwargs = {})
|
| 160 |
+
# %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%primals_2, %primals_3, 3072], 1), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
|
| 161 |
+
# %sub_16 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%full_default, %sigmoid_1), kwargs = {})
|
| 162 |
+
# %mul_63 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_1, %sub_16), kwargs = {})
|
| 163 |
+
# %add_38 : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mul_63, 1), kwargs = {})
|
| 164 |
+
# %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sigmoid_1, %add_38), kwargs = {})
|
| 165 |
+
# %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_62, %mul_64), kwargs = {})
|
| 166 |
+
triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0 = async_compile.triton('triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', '''
|
| 167 |
+
import triton
|
| 168 |
+
import triton.language as tl
|
| 169 |
+
|
| 170 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 171 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 172 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 173 |
+
triton_helpers.set_driver_to_gpu()
|
| 174 |
+
|
| 175 |
+
@triton_heuristics.pointwise(
|
| 176 |
+
size_hints={'x': 33554432},
|
| 177 |
+
filename=__file__,
|
| 178 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 179 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 180 |
+
min_elem_per_thread=0
|
| 181 |
+
)
|
| 182 |
+
@triton.jit
|
| 183 |
+
def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 184 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 185 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 186 |
+
xmask = xindex < xnumel
|
| 187 |
+
x0 = xindex
|
| 188 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 189 |
+
tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
|
| 190 |
+
tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 191 |
+
tmp2 = tmp1.to(tl.float32)
|
| 192 |
+
tmp3 = tl.sigmoid(tmp2)
|
| 193 |
+
tmp4 = tmp2 * tmp3
|
| 194 |
+
tmp5 = tmp4.to(tl.float32)
|
| 195 |
+
tmp6 = tmp0 * tmp5
|
| 196 |
+
tmp8 = tmp0 * tmp7
|
| 197 |
+
tmp9 = tl.sigmoid(tmp1)
|
| 198 |
+
tmp10 = 1.0
|
| 199 |
+
tmp11 = tmp10 - tmp9
|
| 200 |
+
tmp12 = tmp1 * tmp11
|
| 201 |
+
tmp13 = tmp12 + tmp10
|
| 202 |
+
tmp14 = tmp9 * tmp13
|
| 203 |
+
tmp15 = tmp8 * tmp14
|
| 204 |
+
tl.store(out_ptr0 + (x0), tmp6, xmask)
|
| 205 |
+
tl.store(in_out_ptr0 + (x0), tmp15, xmask)
|
| 206 |
+
''', device_str='cuda')
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/cm/ccm7s7qxxaw3nofvzpftd6fqmd57jvnzfc74xrk54wxxhdnqddlo.py
|
| 210 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add]
|
| 211 |
+
# Source node to ATen node mapping:
|
| 212 |
+
# Graph fragment:
|
| 213 |
+
# %add_39 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_9, %view_11), kwargs = {})
|
| 214 |
+
triton_poi_fused_add_1 = async_compile.triton('triton_poi_fused_add_1', '''
|
| 215 |
+
import triton
|
| 216 |
+
import triton.language as tl
|
| 217 |
+
|
| 218 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 219 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 220 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 221 |
+
triton_helpers.set_driver_to_gpu()
|
| 222 |
+
|
| 223 |
+
@triton_heuristics.pointwise(
|
| 224 |
+
size_hints={'x': 8388608},
|
| 225 |
+
filename=__file__,
|
| 226 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 227 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 228 |
+
min_elem_per_thread=0
|
| 229 |
+
)
|
| 230 |
+
@triton.jit
|
| 231 |
+
def triton_poi_fused_add_1(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 232 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 233 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 234 |
+
xmask = xindex < xnumel
|
| 235 |
+
x0 = xindex
|
| 236 |
+
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 237 |
+
tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 238 |
+
tmp2 = tmp0 + tmp1
|
| 239 |
+
tl.store(in_out_ptr0 + (x0), tmp2, xmask)
|
| 240 |
+
''', device_str='cuda')
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
async_compile.wait(globals())
|
| 244 |
+
del async_compile
|
| 245 |
+
|
| 246 |
+
def call(args):
|
| 247 |
+
primals_2, primals_3, mul, view, mm, mm_1, view_4, permute_5, permute_9, permute_14, tangents_1 = args
|
| 248 |
+
args.clear()
|
| 249 |
+
s0 = primals_2
|
| 250 |
+
s1 = primals_3
|
| 251 |
+
assert_size_stride(view, (s0*s1, 1024), (1024, 1))
|
| 252 |
+
assert_size_stride(mm, (s0*s1, 3072), (3072, 1))
|
| 253 |
+
assert_size_stride(mm_1, (s0*s1, 3072), (3072, 1))
|
| 254 |
+
assert_size_stride(view_4, (s0*s1, 3072), (3072, 1))
|
| 255 |
+
assert_size_stride(permute_5, (1024, 3072), (3072, 1))
|
| 256 |
+
assert_size_stride(permute_9, (3072, 1024), (1024, 1))
|
| 257 |
+
assert_size_stride(permute_14, (3072, 1024), (1024, 1))
|
| 258 |
+
assert_size_stride(tangents_1, (s0, s1, 1024), (1024*s1, 1024, 1))
|
| 259 |
+
with torch.cuda._DeviceGuard(0):
|
| 260 |
+
torch.cuda.set_device(0)
|
| 261 |
+
buf0 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16)
|
| 262 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 263 |
+
extern_kernels.mm(reinterpret_tensor(tangents_1, (1024, s0*s1), (1, 1024), 0), view_4, out=buf0)
|
| 264 |
+
del view_4
|
| 265 |
+
buf1 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
|
| 266 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 267 |
+
extern_kernels.mm(reinterpret_tensor(tangents_1, (s0*s1, 1024), (1024, 1), 0), permute_5, out=buf1)
|
| 268 |
+
del permute_5
|
| 269 |
+
del tangents_1
|
| 270 |
+
buf2 = empty_strided_cuda((s0, s1, 3072), (3072*s1, 3072, 1), torch.bfloat16)
|
| 271 |
+
buf5 = reinterpret_tensor(mm_1, (s0, s1, 3072), (3072*s1, 3072, 1), 0); del mm_1 # reuse
|
| 272 |
+
# Topologically Sorted Source Nodes: [silu], Original ATen: [aten.silu, aten.mul, aten.sigmoid, aten.fill, aten.sub, aten.add]
|
| 273 |
+
triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0_xnumel = 3072*s0*s1
|
| 274 |
+
stream0 = get_raw_stream(0)
|
| 275 |
+
triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0.run(buf5, buf1, mm, buf2, triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0_xnumel, stream=stream0)
|
| 276 |
+
del buf1
|
| 277 |
+
del mm
|
| 278 |
+
buf3 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16)
|
| 279 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 280 |
+
extern_kernels.mm(reinterpret_tensor(buf2, (3072, s0*s1), (1, 3072), 0), view, out=buf3)
|
| 281 |
+
buf4 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
|
| 282 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 283 |
+
extern_kernels.mm(reinterpret_tensor(buf2, (s0*s1, 3072), (3072, 1), 0), permute_9, out=buf4)
|
| 284 |
+
del buf2
|
| 285 |
+
del permute_9
|
| 286 |
+
buf6 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16)
|
| 287 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 288 |
+
extern_kernels.mm(reinterpret_tensor(buf5, (3072, s0*s1), (1, 3072), 0), view, out=buf6)
|
| 289 |
+
del view
|
| 290 |
+
buf7 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
|
| 291 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
|
| 292 |
+
extern_kernels.mm(reinterpret_tensor(buf5, (s0*s1, 3072), (3072, 1), 0), permute_14, out=buf7)
|
| 293 |
+
del buf5
|
| 294 |
+
del permute_14
|
| 295 |
+
buf8 = reinterpret_tensor(buf4, (s0, s1, 1024), (1024*s1, 1024, 1), 0); del buf4 # reuse
|
| 296 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add]
|
| 297 |
+
triton_poi_fused_add_1_xnumel = 1024*s0*s1
|
| 298 |
+
stream0 = get_raw_stream(0)
|
| 299 |
+
triton_poi_fused_add_1.run(buf8, buf7, triton_poi_fused_add_1_xnumel, stream=stream0)
|
| 300 |
+
del buf7
|
| 301 |
+
return (buf6, None, None, buf8, buf3, buf0, )
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 305 |
+
from torch._dynamo.testing import rand_strided
|
| 306 |
+
from torch._inductor.utils import print_performance
|
| 307 |
+
primals_2 = 8
|
| 308 |
+
primals_3 = 1000
|
| 309 |
+
mul = 8000
|
| 310 |
+
view = rand_strided((8000, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 311 |
+
mm = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 312 |
+
mm_1 = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 313 |
+
view_4 = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 314 |
+
permute_5 = rand_strided((1024, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 315 |
+
permute_9 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 316 |
+
permute_14 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 317 |
+
tangents_1 = rand_strided((8, 1000, 1024), (1024000, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 318 |
+
fn = lambda: call([primals_2, primals_3, mul, view, mm, mm_1, view_4, permute_5, permute_9, permute_14, tangents_1])
|
| 319 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 324 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor_ch-epfl-345354-j/a6/ca64rxymdowafnowfq53ckfynl3yei5mmfkeefu6f6xndlg3ukok.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile-time auto-tuning block:
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.testing import rand_strided
|
| 6 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 7 |
+
from torch._inductor.select_algorithm import AlgorithmSelectorCache
|
| 8 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 9 |
+
|
| 10 |
+
async_compile = AsyncCompile()
|
| 11 |
+
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
| 12 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 13 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
triton_poi_fused_mul_silu_0 = async_compile.triton('triton_poi_fused_mul_silu_0', '''
|
| 17 |
+
import triton
|
| 18 |
+
import triton.language as tl
|
| 19 |
+
|
| 20 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 21 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 22 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 23 |
+
triton_helpers.set_driver_to_gpu()
|
| 24 |
+
|
| 25 |
+
@triton_heuristics.pointwise(
|
| 26 |
+
size_hints={'x': 33554432},
|
| 27 |
+
filename=__file__,
|
| 28 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 29 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 30 |
+
min_elem_per_thread=0
|
| 31 |
+
)
|
| 32 |
+
@triton.jit
|
| 33 |
+
def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 34 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 35 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 36 |
+
xmask = xindex < xnumel
|
| 37 |
+
x0 = xindex
|
| 38 |
+
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 39 |
+
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 40 |
+
tmp1 = tmp0.to(tl.float32)
|
| 41 |
+
tmp2 = tl.sigmoid(tmp1)
|
| 42 |
+
tmp3 = tmp1 * tmp2
|
| 43 |
+
tmp4 = tmp3.to(tl.float32)
|
| 44 |
+
tmp6 = tmp4 * tmp5
|
| 45 |
+
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
|
| 46 |
+
''', device_str='cuda')
|
| 47 |
+
|
| 48 |
+
async_compile.wait(globals())
|
| 49 |
+
del async_compile
|
| 50 |
+
|
| 51 |
+
import triton
|
| 52 |
+
import triton.language as tl
|
| 53 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 54 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 55 |
+
with torch.cuda._DeviceGuard(0):
|
| 56 |
+
torch.cuda.set_device(0)
|
| 57 |
+
stream0 = get_raw_stream(0)
|
| 58 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 59 |
+
stream0 = get_raw_stream(0)
|
| 60 |
+
buf2 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
|
| 61 |
+
buf1 = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
|
| 62 |
+
triton_poi_fused_mul_silu_0.run(buf2, buf1, 24576000, stream=stream0)
|
| 63 |
+
del buf2, buf1
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
# AOT ID: ['5_inference']
|
| 67 |
+
from ctypes import c_void_p, c_long, c_int
|
| 68 |
+
import torch
|
| 69 |
+
import math
|
| 70 |
+
import random
|
| 71 |
+
import os
|
| 72 |
+
import tempfile
|
| 73 |
+
from math import inf, nan
|
| 74 |
+
from cmath import nanj
|
| 75 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 76 |
+
from torch._inductor.utils import maybe_profile
|
| 77 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 78 |
+
from torch import device, empty_strided
|
| 79 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 80 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 81 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 82 |
+
import triton
|
| 83 |
+
import triton.language as tl
|
| 84 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 85 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 86 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 87 |
+
|
| 88 |
+
aten = torch.ops.aten
|
| 89 |
+
inductor_ops = torch.ops.inductor
|
| 90 |
+
_quantized = torch.ops._quantized
|
| 91 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 92 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 93 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 94 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 95 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 96 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 97 |
+
async_compile = AsyncCompile()
|
| 98 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# kernel path: /tmp/torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py
|
| 102 |
+
# Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul]
|
| 103 |
+
# Source node to ATen node mapping:
|
| 104 |
+
# mul => mul_38
|
| 105 |
+
# silu => convert_element_type_2, convert_element_type_3, mul_19, sigmoid
|
| 106 |
+
# Graph fragment:
|
| 107 |
+
# %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
|
| 108 |
+
# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_2,), kwargs = {})
|
| 109 |
+
# %mul_19 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %sigmoid), kwargs = {})
|
| 110 |
+
# %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_19, torch.bfloat16), kwargs = {})
|
| 111 |
+
# %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_3, %view_3), kwargs = {})
|
| 112 |
+
triton_poi_fused_mul_silu_0 = async_compile.triton('triton_poi_fused_mul_silu_0', '''
|
| 113 |
+
import triton
|
| 114 |
+
import triton.language as tl
|
| 115 |
+
|
| 116 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 117 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 118 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 119 |
+
triton_helpers.set_driver_to_gpu()
|
| 120 |
+
|
| 121 |
+
@triton_heuristics.pointwise(
|
| 122 |
+
size_hints={'x': 33554432},
|
| 123 |
+
filename=__file__,
|
| 124 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 125 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 126 |
+
min_elem_per_thread=0
|
| 127 |
+
)
|
| 128 |
+
@triton.jit
|
| 129 |
+
def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 130 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 131 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 132 |
+
xmask = xindex < xnumel
|
| 133 |
+
x0 = xindex
|
| 134 |
+
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
|
| 135 |
+
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
|
| 136 |
+
tmp1 = tmp0.to(tl.float32)
|
| 137 |
+
tmp2 = tl.sigmoid(tmp1)
|
| 138 |
+
tmp3 = tmp1 * tmp2
|
| 139 |
+
tmp4 = tmp3.to(tl.float32)
|
| 140 |
+
tmp6 = tmp4 * tmp5
|
| 141 |
+
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
|
| 142 |
+
''', device_str='cuda')
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
async_compile.wait(globals())
|
| 146 |
+
del async_compile
|
| 147 |
+
|
| 148 |
+
def call(args):
|
| 149 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
|
| 150 |
+
args.clear()
|
| 151 |
+
s0 = arg1_1
|
| 152 |
+
s1 = arg2_1
|
| 153 |
+
assert_size_stride(arg0_1, (3072, 1024), (1024, 1))
|
| 154 |
+
assert_size_stride(arg3_1, (s0, s1, 1024), (1024*s1, 1024, 1))
|
| 155 |
+
assert_size_stride(arg4_1, (3072, 1024), (1024, 1))
|
| 156 |
+
assert_size_stride(arg5_1, (1024, 3072), (3072, 1))
|
| 157 |
+
with torch.cuda._DeviceGuard(0):
|
| 158 |
+
torch.cuda.set_device(0)
|
| 159 |
+
pool1 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
|
| 160 |
+
buf0 = pool1 # alloc
|
| 161 |
+
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
|
| 162 |
+
extern_kernels.mm(reinterpret_tensor(arg3_1, (s0*s1, 1024), (1024, 1), 0), reinterpret_tensor(arg0_1, (1024, 3072), (1, 1024), 0), out=buf0)
|
| 163 |
+
del arg0_1
|
| 164 |
+
pool2 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
|
| 165 |
+
buf1 = pool2 # alloc
|
| 166 |
+
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
|
| 167 |
+
extern_kernels.mm(reinterpret_tensor(arg3_1, (s0*s1, 1024), (1024, 1), 0), reinterpret_tensor(arg4_1, (1024, 3072), (1, 1024), 0), out=buf1)
|
| 168 |
+
del arg3_1
|
| 169 |
+
del arg4_1
|
| 170 |
+
buf2 = reinterpret_tensor(buf0, (s0, s1, 3072), (3072*s1, 3072, 1), 0); # reuse
|
| 171 |
+
# Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul]
|
| 172 |
+
triton_poi_fused_mul_silu_0_xnumel = 3072*s0*s1
|
| 173 |
+
stream0 = get_raw_stream(0)
|
| 174 |
+
triton_poi_fused_mul_silu_0.run(buf2, buf1, triton_poi_fused_mul_silu_0_xnumel, stream=stream0)
|
| 175 |
+
del pool2, buf1
|
| 176 |
+
pool0 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
|
| 177 |
+
buf3 = pool0 # alloc
|
| 178 |
+
# Topologically Sorted Source Nodes: [down_proj], Original ATen: [aten.mm]
|
| 179 |
+
extern_kernels.mm(reinterpret_tensor(buf2, (s0*s1, 3072), (3072, 1), 0), reinterpret_tensor(arg5_1, (3072, 1024), (1, 3072), 0), out=buf3)
|
| 180 |
+
del arg5_1
|
| 181 |
+
del pool1, buf0, buf2
|
| 182 |
+
return (reinterpret_tensor(buf3, (s0, s1, 1024), (1024*s1, 1024, 1), 0), )
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 186 |
+
from torch._dynamo.testing import rand_strided
|
| 187 |
+
from torch._inductor.utils import print_performance
|
| 188 |
+
arg0_1 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 189 |
+
arg1_1 = 8
|
| 190 |
+
arg2_1 = 1000
|
| 191 |
+
arg3_1 = rand_strided((8, 1000, 1024), (1024000, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 192 |
+
arg4_1 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 193 |
+
arg5_1 = rand_strided((1024, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 194 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
|
| 195 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 200 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor_ch-epfl-345354-j/aotautograd/acxk7xhb35e5myvrfk4m2smos5f3rwybegalnbqbgtl3ghlaw2vw/entry
ADDED
|
Binary file (96.7 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/adw7o5w6jucvlwdu4mn3nk52nno5z3lt73pmvaksrn3cahxlwc5t/entry
ADDED
|
Binary file (5.55 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/agm67xcx3b2ejeqf3t422b43zsalmtzgitagqmb4kcd76dzg2sr6/entry
ADDED
|
Binary file (15.3 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/ahinqqlnserz457jqclv2vjeogmqix7jcrylpuhbc64kw4k3apfy/entry
ADDED
|
Binary file (7.48 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/ahji7b2arusm47q6ox5itjvurtws6r6kls2kskgxfnc2rqm4ojdg/entry
ADDED
|
Binary file (4.83 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/aibbpzcrlnv7lrbehiaaab4olrvijekv6m46vdzzqh3tbnvnl67m/entry
ADDED
|
Binary file (26.4 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/aig3hpjgzj7f27hhdphh7ozndqiwpruhugzjsiwyog75fn4y3rbj/entry
ADDED
|
Binary file (7.63 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/algm7vsngjdke6rmqon76peuppnhsp625k5d4zxnwgwdbdueo4ay/entry
ADDED
|
Binary file (4.37 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/amtpnp6cq6z6ddoun3fwe4zemhgpsp5jicklj6cf3qzsd3xbdeps/entry
ADDED
|
Binary file (4.65 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/aqskia64x2j4xks7dhp5cpq52le5j6js6ghxfhlvw7gfa6qr6stx/entry
ADDED
|
Binary file (16.2 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/aub73aicaqeihl4qdqbrvljzl2qxdzyu52zezeket676qt3pkgwk/entry
ADDED
|
Binary file (4.65 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/aotautograd/aypob3g4nwzt66m7ur252rhjjobqgnn4hvdhagr4474twkamikxg/entry
ADDED
|
Binary file (18.9 kB). View file
|
|
|
torchinductor_ch-epfl-345354-j/bm/cbmn253c3hy77ciw3f6meqi4bsbiio5zhw7hra5np6k5jyjqetnp.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 1},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*i64', 'out_ptr0': '*fp32', 'xnumel': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_div_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
tmp0 = tl.load(in_ptr0 + (0))
|
| 24 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
|
| 25 |
+
tmp2 = tl.load(in_ptr1 + (0))
|
| 26 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
|
| 27 |
+
tmp4 = tmp3.to(tl.float32)
|
| 28 |
+
tmp5 = (tmp1 / tmp4)
|
| 29 |
+
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp5, None)
|
torchinductor_ch-epfl-345354-j/ce/cceyvpztlniy45jdq6sxx7o44obzjinfuxgsvnlhcr3hjdvmek73.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 1024, 'r0_': 64},
|
| 12 |
+
reduction_hint=ReductionHint.OUTER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1024
|
| 20 |
+
r0_numel = 40
|
| 21 |
+
R0_BLOCK: tl.constexpr = 64
|
| 22 |
+
rnumel = r0_numel
|
| 23 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 24 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 25 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 26 |
+
xmask = xindex < xnumel
|
| 27 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 28 |
+
r0_offset = 0
|
| 29 |
+
r0_mask = r0_index < r0_numel
|
| 30 |
+
roffset = r0_offset
|
| 31 |
+
rindex = r0_index
|
| 32 |
+
r0_1 = r0_index
|
| 33 |
+
x0 = xindex
|
| 34 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 1024*r0_1), xmask & r0_mask, other=0.0)
|
| 35 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 36 |
+
tmp3 = tl.where(r0_mask & xmask, tmp1, 0)
|
| 37 |
+
tmp4 = tl.sum(tmp3, 1)[:, None]
|
| 38 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
torchinductor_ch-epfl-345354-j/cf/ae1632ffa009afdc4d40d5477a8e2ffd544972ad9ddf0c636c451826b3219579.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 108}
|
torchinductor_ch-epfl-345354-j/cf/ccfnt2f53rlwauznvnabnitvjchbzg7at22w4x4fskqzmyirxuxq.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 128
|
| 20 |
+
R0_BLOCK: tl.constexpr = 128
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_1 = r0_index
|
| 32 |
+
x0 = xindex
|
| 33 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
|
| 34 |
+
tmp12 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
|
| 35 |
+
tmp1 = tmp0.to(tl.float32)
|
| 36 |
+
tmp2 = tmp1 * tmp1
|
| 37 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 38 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 39 |
+
tmp6 = tl.sum(tmp5, 1)[:, None]
|
| 40 |
+
tmp7 = 128.0
|
| 41 |
+
tmp8 = (tmp6 / tmp7)
|
| 42 |
+
tmp9 = 1e-06
|
| 43 |
+
tmp10 = tmp8 + tmp9
|
| 44 |
+
tmp11 = libdevice.rsqrt(tmp10)
|
| 45 |
+
tmp13 = tmp1 * tmp11
|
| 46 |
+
tmp14 = tmp13.to(tl.float32)
|
| 47 |
+
tmp15 = tmp12 * tmp14
|
| 48 |
+
tl.debug_barrier()
|
| 49 |
+
tl.store(in_out_ptr0 + (x0), tmp11, xmask)
|
| 50 |
+
tl.store(out_ptr0 + (r0_1 + 128*x0), tmp15, xmask)
|