MusYW commited on
Commit
7f0c6e6
·
verified ·
1 Parent(s): 4d5676b

Training in progress, step 500

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +19 -0
  2. added_tokens.json +28 -0
  3. config.json +32 -0
  4. merges.txt +0 -0
  5. model.safetensors +3 -0
  6. special_tokens_map.json +31 -0
  7. tmpcnl2zetw/__pycache__/_remote_module_non_scriptable.cpython-312.pyc +0 -0
  8. tmpcnl2zetw/_remote_module_non_scriptable.py +81 -0
  9. tmpfehyc297/_remote_module_non_scriptable.py +81 -0
  10. tokenizer.json +3 -0
  11. tokenizer_config.json +240 -0
  12. torchinductor_ch-epfl-345354-j/2f/8f4778b9c2bdc504a3c5d1b5bc09dac279f18294d78f93a4c581178508bcf83b.best_config +1 -0
  13. torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py +46 -0
  14. torchinductor_ch-epfl-345354-j/2r/717267c6902a1a61a8cc50b68a007cec3f90a0241185b112347df8b34fa8c605.best_config +1 -0
  15. torchinductor_ch-epfl-345354-j/2r/c2rhxwmxh62lojowjb65g6mbzowlwbjcacwjmn3vu63z4qatxuo3.py +26 -0
  16. torchinductor_ch-epfl-345354-j/2y/c2ykxnj2iqrpp4u3ihziotcanxl3tc27h7ajzahx5wypy4anuhuj.py +88 -0
  17. torchinductor_ch-epfl-345354-j/3m/c3mt4utggpr6zcsqyeele6646fofhvyk4xxtwll4gqqa5w6nrbct.py +55 -0
  18. torchinductor_ch-epfl-345354-j/43/c43m5ctxi7dcy4hjgz5jijzo4xp7fp3bmvzcjp3ygmirxptgoerd.py +53 -0
  19. torchinductor_ch-epfl-345354-j/4i/c4iarmybewwgyq7pa6izmajgs66hg4cgb6yhmezt4tg6j77oklfi.py +50 -0
  20. torchinductor_ch-epfl-345354-j/53/c53mrwlx5sxivgg5x5z6kkaldo2q5yn2pjsymcv27tpzj2cdoeww.py +66 -0
  21. torchinductor_ch-epfl-345354-j/56/c56q66j66nfzeu5puvuhal4wt2foih6rnb5nwmqolafn3iq33kjp.py +66 -0
  22. torchinductor_ch-epfl-345354-j/57/c573irrqes6p6it4yfyvqw2efgfbbgp7yjzxjjxq5jpeesj3bi77.py +353 -0
  23. torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py +30 -0
  24. torchinductor_ch-epfl-345354-j/57/dad7be19dc394c1e08368515640dff88b78797aaabae100a15a1f195476a9a87.best_config +1 -0
  25. torchinductor_ch-epfl-345354-j/5e/c5enonf6qztlsw7dozsqkejk4exzt4n56gbz6fiey2gnus5vdf76.py +66 -0
  26. torchinductor_ch-epfl-345354-j/5x/c5xsvywggx5vrzm2l5uaktu7pipclhdn5h6263yru2ugvuhe2nak.py +57 -0
  27. torchinductor_ch-epfl-345354-j/6d/c6dsbxlebwjqawzeprkq3lkldtxoiept4c6bpgtva5r4mjlrnwlr.py +229 -0
  28. torchinductor_ch-epfl-345354-j/6j/7c215475e7b40a21cf286026270965eb7f07e7c3af1c4052d331de3f74c6449e.best_config +1 -0
  29. torchinductor_ch-epfl-345354-j/6j/c6j5lx5qgycfvyi3dm5f4mo3ssluzzsrmdq32pka7e6pyhg42zvd.py +499 -0
  30. torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py +40 -0
  31. torchinductor_ch-epfl-345354-j/7a/c7a4b5izank2343xz4473c4igojrrhlfxb5ulctqd32qrtkreq3m.py +42 -0
  32. torchinductor_ch-epfl-345354-j/7a/ddaeab32a9175f6d14ae7329f3defe09537605c41fa4d35da5bf9cbac1616b91.best_config +1 -0
  33. torchinductor_ch-epfl-345354-j/7q/c7qudnwq7tyfwnepjsm2ilmratxdwkx4euvow7brbvrfif7hgnwh.py +324 -0
  34. torchinductor_ch-epfl-345354-j/a6/ca64rxymdowafnowfq53ckfynl3yei5mmfkeefu6f6xndlg3ukok.py +200 -0
  35. torchinductor_ch-epfl-345354-j/aotautograd/acxk7xhb35e5myvrfk4m2smos5f3rwybegalnbqbgtl3ghlaw2vw/entry +0 -0
  36. torchinductor_ch-epfl-345354-j/aotautograd/adw7o5w6jucvlwdu4mn3nk52nno5z3lt73pmvaksrn3cahxlwc5t/entry +0 -0
  37. torchinductor_ch-epfl-345354-j/aotautograd/agm67xcx3b2ejeqf3t422b43zsalmtzgitagqmb4kcd76dzg2sr6/entry +0 -0
  38. torchinductor_ch-epfl-345354-j/aotautograd/ahinqqlnserz457jqclv2vjeogmqix7jcrylpuhbc64kw4k3apfy/entry +0 -0
  39. torchinductor_ch-epfl-345354-j/aotautograd/ahji7b2arusm47q6ox5itjvurtws6r6kls2kskgxfnc2rqm4ojdg/entry +0 -0
  40. torchinductor_ch-epfl-345354-j/aotautograd/aibbpzcrlnv7lrbehiaaab4olrvijekv6m46vdzzqh3tbnvnl67m/entry +0 -0
  41. torchinductor_ch-epfl-345354-j/aotautograd/aig3hpjgzj7f27hhdphh7ozndqiwpruhugzjsiwyog75fn4y3rbj/entry +0 -0
  42. torchinductor_ch-epfl-345354-j/aotautograd/algm7vsngjdke6rmqon76peuppnhsp625k5d4zxnwgwdbdueo4ay/entry +0 -0
  43. torchinductor_ch-epfl-345354-j/aotautograd/amtpnp6cq6z6ddoun3fwe4zemhgpsp5jicklj6cf3qzsd3xbdeps/entry +0 -0
  44. torchinductor_ch-epfl-345354-j/aotautograd/aqskia64x2j4xks7dhp5cpq52le5j6js6ghxfhlvw7gfa6qr6stx/entry +0 -0
  45. torchinductor_ch-epfl-345354-j/aotautograd/aub73aicaqeihl4qdqbrvljzl2qxdzyu52zezeket676qt3pkgwk/entry +0 -0
  46. torchinductor_ch-epfl-345354-j/aotautograd/aypob3g4nwzt66m7ur252rhjjobqgnn4hvdhagr4474twkamikxg/entry +0 -0
  47. torchinductor_ch-epfl-345354-j/bm/cbmn253c3hy77ciw3f6meqi4bsbiio5zhw7hra5np6k5jyjqetnp.py +29 -0
  48. torchinductor_ch-epfl-345354-j/ce/cceyvpztlniy45jdq6sxx7o44obzjinfuxgsvnlhcr3hjdvmek73.py +38 -0
  49. torchinductor_ch-epfl-345354-j/cf/ae1632ffa009afdc4d40d5477a8e2ffd544972ad9ddf0c636c451826b3219579.best_config +1 -0
  50. torchinductor_ch-epfl-345354-j/cf/ccfnt2f53rlwauznvnabnitvjchbzg7at22w4x4fskqzmyirxuxq.py +50 -0
.gitattributes CHANGED
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ torchinductor_ch-epfl-345354-j/fxgraph/5s/f5sdzjvbwcmgigljts5qiy6lpxvzqdph6wzu6phn3y3ibrcaorli/khjgqt4e4qpxv6eeg3vxxgpwqsv3grytkxtje4324jtolexrezo filter=lfs diff=lfs merge=lfs -text
38
+ torchinductor_ch-epfl-345354-j/fxgraph/6x/f6x5yse7q3r4kgrx3mzcmr7b2m72jtxrtziweqwk5lposwvr3y52/wqdd5re3wcyjqm2yolzwsoaqo7xdfcnam6vnnh7oy4krfpy35mu filter=lfs diff=lfs merge=lfs -text
39
+ torchinductor_ch-epfl-345354-j/fxgraph/ap/fap4jdkhbhos2zlgzy4vqldjzu7uaf3wuhatrkr2kcwc42gvg2yz/2zrvkpixflxwoduo4mf45p672nlwxrnyfi7jiaahhn2lw6eafh7 filter=lfs diff=lfs merge=lfs -text
40
+ torchinductor_ch-epfl-345354-j/fxgraph/ay/fay24mgdmhudth5v4jelopeu2revquda5zewn436r6biwrnpgabo/rua4hs554jun3xlx2vr5pka3oqdzq77wfoxhvymrffnwxy27nbt filter=lfs diff=lfs merge=lfs -text
41
+ torchinductor_ch-epfl-345354-j/fxgraph/be/fbei2uvrpqs44s3zmuwyjn6byfolnsqoa7juh23nj5xwvypzh4qm/isndyt5bxypcz3icmz7sxxnxnv5k5flzys3ydjieaujxkpv2by2 filter=lfs diff=lfs merge=lfs -text
42
+ torchinductor_ch-epfl-345354-j/fxgraph/dz/fdz6ntzqbiwadat6ybb6wx5r72slohcdcufolrfguxpyy72ha4k3/6wwa6mdlxwrojkbhhttg3r434evsqlaub3wfx47anl25o536b6x filter=lfs diff=lfs merge=lfs -text
43
+ torchinductor_ch-epfl-345354-j/fxgraph/el/felw7pslqa4fo6ex4wmphdc2ybuhij3y7nr6ek7weyw6exi5n6un/vtvnvlc4u5nq7hdpjcieofwiwev7pvczdillvjg3yxva6aojzbt filter=lfs diff=lfs merge=lfs -text
44
+ torchinductor_ch-epfl-345354-j/fxgraph/io/fiop3vsjr3dy3bv5akbf4igspko5bzgikfs2c5r65fwccfuvu7ux/qzhalrsuhi7mryjabm3iq6rhxlb5ozesgq6j7k5e7iyja2k7vsq filter=lfs diff=lfs merge=lfs -text
45
+ torchinductor_ch-epfl-345354-j/fxgraph/kl/fklqttybh3jmeqtx6bvvj4haqgxnjqufqyoz2nks5uigm4r36cx4/hxqz22nepqoek7pc76k5r3icp4zjsh4kaue55ohl6xt6tbvp4oi filter=lfs diff=lfs merge=lfs -text
46
+ torchinductor_ch-epfl-345354-j/fxgraph/lg/flg4hgizerwmptq2h73kr43ahriht7psoidazshg6qzgrnt647t2/pijpcpasxrep34djgpqcasekfyyagfa4g3bjwvevt3aj3ofuu6q filter=lfs diff=lfs merge=lfs -text
47
+ torchinductor_ch-epfl-345354-j/fxgraph/ng/fngku4s6qemdyw2pe5ve3jmj4j3kwxfgrfoqpednglkm5rrggltm/gpu6deqxpscuw7qomk5b3bn5nhb7huhwwde46fignb66f72pvmw filter=lfs diff=lfs merge=lfs -text
48
+ torchinductor_ch-epfl-345354-j/fxgraph/ro/froa33wtzy5mq3utlkybs5daxqnv2apit6smxhgb22cf4rr7fysv/rvynlwjqk2pli2qmjywzpal7ly5ufkhjhxzh4v7y5dgjunk5bmh filter=lfs diff=lfs merge=lfs -text
49
+ torchinductor_ch-epfl-345354-j/fxgraph/tm/ftmxrirms7cnlnwoxoz62nacrceoyhhzhbnuwmnbxdx426kuxbca/nmw3t6ssgbvfrpuhau55inf3woabflb6qc3ynybm53y42slvkos filter=lfs diff=lfs merge=lfs -text
50
+ torchinductor_ch-epfl-345354-j/fxgraph/tu/ftuiasjkco5eqbhoc3ebj57kfzkbkbzglhq2wrkg4hjpkhixsfbl/qpwpe3w2vqzxprimhpukbj46zopeocm5rat5f5apx7ghtoyyi67 filter=lfs diff=lfs merge=lfs -text
51
+ torchinductor_ch-epfl-345354-j/fxgraph/xb/fxbygtwj5f3jfikmhfoz7vd3mtceabjyw5ggscr4lin73qu5vibc/typ73hul4yzbqjyctyc3spof5hr4okcdyp27oj4p5lpfovus76n filter=lfs diff=lfs merge=lfs -text
52
+ torchinductor_ch-epfl-345354-j/triton/0/35ZDQUBPR56EFY64BDG6UB7OKJODSPJPYCTXXIKSOWZ2CA3EPWGA/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
53
+ torchinductor_ch-epfl-345354-j/triton/0/SS6ERMBZYPQUFPYGWYCICDOLY35WYNXDIHS3C45P5EZJOFLNB5SQ/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
54
+ torchinductor_ch-epfl-345354-j/triton/0/YGRXAKJ5T6UCGDY6RY3GMSSD4JVJFIRM6WYZVP7IT63SMM7NWHSA/triton_poi_fused__to_copy_cos_mul_sin_1.cubin filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151643,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 28,
15
+ "model_type": "qwen3",
16
+ "num_attention_heads": 16,
17
+ "num_hidden_layers": 28,
18
+ "num_key_value_heads": 8,
19
+ "pad_token_id": 151654,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.3",
27
+ "unsloth_fixed": true,
28
+ "unsloth_version": "2025.5.7",
29
+ "use_cache": true,
30
+ "use_sliding_window": false,
31
+ "vocab_size": 151936
32
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26045c44c8f7b94239544f69255de30ea5e05c7f8e23f0c01a67e755eaa0beba
3
+ size 1192135096
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|vision_pad|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tmpcnl2zetw/__pycache__/_remote_module_non_scriptable.cpython-312.pyc ADDED
Binary file (2.6 kB). View file
 
tmpcnl2zetw/_remote_module_non_scriptable.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.distributed.rpc as rpc
5
+ from torch import Tensor
6
+ from torch._jit_internal import Future
7
+ from torch.distributed.rpc import RRef
8
+ from typing import Tuple # pyre-ignore: unused import
9
+
10
+
11
+ module_interface_cls = None
12
+
13
+
14
+ def forward_async(self, *args, **kwargs):
15
+ args = (self.module_rref, self.device, self.is_device_map_set, *args)
16
+ kwargs = {**kwargs}
17
+ return rpc.rpc_async(
18
+ self.module_rref.owner(),
19
+ _remote_forward,
20
+ args,
21
+ kwargs,
22
+ )
23
+
24
+
25
+ def forward(self, *args, **kwargs):
26
+ args = (self.module_rref, self.device, self.is_device_map_set, *args)
27
+ kwargs = {**kwargs}
28
+ ret_fut = rpc.rpc_async(
29
+ self.module_rref.owner(),
30
+ _remote_forward,
31
+ args,
32
+ kwargs,
33
+ )
34
+ return ret_fut.wait()
35
+
36
+
37
+ _generated_methods = [
38
+ forward_async,
39
+ forward,
40
+ ]
41
+
42
+
43
+
44
+
45
+ def _remote_forward(
46
+ module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs):
47
+ module = module_rref.local_value()
48
+ device = torch.device(device)
49
+
50
+ if device.type != "cuda":
51
+ return module.forward(*args, **kwargs)
52
+
53
+ # If the module is on a cuda device,
54
+ # move any CPU tensor in args or kwargs to the same cuda device.
55
+ # Since torch script does not support generator expression,
56
+ # have to use concatenation instead of
57
+ # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
58
+ args = (*args,)
59
+ out_args: Tuple[()] = ()
60
+ for arg in args:
61
+ arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
62
+ out_args = out_args + arg
63
+
64
+ kwargs = {**kwargs}
65
+ for k, v in kwargs.items():
66
+ if isinstance(v, Tensor):
67
+ kwargs[k] = kwargs[k].to(device)
68
+
69
+ if is_device_map_set:
70
+ return module.forward(*out_args, **kwargs)
71
+
72
+ # If the device map is empty, then only CPU tensors are allowed to send over wire,
73
+ # so have to move any GPU tensor to CPU in the output.
74
+ # Since torch script does not support generator expression,
75
+ # have to use concatenation instead of
76
+ # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, **kwargs))``.
77
+ ret: Tuple[()] = ()
78
+ for i in module.forward(*out_args, **kwargs):
79
+ i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
80
+ ret = ret + i
81
+ return ret
tmpfehyc297/_remote_module_non_scriptable.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.distributed.rpc as rpc
5
+ from torch import Tensor
6
+ from torch._jit_internal import Future
7
+ from torch.distributed.rpc import RRef
8
+ from typing import Tuple # pyre-ignore: unused import
9
+
10
+
11
+ module_interface_cls = None
12
+
13
+
14
+ def forward_async(self, *args, **kwargs):
15
+ args = (self.module_rref, self.device, self.is_device_map_set, *args)
16
+ kwargs = {**kwargs}
17
+ return rpc.rpc_async(
18
+ self.module_rref.owner(),
19
+ _remote_forward,
20
+ args,
21
+ kwargs,
22
+ )
23
+
24
+
25
+ def forward(self, *args, **kwargs):
26
+ args = (self.module_rref, self.device, self.is_device_map_set, *args)
27
+ kwargs = {**kwargs}
28
+ ret_fut = rpc.rpc_async(
29
+ self.module_rref.owner(),
30
+ _remote_forward,
31
+ args,
32
+ kwargs,
33
+ )
34
+ return ret_fut.wait()
35
+
36
+
37
+ _generated_methods = [
38
+ forward_async,
39
+ forward,
40
+ ]
41
+
42
+
43
+
44
+
45
+ def _remote_forward(
46
+ module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs):
47
+ module = module_rref.local_value()
48
+ device = torch.device(device)
49
+
50
+ if device.type != "cuda":
51
+ return module.forward(*args, **kwargs)
52
+
53
+ # If the module is on a cuda device,
54
+ # move any CPU tensor in args or kwargs to the same cuda device.
55
+ # Since torch script does not support generator expression,
56
+ # have to use concatenation instead of
57
+ # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
58
+ args = (*args,)
59
+ out_args: Tuple[()] = ()
60
+ for arg in args:
61
+ arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
62
+ out_args = out_args + arg
63
+
64
+ kwargs = {**kwargs}
65
+ for k, v in kwargs.items():
66
+ if isinstance(v, Tensor):
67
+ kwargs[k] = kwargs[k].to(device)
68
+
69
+ if is_device_map_set:
70
+ return module.forward(*out_args, **kwargs)
71
+
72
+ # If the device map is empty, then only CPU tensors are allowed to send over wire,
73
+ # so have to move any GPU tensor to CPU in the output.
74
+ # Since torch script does not support generator expression,
75
+ # have to use concatenation instead of
76
+ # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, **kwargs))``.
77
+ ret: Tuple[()] = ()
78
+ for i in module.forward(*out_args, **kwargs):
79
+ i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
80
+ ret = ret + i
81
+ return ret
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|endoftext|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 32768,
235
+ "pad_token": "<|vision_pad|>",
236
+ "padding_side": "left",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
torchinductor_ch-epfl-345354-j/2f/8f4778b9c2bdc504a3c5d1b5bc09dac279f18294d78f93a4c581178508bcf83b.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50}
torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 8388608},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x0 = (xindex % ks0)
24
+ x2 = ((xindex // ks1) % ks2)
25
+ x5 = xindex // ks0
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
28
+ tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
29
+ tmp2 = tmp0 * tmp1
30
+ tmp3 = x0
31
+ tmp4 = tl.full([1], 0, tl.int64)
32
+ tmp5 = tmp3 >= tmp4
33
+ tmp6 = ks0 + (-1)*(ks0 // 2)
34
+ tmp7 = tmp3 < tmp6
35
+ tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
36
+ tmp9 = -tmp8
37
+ tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
38
+ tmp11 = tl.where(tmp7, tmp9, tmp10)
39
+ tmp12 = tmp3 >= tmp6
40
+ tmp13 = ks0
41
+ tmp14 = tmp3 < tmp13
42
+ tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
43
+ tmp16 = tl.where(tmp7, tmp11, tmp15)
44
+ tmp18 = tmp16 * tmp17
45
+ tmp19 = tmp2 + tmp18
46
+ tl.store(out_ptr0 + (x4), tmp19, xmask)
torchinductor_ch-epfl-345354-j/2r/717267c6902a1a61a8cc50b68a007cec3f90a0241185b112347df8b34fa8c605.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 81}
torchinductor_ch-epfl-345354-j/2r/c2rhxwmxh62lojowjb65g6mbzowlwbjcacwjmn3vu63z4qatxuo3.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 2048},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_3', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_nll_loss_backward_3(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
24
+ tl.device_assert(((0 <= tmp0) & (tmp0 < ks0)) | ~(xmask), "index out of bounds: 0 <= tmp0 < ks0")
25
+ tmp2 = -1.0
26
+ tl.store(out_ptr0 + (tmp0 + ks0*x0), tmp2, xmask)
torchinductor_ch-epfl-345354-j/2y/c2ykxnj2iqrpp4u3ihziotcanxl3tc27h7ajzahx5wypy4anuhuj.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 2048},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'in_ptr7': '*fp32', 'in_ptr8': '*fp32', 'in_ptr9': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_nll_loss_backward_nll_loss_forward_15', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 10, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_add_nll_loss_backward_nll_loss_forward_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp28 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp5 = tl.load(in_ptr1 + (r0_0 + 6*((6 + ks0*ks1) // 7)), r0_mask, eviction_policy='evict_first', other=0.0)
35
+ tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp0 = ((r0_0 + 6*((6 + ks0*ks1) // 7)) % ks1)
38
+ tmp1 = (-1) + ks1
39
+ tmp2 = tmp0 == tmp1
40
+ tmp3 = tmp0 < tmp1
41
+ tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + r0_0 + 6*((6 + ks0*ks1) // 7), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_first', other=0.0)
42
+ tmp6 = tl.where(tmp3, tmp4, tmp5)
43
+ tmp7 = tl.full([1, 1], -100, tl.int64)
44
+ tmp8 = tl.where(tmp2, tmp7, tmp6)
45
+ tmp9 = tmp8 != tmp7
46
+ tmp10 = tl.full([1, 1], 0, tl.int64)
47
+ tmp11 = tl.where(tmp9, tmp8, tmp10)
48
+ tmp12 = ks2
49
+ tmp13 = tmp11 + tmp12
50
+ tmp14 = tmp11 < 0
51
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
52
+ tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
53
+ tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 6*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
54
+ tmp18 = tmp17.to(tl.float32)
55
+ tmp20 = tmp18 - tmp19
56
+ tmp22 = tl_math.log(tmp21)
57
+ tmp23 = tmp20 - tmp22
58
+ tmp24 = -tmp23
59
+ tmp25 = 0.0
60
+ tmp26 = tl.where(tmp9, tmp24, tmp25)
61
+ tmp27 = tl.broadcast_to(tmp26, [XBLOCK, R0_BLOCK])
62
+ tmp29 = _tmp28 + tmp27
63
+ _tmp28 = tl.where(r0_mask, tmp29, _tmp28)
64
+ tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
65
+ tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
66
+ tmp28 = tl.sum(_tmp28, 1)[:, None]
67
+ tmp30 = tl.load(in_out_ptr0 + (0))
68
+ tmp31 = tl.broadcast_to(tmp30, [XBLOCK, 1])
69
+ tmp34 = tl.load(in_ptr5 + (0))
70
+ tmp35 = tl.broadcast_to(tmp34, [XBLOCK, 1])
71
+ tmp37 = tl.load(in_ptr6 + (0))
72
+ tmp38 = tl.broadcast_to(tmp37, [XBLOCK, 1])
73
+ tmp40 = tl.load(in_ptr7 + (0))
74
+ tmp41 = tl.broadcast_to(tmp40, [XBLOCK, 1])
75
+ tmp43 = tl.load(in_ptr8 + (0))
76
+ tmp44 = tl.broadcast_to(tmp43, [XBLOCK, 1])
77
+ tmp46 = tl.load(in_ptr9 + (0))
78
+ tmp47 = tl.broadcast_to(tmp46, [XBLOCK, 1])
79
+ tmp32 = 0.0
80
+ tmp33 = tmp31 + tmp32
81
+ tmp36 = tmp33 + tmp35
82
+ tmp39 = tmp36 + tmp38
83
+ tmp42 = tmp39 + tmp41
84
+ tmp45 = tmp42 + tmp44
85
+ tmp48 = tmp45 + tmp47
86
+ tmp49 = tmp48 + tmp28
87
+ tl.debug_barrier()
88
+ tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp49, None)
torchinductor_ch-epfl-345354-j/3m/c3mt4utggpr6zcsqyeele6646fofhvyk4xxtwll4gqqa5w6nrbct.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 2048, 'r0_': 262144},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__to_copy_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__log_softmax__to_copy_2(in_out_ptr0, in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
28
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_1 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks2*x0 + 2*ks2*((6 + ks0*ks1) // 7)), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
36
+ tmp1 = tmp0.to(tl.float32)
37
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
38
+
39
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
40
+ _tmp3_max, _tmp3_sum, tmp2, False
41
+ )
42
+
43
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
44
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
45
+
46
+ tmp5, tmp6 = triton_helpers.online_softmax_reduce(
47
+ _tmp3_max, _tmp3_sum, 1, False)
48
+ tmp5 = tmp5[:, None]
49
+ tmp6 = tmp6[:, None]
50
+ tmp3 = tmp5
51
+ tmp4 = tmp6
52
+ tl.store(out_ptr0 + (x0), tmp3, xmask)
53
+ tmp7 = tl_math.log(tmp4)
54
+ tl.debug_barrier()
55
+ tl.store(in_out_ptr0 + (x0), tmp7, xmask)
torchinductor_ch-epfl-345354-j/43/c43m5ctxi7dcy4hjgz5jijzo4xp7fp3bmvzcjp3ygmirxptgoerd.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 2048, 'r0_': 262144},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_5', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_5(in_ptr0, out_ptr0, out_ptr1, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
28
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_1 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks2*x0 + 5*ks2*((6 + ks0*ks1) // 7)), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
36
+ tmp1 = tmp0.to(tl.float32)
37
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
38
+
39
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
40
+ _tmp3_max, _tmp3_sum, tmp2, False
41
+ )
42
+
43
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
44
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
45
+
46
+ tmp5, tmp6 = triton_helpers.online_softmax_reduce(
47
+ _tmp3_max, _tmp3_sum, 1, False)
48
+ tmp5 = tmp5[:, None]
49
+ tmp6 = tmp6[:, None]
50
+ tmp3 = tmp5
51
+ tmp4 = tmp6
52
+ tl.store(out_ptr0 + (x0), tmp3, xmask)
53
+ tl.store(out_ptr1 + (x0), tmp4, xmask)
torchinductor_ch-epfl-345354-j/4i/c4iarmybewwgyq7pa6izmajgs66hg4cgb6yhmezt4tg6j77oklfi.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 8192, 'r0_': 1024},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel):
19
+ XBLOCK: tl.constexpr = 1
20
+ r0_numel = 1024
21
+ R0_BLOCK: tl.constexpr = 1024
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = tl.full([1], xoffset, tl.int32)
26
+ xmask = tl.full([R0_BLOCK], True, tl.int1)
27
+ r0_index = tl.arange(0, R0_BLOCK)[:]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
35
+ tmp11 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
36
+ tmp1 = tmp0.to(tl.float32)
37
+ tmp2 = tmp1 * tmp1
38
+ tmp3 = tl.broadcast_to(tmp2, [R0_BLOCK])
39
+ tmp5 = triton_helpers.promote_to_tensor(tl.sum(tmp3, 0))
40
+ tmp6 = 1024.0
41
+ tmp7 = (tmp5 / tmp6)
42
+ tmp8 = 1e-06
43
+ tmp9 = tmp7 + tmp8
44
+ tmp10 = libdevice.rsqrt(tmp9)
45
+ tmp12 = tmp1 * tmp10
46
+ tmp13 = tmp12.to(tl.float32)
47
+ tmp14 = tmp11 * tmp13
48
+ tl.debug_barrier()
49
+ tl.store(in_out_ptr0 + (x0), tmp10, None)
50
+ tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp14, None)
torchinductor_ch-epfl-345354-j/53/c53mrwlx5sxivgg5x5z6kkaldo2q5yn2pjsymcv27tpzj2cdoeww.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 2048},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_11', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_nll_loss_backward_nll_loss_forward_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp5 = tl.load(in_ptr1 + (((r0_0 + 2*((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
35
+ tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp0 = ((r0_0 + 2*((6 + ks0*ks1) // 7)) % ks1)
38
+ tmp1 = (-1) + ks1
39
+ tmp2 = tmp0 == tmp1
40
+ tmp3 = tmp0 < tmp1
41
+ tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + 2*((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
42
+ tmp6 = tl.where(tmp3, tmp4, tmp5)
43
+ tmp7 = tl.full([1, 1], -100, tl.int64)
44
+ tmp8 = tl.where(tmp2, tmp7, tmp6)
45
+ tmp9 = tmp8 != tmp7
46
+ tmp10 = tl.full([1, 1], 0, tl.int64)
47
+ tmp11 = tl.where(tmp9, tmp8, tmp10)
48
+ tmp12 = ks2
49
+ tmp13 = tmp11 + tmp12
50
+ tmp14 = tmp11 < 0
51
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
52
+ tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
53
+ tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 2*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
54
+ tmp18 = tmp17.to(tl.float32)
55
+ tmp20 = tmp18 - tmp19
56
+ tmp22 = tmp20 - tmp21
57
+ tmp23 = -tmp22
58
+ tmp24 = 0.0
59
+ tmp25 = tl.where(tmp9, tmp23, tmp24)
60
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
61
+ tmp28 = _tmp27 + tmp26
62
+ _tmp27 = tl.where(r0_mask, tmp28, _tmp27)
63
+ tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
64
+ tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
65
+ tmp27 = tl.sum(_tmp27, 1)[:, None]
66
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
torchinductor_ch-epfl-345354-j/56/c56q66j66nfzeu5puvuhal4wt2foih6rnb5nwmqolafn3iq33kjp.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 2048},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_10', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_nll_loss_backward_nll_loss_forward_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp5 = tl.load(in_ptr1 + (((r0_0 + ((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
35
+ tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp0 = ((r0_0 + ((6 + ks0*ks1) // 7)) % ks1)
38
+ tmp1 = (-1) + ks1
39
+ tmp2 = tmp0 == tmp1
40
+ tmp3 = tmp0 < tmp1
41
+ tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + ((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
42
+ tmp6 = tl.where(tmp3, tmp4, tmp5)
43
+ tmp7 = tl.full([1, 1], -100, tl.int64)
44
+ tmp8 = tl.where(tmp2, tmp7, tmp6)
45
+ tmp9 = tmp8 != tmp7
46
+ tmp10 = tl.full([1, 1], 0, tl.int64)
47
+ tmp11 = tl.where(tmp9, tmp8, tmp10)
48
+ tmp12 = ks2
49
+ tmp13 = tmp11 + tmp12
50
+ tmp14 = tmp11 < 0
51
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
52
+ tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
53
+ tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
54
+ tmp18 = tmp17.to(tl.float32)
55
+ tmp20 = tmp18 - tmp19
56
+ tmp22 = tmp20 - tmp21
57
+ tmp23 = -tmp22
58
+ tmp24 = 0.0
59
+ tmp25 = tl.where(tmp9, tmp23, tmp24)
60
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
61
+ tmp28 = _tmp27 + tmp26
62
+ _tmp27 = tl.where(r0_mask, tmp28, _tmp27)
63
+ tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
64
+ tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
65
+ tmp27 = tl.sum(_tmp27, 1)[:, None]
66
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
torchinductor_ch-epfl-345354-j/57/c573irrqes6p6it4yfyvqw2efgfbbgp7yjzxjjxq5jpeesj3bi77.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile-time auto-tuning block:
3
+
4
+ import torch
5
+ from torch._dynamo.testing import rand_strided
6
+ from torch._dynamo.utils import preserve_rng_state
7
+ from torch._inductor.select_algorithm import AlgorithmSelectorCache
8
+ from torch._inductor.async_compile import AsyncCompile
9
+
10
+ async_compile = AsyncCompile()
11
+ generate_example_value = AlgorithmSelectorCache.generate_example_value
12
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
13
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
14
+
15
+
16
+ triton_poi_fused_add_cat_mul_0 = async_compile.triton('triton_poi_fused_add_cat_mul_0', '''
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
21
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
22
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
23
+ triton_helpers.set_driver_to_gpu()
24
+
25
+ @triton_heuristics.pointwise(
26
+ size_hints={'x': 16777216},
27
+ filename=__file__,
28
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
29
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
30
+ min_elem_per_thread=0
31
+ )
32
+ @triton.jit
33
+ def triton_poi_fused_add_cat_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
34
+ xoffset = tl.program_id(0) * XBLOCK
35
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
36
+ xmask = xindex < xnumel
37
+ x4 = xindex
38
+ x0 = (xindex % ks0)
39
+ x2 = ((xindex // ks1) % ks2)
40
+ x5 = xindex // ks0
41
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
42
+ tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
43
+ tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
44
+ tmp2 = tmp0 * tmp1
45
+ tmp3 = x0
46
+ tmp4 = tl.full([1], 0, tl.int64)
47
+ tmp5 = tmp3 >= tmp4
48
+ tmp6 = ks0 + (-1)*(ks0 // 2)
49
+ tmp7 = tmp3 < tmp6
50
+ tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
51
+ tmp9 = -tmp8
52
+ tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
53
+ tmp11 = tl.where(tmp7, tmp9, tmp10)
54
+ tmp12 = tmp3 >= tmp6
55
+ tmp13 = ks0
56
+ tmp14 = tmp3 < tmp13
57
+ tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
58
+ tmp16 = tl.where(tmp7, tmp11, tmp15)
59
+ tmp18 = tmp16 * tmp17
60
+ tmp19 = tmp2 + tmp18
61
+ tl.store(out_ptr0 + (x4), tmp19, xmask)
62
+ ''', device_str='cuda')
63
+
64
+
65
+ triton_poi_fused_add_cat_mul_1 = async_compile.triton('triton_poi_fused_add_cat_mul_1', '''
66
+ import triton
67
+ import triton.language as tl
68
+
69
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
70
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
71
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
72
+ triton_helpers.set_driver_to_gpu()
73
+
74
+ @triton_heuristics.pointwise(
75
+ size_hints={'x': 8388608},
76
+ filename=__file__,
77
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
78
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
79
+ min_elem_per_thread=0
80
+ )
81
+ @triton.jit
82
+ def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
83
+ xoffset = tl.program_id(0) * XBLOCK
84
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
85
+ xmask = xindex < xnumel
86
+ x4 = xindex
87
+ x0 = (xindex % ks0)
88
+ x2 = ((xindex // ks1) % ks2)
89
+ x5 = xindex // ks0
90
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
91
+ tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
92
+ tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
93
+ tmp2 = tmp0 * tmp1
94
+ tmp3 = x0
95
+ tmp4 = tl.full([1], 0, tl.int64)
96
+ tmp5 = tmp3 >= tmp4
97
+ tmp6 = ks0 + (-1)*(ks0 // 2)
98
+ tmp7 = tmp3 < tmp6
99
+ tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
100
+ tmp9 = -tmp8
101
+ tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
102
+ tmp11 = tl.where(tmp7, tmp9, tmp10)
103
+ tmp12 = tmp3 >= tmp6
104
+ tmp13 = ks0
105
+ tmp14 = tmp3 < tmp13
106
+ tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
107
+ tmp16 = tl.where(tmp7, tmp11, tmp15)
108
+ tmp18 = tmp16 * tmp17
109
+ tmp19 = tmp2 + tmp18
110
+ tl.store(out_ptr0 + (x4), tmp19, xmask)
111
+ ''', device_str='cuda')
112
+
113
+ async_compile.wait(globals())
114
+ del async_compile
115
+
116
+ import triton
117
+ import triton.language as tl
118
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
119
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
120
+ with torch.cuda._DeviceGuard(0):
121
+ torch.cuda.set_device(0)
122
+ stream0 = get_raw_stream(0)
123
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
124
+ stream0 = get_raw_stream(0)
125
+ arg7_1 = generate_example_value((8, 16, 1000, 128), (2048000, 128, 2048, 1), 'cuda:0', torch.bfloat16, 0, (8, 16, 1000, 128))
126
+ arg2_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
127
+ arg4_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
128
+ buf0 = generate_example_value((8, 16, 1000, 128), (2048000, 128, 2048, 1), 'cuda:0', torch.bfloat16, 0, (8, 16, 1000, 128))
129
+ triton_poi_fused_add_cat_mul_0.run(arg7_1, arg2_1, arg4_1, buf0, 128, 2048, 1000, 16384000, stream=stream0)
130
+ del arg7_1, arg2_1, arg4_1, buf0
131
+
132
+ stream0 = get_raw_stream(0)
133
+ arg8_1 = generate_example_value((8, 8, 1000, 128), (1024000, 128, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 8, 1000, 128))
134
+ arg2_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
135
+ arg4_1 = generate_example_value((1, 1000, 128), (128000, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1000, 128))
136
+ buf1 = generate_example_value((8, 8, 1000, 128), (1024000, 128, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 8, 1000, 128))
137
+ triton_poi_fused_add_cat_mul_1.run(arg8_1, arg2_1, arg4_1, buf1, 128, 1024, 1000, 8192000, stream=stream0)
138
+ del arg8_1, arg2_1, arg4_1, buf1
139
+
140
+ """
141
+ # AOT ID: ['3_inference']
142
+ from ctypes import c_void_p, c_long, c_int
143
+ import torch
144
+ import math
145
+ import random
146
+ import os
147
+ import tempfile
148
+ from math import inf, nan
149
+ from cmath import nanj
150
+ from torch._inductor.hooks import run_intermediate_hooks
151
+ from torch._inductor.utils import maybe_profile
152
+ from torch._inductor.codegen.memory_planning import _align as align
153
+ from torch import device, empty_strided
154
+ from torch._inductor.async_compile import AsyncCompile
155
+ from torch._inductor.select_algorithm import extern_kernels
156
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
157
+ import triton
158
+ import triton.language as tl
159
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
160
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
161
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
162
+
163
+ aten = torch.ops.aten
164
+ inductor_ops = torch.ops.inductor
165
+ _quantized = torch.ops._quantized
166
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
167
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
168
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
169
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
170
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
171
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
172
+ async_compile = AsyncCompile()
173
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
174
+
175
+
176
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/nf/cnf7ddto2mtv7utbz2ev3zn2hkgmh5nivfwjn3kwh7vb2j3fnuyw.py
177
+ # Topologically Sorted Source Nodes: [mul, cat, mul_1, q_embed], Original ATen: [aten.mul, aten.cat, aten.add]
178
+ # Source node to ATen node mapping:
179
+ # cat => cat
180
+ # mul => mul_8
181
+ # mul_1 => mul_29
182
+ # q_embed => add_36
183
+ # Graph fragment:
184
+ # %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg7_1, %unsqueeze), kwargs = {})
185
+ # %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
186
+ # %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
187
+ # %add_36 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %mul_29), kwargs = {})
188
+ triton_poi_fused_add_cat_mul_0 = async_compile.triton('triton_poi_fused_add_cat_mul_0', '''
189
+ import triton
190
+ import triton.language as tl
191
+
192
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
193
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
194
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
195
+ triton_helpers.set_driver_to_gpu()
196
+
197
+ @triton_heuristics.pointwise(
198
+ size_hints={'x': 16777216},
199
+ filename=__file__,
200
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
201
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
202
+ min_elem_per_thread=0
203
+ )
204
+ @triton.jit
205
+ def triton_poi_fused_add_cat_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
206
+ xoffset = tl.program_id(0) * XBLOCK
207
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
208
+ xmask = xindex < xnumel
209
+ x4 = xindex
210
+ x0 = (xindex % ks0)
211
+ x2 = ((xindex // ks1) % ks2)
212
+ x5 = xindex // ks0
213
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
214
+ tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
215
+ tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
216
+ tmp2 = tmp0 * tmp1
217
+ tmp3 = x0
218
+ tmp4 = tl.full([1], 0, tl.int64)
219
+ tmp5 = tmp3 >= tmp4
220
+ tmp6 = ks0 + (-1)*(ks0 // 2)
221
+ tmp7 = tmp3 < tmp6
222
+ tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
223
+ tmp9 = -tmp8
224
+ tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
225
+ tmp11 = tl.where(tmp7, tmp9, tmp10)
226
+ tmp12 = tmp3 >= tmp6
227
+ tmp13 = ks0
228
+ tmp14 = tmp3 < tmp13
229
+ tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
230
+ tmp16 = tl.where(tmp7, tmp11, tmp15)
231
+ tmp18 = tmp16 * tmp17
232
+ tmp19 = tmp2 + tmp18
233
+ tl.store(out_ptr0 + (x4), tmp19, xmask)
234
+ ''', device_str='cuda')
235
+
236
+
237
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/2f/c2fzymhcr3rme5dtns3jvvyp6x3osfrlp7cc7zg7igwzigqmmg65.py
238
+ # Topologically Sorted Source Nodes: [mul_2, cat_1, mul_3, k_embed], Original ATen: [aten.mul, aten.cat, aten.add]
239
+ # Source node to ATen node mapping:
240
+ # cat_1 => cat_1
241
+ # k_embed => add_72
242
+ # mul_2 => mul_38
243
+ # mul_3 => mul_59
244
+ # Graph fragment:
245
+ # %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg8_1, %unsqueeze), kwargs = {})
246
+ # %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
247
+ # %mul_59 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
248
+ # %add_72 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_38, %mul_59), kwargs = {})
249
+ triton_poi_fused_add_cat_mul_1 = async_compile.triton('triton_poi_fused_add_cat_mul_1', '''
250
+ import triton
251
+ import triton.language as tl
252
+
253
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
254
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
255
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
256
+ triton_helpers.set_driver_to_gpu()
257
+
258
+ @triton_heuristics.pointwise(
259
+ size_hints={'x': 8388608},
260
+ filename=__file__,
261
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
262
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
263
+ min_elem_per_thread=0
264
+ )
265
+ @triton.jit
266
+ def triton_poi_fused_add_cat_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
267
+ xoffset = tl.program_id(0) * XBLOCK
268
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
269
+ xmask = xindex < xnumel
270
+ x4 = xindex
271
+ x0 = (xindex % ks0)
272
+ x2 = ((xindex // ks1) % ks2)
273
+ x5 = xindex // ks0
274
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
275
+ tmp1 = tl.load(in_ptr1 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
276
+ tmp17 = tl.load(in_ptr2 + (x0 + ks0*x2), xmask, eviction_policy='evict_last').to(tl.float32)
277
+ tmp2 = tmp0 * tmp1
278
+ tmp3 = x0
279
+ tmp4 = tl.full([1], 0, tl.int64)
280
+ tmp5 = tmp3 >= tmp4
281
+ tmp6 = ks0 + (-1)*(ks0 // 2)
282
+ tmp7 = tmp3 < tmp6
283
+ tmp8 = tl.load(in_ptr0 + (ks0*x5 + (ks0 // 2) + (x0)), xmask & tmp7, eviction_policy='evict_last', other=0.0).to(tl.float32)
284
+ tmp9 = -tmp8
285
+ tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
286
+ tmp11 = tl.where(tmp7, tmp9, tmp10)
287
+ tmp12 = tmp3 >= tmp6
288
+ tmp13 = ks0
289
+ tmp14 = tmp3 < tmp13
290
+ tmp15 = tl.load(in_ptr0 + (ks0*x5 + (x0 + ((-1)*ks0) + (ks0 // 2))), xmask & tmp12, eviction_policy='evict_last', other=0.0).to(tl.float32)
291
+ tmp16 = tl.where(tmp7, tmp11, tmp15)
292
+ tmp18 = tmp16 * tmp17
293
+ tmp19 = tmp2 + tmp18
294
+ tl.store(out_ptr0 + (x4), tmp19, xmask)
295
+ ''', device_str='cuda')
296
+
297
+
298
+ async_compile.wait(globals())
299
+ del async_compile
300
+
301
+ def call(args):
302
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1 = args
303
+ args.clear()
304
+ s0 = arg0_1
305
+ s1 = arg1_1
306
+ s7 = arg5_1
307
+ s8 = arg6_1
308
+ assert_size_stride(arg2_1, (1, s0, s1), (s0*s1, s1, 1))
309
+ assert_size_stride(arg4_1, (1, s0, s1), (s0*s1, s1, 1))
310
+ assert_size_stride(arg7_1, (s7, s8, s0, s1), (s0*s1*s8, s1, s1*s8, 1))
311
+ assert_size_stride(arg8_1, (s7, s7, s0, s1), (s0*s1*s7, s1, s1*s7, 1))
312
+ with torch.cuda._DeviceGuard(0):
313
+ torch.cuda.set_device(0)
314
+ ps0 = s1*s8
315
+ pool1 = empty_strided_cuda((s7, s8, s0, s1), (s0*s1*s8, s1, s1*s8, 1), torch.bfloat16)
316
+ buf0 = pool1 # alloc
317
+ # Topologically Sorted Source Nodes: [mul, cat, mul_1, q_embed], Original ATen: [aten.mul, aten.cat, aten.add]
318
+ triton_poi_fused_add_cat_mul_0_xnumel = s0*s1*s7*s8
319
+ stream0 = get_raw_stream(0)
320
+ triton_poi_fused_add_cat_mul_0.run(arg7_1, arg2_1, arg4_1, buf0, s1, ps0, s0, triton_poi_fused_add_cat_mul_0_xnumel, stream=stream0)
321
+ del arg7_1
322
+ ps1 = s1*s7
323
+ pool0 = empty_strided_cuda((s7, s7, s0, s1), (s0*s1*s7, s1, s1*s7, 1), torch.bfloat16)
324
+ buf1 = pool0 # alloc
325
+ # Topologically Sorted Source Nodes: [mul_2, cat_1, mul_3, k_embed], Original ATen: [aten.mul, aten.cat, aten.add]
326
+ triton_poi_fused_add_cat_mul_1_xnumel = s0*s1*s7*s7
327
+ stream0 = get_raw_stream(0)
328
+ triton_poi_fused_add_cat_mul_1.run(arg8_1, arg2_1, arg4_1, buf1, s1, ps1, s0, triton_poi_fused_add_cat_mul_1_xnumel, stream=stream0)
329
+ del arg2_1
330
+ del arg4_1
331
+ del arg8_1
332
+ return (buf0, buf1, )
333
+
334
+
335
+ def benchmark_compiled_module(times=10, repeat=10):
336
+ from torch._dynamo.testing import rand_strided
337
+ from torch._inductor.utils import print_performance
338
+ arg0_1 = 1000
339
+ arg1_1 = 128
340
+ arg2_1 = rand_strided((1, 1000, 128), (128000, 128, 1), device='cuda:0', dtype=torch.bfloat16)
341
+ arg3_1 = 1
342
+ arg4_1 = rand_strided((1, 1000, 128), (128000, 128, 1), device='cuda:0', dtype=torch.bfloat16)
343
+ arg5_1 = 8
344
+ arg6_1 = 16
345
+ arg7_1 = rand_strided((8, 16, 1000, 128), (2048000, 128, 2048, 1), device='cuda:0', dtype=torch.bfloat16)
346
+ arg8_1 = rand_strided((8, 8, 1000, 128), (1024000, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
347
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1])
348
+ return print_performance(fn, times=times, repeat=repeat)
349
+
350
+
351
+ if __name__ == "__main__":
352
+ from torch._inductor.wrapper_benchmark import compiled_module_main
353
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 33554432},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
24
+ tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
25
+ tmp1 = tmp0.to(tl.float32)
26
+ tmp2 = tl.sigmoid(tmp1)
27
+ tmp3 = tmp1 * tmp2
28
+ tmp4 = tmp3.to(tl.float32)
29
+ tmp6 = tmp4 * tmp5
30
+ tl.store(in_out_ptr0 + (x0), tmp6, xmask)
torchinductor_ch-epfl-345354-j/57/dad7be19dc394c1e08368515640dff88b78797aaabae100a15a1f195476a9a87.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 67}
torchinductor_ch-epfl-345354-j/5e/c5enonf6qztlsw7dozsqkejk4exzt4n56gbz6fiey2gnus5vdf76.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 2048},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*i1', 'out_ptr2': '*i64', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_backward_nll_loss_forward_12', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_nll_loss_backward_nll_loss_forward_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp27 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp5 = tl.load(in_ptr1 + (((r0_0 + 3*((6 + ks0*ks1) // 7)) % (ks0*ks1))), r0_mask, eviction_policy='evict_last', other=0.0)
35
+ tmp19 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp21 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp0 = ((r0_0 + 3*((6 + ks0*ks1) // 7)) % ks1)
38
+ tmp1 = (-1) + ks1
39
+ tmp2 = tmp0 == tmp1
40
+ tmp3 = tmp0 < tmp1
41
+ tmp4 = tl.load(in_ptr0 + (tl.broadcast_to(1 + (((r0_0 + 3*((6 + ks0*ks1) // 7)) % (ks0*ks1))), [XBLOCK, R0_BLOCK])), r0_mask & tmp3, eviction_policy='evict_last', other=0.0)
42
+ tmp6 = tl.where(tmp3, tmp4, tmp5)
43
+ tmp7 = tl.full([1, 1], -100, tl.int64)
44
+ tmp8 = tl.where(tmp2, tmp7, tmp6)
45
+ tmp9 = tmp8 != tmp7
46
+ tmp10 = tl.full([1, 1], 0, tl.int64)
47
+ tmp11 = tl.where(tmp9, tmp8, tmp10)
48
+ tmp12 = ks2
49
+ tmp13 = tmp11 + tmp12
50
+ tmp14 = tmp11 < 0
51
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
52
+ tl.device_assert(((0 <= tmp15) & (tmp15 < ks2)) | ~(r0_mask), "index out of bounds: 0 <= tmp15 < ks2")
53
+ tmp17 = tl.load(in_ptr2 + (tmp15 + ks2*r0_0 + 3*ks2*((6 + ks0*ks3) // 7)), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
54
+ tmp18 = tmp17.to(tl.float32)
55
+ tmp20 = tmp18 - tmp19
56
+ tmp22 = tmp20 - tmp21
57
+ tmp23 = -tmp22
58
+ tmp24 = 0.0
59
+ tmp25 = tl.where(tmp9, tmp23, tmp24)
60
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
61
+ tmp28 = _tmp27 + tmp26
62
+ _tmp27 = tl.where(r0_mask, tmp28, _tmp27)
63
+ tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp9, r0_mask)
64
+ tl.store(out_ptr2 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp11, r0_mask)
65
+ tmp27 = tl.sum(_tmp27, 1)[:, None]
66
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp27, None)
torchinductor_ch-epfl-345354-j/5x/c5xsvywggx5vrzm2l5uaktu7pipclhdn5h6263yru2ugvuhe2nak.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 8192, 'r0_': 1024},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel):
19
+ XBLOCK: tl.constexpr = 1
20
+ r0_numel = 1024
21
+ R0_BLOCK: tl.constexpr = 1024
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = tl.full([1], xoffset, tl.int32)
26
+ xmask = tl.full([R0_BLOCK], True, tl.int1)
27
+ r0_index = tl.arange(0, R0_BLOCK)[:]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
35
+ tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
36
+ tmp4 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
37
+ tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
38
+ tmp2 = tmp0 * tmp1
39
+ tmp3 = tmp2.to(tl.float32)
40
+ tmp5 = tmp4.to(tl.float32)
41
+ tmp6 = tmp3 * tmp5
42
+ tmp7 = tl.broadcast_to(tmp6, [R0_BLOCK])
43
+ tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
44
+ tmp11 = tmp3 * tmp10
45
+ tmp12 = -0.5
46
+ tmp13 = tmp9 * tmp12
47
+ tmp14 = tmp10 * tmp10
48
+ tmp15 = tmp14 * tmp10
49
+ tmp16 = tmp13 * tmp15
50
+ tmp17 = 0.0009765625
51
+ tmp18 = tmp16 * tmp17
52
+ tmp19 = 2.0
53
+ tmp20 = tmp5 * tmp19
54
+ tmp21 = tmp18 * tmp20
55
+ tmp22 = tmp11 + tmp21
56
+ tmp23 = tmp22.to(tl.float32)
57
+ tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, None)
torchinductor_ch-epfl-345354-j/6d/c6dsbxlebwjqawzeprkq3lkldtxoiept4c6bpgtva5r4mjlrnwlr.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile-time auto-tuning block:
3
+
4
+ import torch
5
+ from torch._dynamo.testing import rand_strided
6
+ from torch._dynamo.utils import preserve_rng_state
7
+ from torch._inductor.select_algorithm import AlgorithmSelectorCache
8
+ from torch._inductor.async_compile import AsyncCompile
9
+
10
+ async_compile = AsyncCompile()
11
+ generate_example_value = AlgorithmSelectorCache.generate_example_value
12
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
13
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
14
+
15
+
16
+ triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
21
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
22
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
23
+ triton_helpers.set_driver_to_gpu()
24
+
25
+ @triton_heuristics.persistent_reduction(
26
+ size_hints={'x': 131072, 'r0_': 128},
27
+ reduction_hint=ReductionHint.INNER,
28
+ filename=__file__,
29
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
30
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
31
+ )
32
+ @triton.jit
33
+ def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
34
+ r0_numel = 128
35
+ R0_BLOCK: tl.constexpr = 128
36
+ rnumel = r0_numel
37
+ RBLOCK: tl.constexpr = R0_BLOCK
38
+ xoffset = tl.program_id(0) * XBLOCK
39
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
40
+ xmask = xindex < xnumel
41
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
42
+ r0_offset = 0
43
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
44
+ roffset = r0_offset
45
+ rindex = r0_index
46
+ r0_1 = r0_index
47
+ x0 = xindex
48
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
49
+ tmp7 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
50
+ tmp1 = tmp0.to(tl.float32)
51
+ tmp2 = tmp1 * tmp1
52
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
53
+ tmp5 = tl.where(xmask, tmp3, 0)
54
+ tmp6 = tl.sum(tmp5, 1)[:, None]
55
+ tmp8 = 128.0
56
+ tmp9 = (tmp6 / tmp8)
57
+ tmp10 = 1e-06
58
+ tmp11 = tmp9 + tmp10
59
+ tmp12 = libdevice.rsqrt(tmp11)
60
+ tmp13 = tmp1 * tmp12
61
+ tmp14 = tmp13.to(tl.float32)
62
+ tmp15 = tmp7 * tmp14
63
+ tl.store(out_ptr1 + (r0_1 + 128*x0), tmp15, xmask)
64
+ ''', device_str='cuda')
65
+
66
+ async_compile.wait(globals())
67
+ del async_compile
68
+
69
+ import triton
70
+ import triton.language as tl
71
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
72
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
73
+ with torch.cuda._DeviceGuard(0):
74
+ torch.cuda.set_device(0)
75
+ stream0 = get_raw_stream(0)
76
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
77
+ stream0 = get_raw_stream(0)
78
+ arg3_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
79
+ arg4_1 = generate_example_value((128,), (1,), 'cuda:0', torch.bfloat16, 0, (128,))
80
+ buf1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
81
+ triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg3_1, arg4_1, buf1, 128000, 128, stream=stream0)
82
+ del arg3_1, arg4_1, buf1
83
+
84
+ """
85
+ # AOT ID: ['2_inference']
86
+ from ctypes import c_void_p, c_long, c_int
87
+ import torch
88
+ import math
89
+ import random
90
+ import os
91
+ import tempfile
92
+ from math import inf, nan
93
+ from cmath import nanj
94
+ from torch._inductor.hooks import run_intermediate_hooks
95
+ from torch._inductor.utils import maybe_profile
96
+ from torch._inductor.codegen.memory_planning import _align as align
97
+ from torch import device, empty_strided
98
+ from torch._inductor.async_compile import AsyncCompile
99
+ from torch._inductor.select_algorithm import extern_kernels
100
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
101
+ import triton
102
+ import triton.language as tl
103
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
104
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
105
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
106
+
107
+ aten = torch.ops.aten
108
+ inductor_ops = torch.ops.inductor
109
+ _quantized = torch.ops._quantized
110
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
111
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
112
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
113
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
114
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
115
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
116
+ async_compile = AsyncCompile()
117
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
118
+
119
+
120
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/l7/cl7z3tsg3vjxfjb3vqjym4iefsq6h5o7fsfsmbi65q55e56x3lm7.py
121
+ # Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, add, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
122
+ # Source node to ATen node mapping:
123
+ # add => add_15
124
+ # hidden_states => convert_element_type
125
+ # hidden_states_1 => mul_17
126
+ # mul_1 => mul_26
127
+ # pow_1 => pow_1
128
+ # rsqrt => rsqrt
129
+ # to_1 => convert_element_type_1
130
+ # variance => mean
131
+ # Graph fragment:
132
+ # %convert_element_type : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg3_1, torch.float32), kwargs = {})
133
+ # %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
134
+ # %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
135
+ # %add_15 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
136
+ # %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_15,), kwargs = {})
137
+ # %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
138
+ # %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_17, torch.bfloat16), kwargs = {})
139
+ # %mul_26 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %convert_element_type_1), kwargs = {})
140
+ triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
141
+ import triton
142
+ import triton.language as tl
143
+
144
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
145
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
146
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
147
+ triton_helpers.set_driver_to_gpu()
148
+
149
+ @triton_heuristics.persistent_reduction(
150
+ size_hints={'x': 131072, 'r0_': 128},
151
+ reduction_hint=ReductionHint.INNER,
152
+ filename=__file__,
153
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
154
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
155
+ )
156
+ @triton.jit
157
+ def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
158
+ r0_numel = 128
159
+ R0_BLOCK: tl.constexpr = 128
160
+ rnumel = r0_numel
161
+ RBLOCK: tl.constexpr = R0_BLOCK
162
+ xoffset = tl.program_id(0) * XBLOCK
163
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
164
+ xmask = xindex < xnumel
165
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
166
+ r0_offset = 0
167
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
168
+ roffset = r0_offset
169
+ rindex = r0_index
170
+ r0_1 = r0_index
171
+ x0 = xindex
172
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
173
+ tmp7 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
174
+ tmp1 = tmp0.to(tl.float32)
175
+ tmp2 = tmp1 * tmp1
176
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
177
+ tmp5 = tl.where(xmask, tmp3, 0)
178
+ tmp6 = tl.sum(tmp5, 1)[:, None]
179
+ tmp8 = 128.0
180
+ tmp9 = (tmp6 / tmp8)
181
+ tmp10 = 1e-06
182
+ tmp11 = tmp9 + tmp10
183
+ tmp12 = libdevice.rsqrt(tmp11)
184
+ tmp13 = tmp1 * tmp12
185
+ tmp14 = tmp13.to(tl.float32)
186
+ tmp15 = tmp7 * tmp14
187
+ tl.store(out_ptr1 + (r0_1 + 128*x0), tmp15, xmask)
188
+ ''', device_str='cuda')
189
+
190
+
191
+ async_compile.wait(globals())
192
+ del async_compile
193
+
194
+ def call(args):
195
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
196
+ args.clear()
197
+ s3 = arg0_1
198
+ s4 = arg1_1
199
+ s5 = arg2_1
200
+ assert_size_stride(arg3_1, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
201
+ assert_size_stride(arg4_1, (128, ), (1, ))
202
+ with torch.cuda._DeviceGuard(0):
203
+ torch.cuda.set_device(0)
204
+ pool0 = empty_strided_cuda((s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1), torch.bfloat16)
205
+ buf1 = pool0 # alloc
206
+ # Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, add, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
207
+ triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0_xnumel = s3*s4*s5
208
+ stream0 = get_raw_stream(0)
209
+ triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg3_1, arg4_1, buf1, triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0_xnumel, 128, stream=stream0)
210
+ del arg3_1
211
+ del arg4_1
212
+ return (buf1, )
213
+
214
+
215
+ def benchmark_compiled_module(times=10, repeat=10):
216
+ from torch._dynamo.testing import rand_strided
217
+ from torch._inductor.utils import print_performance
218
+ arg0_1 = 8
219
+ arg1_1 = 1000
220
+ arg2_1 = 16
221
+ arg3_1 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
222
+ arg4_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
223
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
224
+ return print_performance(fn, times=times, repeat=repeat)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ from torch._inductor.wrapper_benchmark import compiled_module_main
229
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor_ch-epfl-345354-j/6j/7c215475e7b40a21cf286026270965eb7f07e7c3af1c4052d331de3f74c6449e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 68}
torchinductor_ch-epfl-345354-j/6j/c6j5lx5qgycfvyi3dm5f4mo3ssluzzsrmdq32pka7e6pyhg42zvd.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile-time auto-tuning block:
3
+
4
+ import torch
5
+ from torch._dynamo.testing import rand_strided
6
+ from torch._dynamo.utils import preserve_rng_state
7
+ from torch._inductor.select_algorithm import AlgorithmSelectorCache
8
+ from torch._inductor.async_compile import AsyncCompile
9
+
10
+ async_compile = AsyncCompile()
11
+ generate_example_value = AlgorithmSelectorCache.generate_example_value
12
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
13
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
14
+
15
+
16
+ triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
21
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
22
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
23
+ triton_helpers.set_driver_to_gpu()
24
+
25
+ @triton_heuristics.reduction(
26
+ size_hints={'x': 65536, 'r0_': 512},
27
+ reduction_hint=ReductionHint.OUTER,
28
+ filename=__file__,
29
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
30
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
31
+ )
32
+ @triton.jit
33
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
34
+ xnumel = 40960
35
+ rnumel = r0_numel
36
+ RBLOCK: tl.constexpr = R0_BLOCK
37
+ xoffset = tl.program_id(0) * XBLOCK
38
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
39
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
40
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
41
+ rbase = r0_base
42
+ x1 = xindex // 128
43
+ x0 = (xindex % 128)
44
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
45
+ x3 = xindex
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_2 = r0_index
52
+ tmp0 = r0_2 + x1*((319 + ks0*ks1*ks2) // 320)
53
+ tmp1 = ks0*ks1*ks2
54
+ tmp2 = tmp0 < tmp1
55
+ tmp3 = tl.load(in_ptr0 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
56
+ tmp4 = tl.load(in_ptr1 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
57
+ tmp5 = tmp4.to(tl.float32)
58
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2))), r0_mask & tmp2, eviction_policy='evict_last', other=0.0)
59
+ tmp7 = tmp5 * tmp6
60
+ tmp8 = tmp7.to(tl.float32)
61
+ tmp9 = tmp3 * tmp8
62
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
63
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
64
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
65
+ tmp14 = _tmp13 + tmp12
66
+ _tmp13 = tl.where(r0_mask, tmp14, _tmp13)
67
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
68
+ tl.store(out_ptr0 + (x3), tmp13, None)
69
+ ''', device_str='cuda')
70
+
71
+
72
+ triton_red_fused__to_copy_mul_sum_1 = async_compile.triton('triton_red_fused__to_copy_mul_sum_1', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.reduction(
82
+ size_hints={'x': 128, 'r0_': 512},
83
+ reduction_hint=ReductionHint.OUTER_TINY,
84
+ filename=__file__,
85
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
86
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
87
+ )
88
+ @triton.jit
89
+ def triton_red_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
90
+ xnumel = 128
91
+ r0_numel = 320
92
+ rnumel = r0_numel
93
+ RBLOCK: tl.constexpr = R0_BLOCK
94
+ xoffset = tl.program_id(0) * XBLOCK
95
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
96
+ xmask = xindex < xnumel
97
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
98
+ rbase = r0_base
99
+ x0 = xindex
100
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
101
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
102
+ r0_index = r0_offset + r0_base
103
+ r0_mask = r0_index < r0_numel
104
+ roffset = r0_offset
105
+ rindex = r0_index
106
+ r0_1 = r0_index
107
+ tmp0 = tl.load(in_ptr0 + (x0 + 128*r0_1), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
108
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
109
+ tmp3 = _tmp2 + tmp1
110
+ _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
111
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
112
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
113
+ ''', device_str='cuda')
114
+
115
+
116
+ triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
117
+ import triton
118
+ import triton.language as tl
119
+
120
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
121
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
122
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
123
+ triton_helpers.set_driver_to_gpu()
124
+
125
+ @triton_heuristics.persistent_reduction(
126
+ size_hints={'x': 131072, 'r0_': 128},
127
+ reduction_hint=ReductionHint.INNER,
128
+ filename=__file__,
129
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
130
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
131
+ )
132
+ @triton.jit
133
+ def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
134
+ r0_numel = 128
135
+ R0_BLOCK: tl.constexpr = 128
136
+ rnumel = r0_numel
137
+ RBLOCK: tl.constexpr = R0_BLOCK
138
+ xoffset = tl.program_id(0) * XBLOCK
139
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
140
+ xmask = xindex < xnumel
141
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
142
+ r0_offset = 0
143
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
144
+ roffset = r0_offset
145
+ rindex = r0_index
146
+ r0_1 = r0_index
147
+ x0 = xindex
148
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
149
+ tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
150
+ tmp4 = tl.load(in_ptr2 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
151
+ tmp11 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
152
+ tmp2 = tmp0 * tmp1
153
+ tmp3 = tmp2.to(tl.float32)
154
+ tmp5 = tmp4.to(tl.float32)
155
+ tmp6 = tmp3 * tmp5
156
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
157
+ tmp9 = tl.where(xmask, tmp7, 0)
158
+ tmp10 = tl.sum(tmp9, 1)[:, None]
159
+ tmp12 = tmp3 * tmp11
160
+ tmp13 = -0.5
161
+ tmp14 = tmp10 * tmp13
162
+ tmp15 = tmp11 * tmp11
163
+ tmp16 = tmp15 * tmp11
164
+ tmp17 = tmp14 * tmp16
165
+ tmp18 = 0.0078125
166
+ tmp19 = tmp17 * tmp18
167
+ tmp20 = 2.0
168
+ tmp21 = tmp5 * tmp20
169
+ tmp22 = tmp19 * tmp21
170
+ tmp23 = tmp12 + tmp22
171
+ tmp24 = tmp23.to(tl.float32)
172
+ tl.store(out_ptr1 + (r0_1 + 128*x0), tmp24, xmask)
173
+ ''', device_str='cuda')
174
+
175
+ async_compile.wait(globals())
176
+ del async_compile
177
+
178
+ import triton
179
+ import triton.language as tl
180
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
181
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
182
+ with torch.cuda._DeviceGuard(0):
183
+ torch.cuda.set_device(0)
184
+ stream0 = get_raw_stream(0)
185
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
186
+ stream0 = get_raw_stream(0)
187
+ tangents_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
188
+ primals_4 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
189
+ rsqrt = generate_example_value((8, 1000, 16, 1), (16000, 16, 1, 1), 'cuda:0', torch.float32, 0, (8, 1000, 16, 1))
190
+ buf0 = generate_example_value((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), 'cuda:0', torch.float32, 0, (1, 1, 1, 128, 320))
191
+ triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, 8, 1000, 16, 40960, 400, stream=stream0)
192
+ del tangents_1, primals_4, rsqrt, buf0
193
+
194
+ stream0 = get_raw_stream(0)
195
+ buf0 = generate_example_value((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), 'cuda:0', torch.float32, 0, (1, 1, 1, 128, 320))
196
+ buf1 = generate_example_value((1, 1, 1, 128), (128, 128, 128, 1), 'cuda:0', torch.bfloat16, 0, (1, 1, 1, 128))
197
+ triton_red_fused__to_copy_mul_sum_1.run(buf0, buf1, 128, 320, stream=stream0)
198
+ del buf0, buf1
199
+
200
+ stream0 = get_raw_stream(0)
201
+ tangents_1 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
202
+ primals_5 = generate_example_value((128,), (1,), 'cuda:0', torch.bfloat16, 0, (128,))
203
+ primals_4 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
204
+ rsqrt = generate_example_value((8, 1000, 16, 1), (16000, 16, 1, 1), 'cuda:0', torch.float32, 0, (8, 1000, 16, 1))
205
+ buf3 = generate_example_value((8, 1000, 16, 128), (2048000, 2048, 128, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 16, 128))
206
+ triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(tangents_1, primals_5, primals_4, rsqrt, buf3, 128000, 128, stream=stream0)
207
+ del tangents_1, primals_5, primals_4, rsqrt, buf3
208
+
209
+ """
210
+ # AOT ID: ['9_backward']
211
+ from ctypes import c_void_p, c_long, c_int
212
+ import torch
213
+ import math
214
+ import random
215
+ import os
216
+ import tempfile
217
+ from math import inf, nan
218
+ from cmath import nanj
219
+ from torch._inductor.hooks import run_intermediate_hooks
220
+ from torch._inductor.utils import maybe_profile
221
+ from torch._inductor.codegen.memory_planning import _align as align
222
+ from torch import device, empty_strided
223
+ from torch._inductor.async_compile import AsyncCompile
224
+ from torch._inductor.select_algorithm import extern_kernels
225
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
226
+ import triton
227
+ import triton.language as tl
228
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
229
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
230
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
231
+
232
+ aten = torch.ops.aten
233
+ inductor_ops = torch.ops.inductor
234
+ _quantized = torch.ops._quantized
235
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
236
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
237
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
238
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
239
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
240
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
241
+ async_compile = AsyncCompile()
242
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
243
+
244
+
245
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/mm/cmmurv7ol6o3kll2wm4b6wgtdca4tsysrq7yrhhvfkf7ikm72y24.py
246
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
247
+ # Source node to ATen node mapping:
248
+ # hidden_states => convert_element_type
249
+ # hidden_states_1 => mul_23
250
+ # to_1 => convert_element_type_1
251
+ # Graph fragment:
252
+ # %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
253
+ # %mul_23 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
254
+ # %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
255
+ # %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
256
+ # %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_38, [0, 1, 2], True), kwargs = {})
257
+ triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
258
+ import triton
259
+ import triton.language as tl
260
+
261
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
262
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
263
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
264
+ triton_helpers.set_driver_to_gpu()
265
+
266
+ @triton_heuristics.reduction(
267
+ size_hints={'x': 65536, 'r0_': 512},
268
+ reduction_hint=ReductionHint.OUTER,
269
+ filename=__file__,
270
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
271
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
272
+ )
273
+ @triton.jit
274
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
275
+ xnumel = 40960
276
+ rnumel = r0_numel
277
+ RBLOCK: tl.constexpr = R0_BLOCK
278
+ xoffset = tl.program_id(0) * XBLOCK
279
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
280
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
281
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
282
+ rbase = r0_base
283
+ x1 = xindex // 128
284
+ x0 = (xindex % 128)
285
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
286
+ x3 = xindex
287
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
288
+ r0_index = r0_offset + r0_base
289
+ r0_mask = r0_index < r0_numel
290
+ roffset = r0_offset
291
+ rindex = r0_index
292
+ r0_2 = r0_index
293
+ tmp0 = r0_2 + x1*((319 + ks0*ks1*ks2) // 320)
294
+ tmp1 = ks0*ks1*ks2
295
+ tmp2 = tmp0 < tmp1
296
+ tmp3 = tl.load(in_ptr0 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
297
+ tmp4 = tl.load(in_ptr1 + (x0 + 128*(((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2)))), r0_mask & tmp2, eviction_policy='evict_first', other=0.0).to(tl.float32)
298
+ tmp5 = tmp4.to(tl.float32)
299
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((319 + ks0*ks1*ks2) // 320)) % (ks0*ks1*ks2))), r0_mask & tmp2, eviction_policy='evict_last', other=0.0)
300
+ tmp7 = tmp5 * tmp6
301
+ tmp8 = tmp7.to(tl.float32)
302
+ tmp9 = tmp3 * tmp8
303
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
304
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
305
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
306
+ tmp14 = _tmp13 + tmp12
307
+ _tmp13 = tl.where(r0_mask, tmp14, _tmp13)
308
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
309
+ tl.store(out_ptr0 + (x3), tmp13, None)
310
+ ''', device_str='cuda')
311
+
312
+
313
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/om/com66ihngf42hseqjadt4jcosvwffgk5ynrurmthhraywffjqcop.py
314
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
315
+ # Source node to ATen node mapping:
316
+ # hidden_states => convert_element_type
317
+ # hidden_states_1 => mul_23
318
+ # to_1 => convert_element_type_1
319
+ # Graph fragment:
320
+ # %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
321
+ # %mul_23 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
322
+ # %convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
323
+ # %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
324
+ # %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_38, [0, 1, 2], True), kwargs = {})
325
+ triton_red_fused__to_copy_mul_sum_1 = async_compile.triton('triton_red_fused__to_copy_mul_sum_1', '''
326
+ import triton
327
+ import triton.language as tl
328
+
329
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
330
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
331
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
332
+ triton_helpers.set_driver_to_gpu()
333
+
334
+ @triton_heuristics.reduction(
335
+ size_hints={'x': 128, 'r0_': 512},
336
+ reduction_hint=ReductionHint.OUTER_TINY,
337
+ filename=__file__,
338
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
339
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
340
+ )
341
+ @triton.jit
342
+ def triton_red_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
343
+ xnumel = 128
344
+ r0_numel = 320
345
+ rnumel = r0_numel
346
+ RBLOCK: tl.constexpr = R0_BLOCK
347
+ xoffset = tl.program_id(0) * XBLOCK
348
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
349
+ xmask = xindex < xnumel
350
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
351
+ rbase = r0_base
352
+ x0 = xindex
353
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
354
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
355
+ r0_index = r0_offset + r0_base
356
+ r0_mask = r0_index < r0_numel
357
+ roffset = r0_offset
358
+ rindex = r0_index
359
+ r0_1 = r0_index
360
+ tmp0 = tl.load(in_ptr0 + (x0 + 128*r0_1), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
361
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
362
+ tmp3 = _tmp2 + tmp1
363
+ _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
364
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
365
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
366
+ ''', device_str='cuda')
367
+
368
+
369
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/p3/cp3uyorjp57oo5jc6wsaegyfcdcbyu6gppqliksrkascc6kis3o2.py
370
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
371
+ # Source node to ATen node mapping:
372
+ # hidden_states => convert_element_type
373
+ # Graph fragment:
374
+ # %mul_37 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_5), kwargs = {})
375
+ # %convert_element_type : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
376
+ # %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_37, torch.float32), kwargs = {})
377
+ # %mul_39 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {})
378
+ # %mul_40 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {})
379
+ # %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_39, [3], True), kwargs = {})
380
+ # %div : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, 128), kwargs = {})
381
+ # %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {})
382
+ # %mul_43 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {})
383
+ # %mul_44 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_43), kwargs = {})
384
+ # %add_46 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_40, %mul_44), kwargs = {})
385
+ # %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_46, torch.bfloat16), kwargs = {})
386
+ triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
387
+ import triton
388
+ import triton.language as tl
389
+
390
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
391
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
392
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
393
+ triton_helpers.set_driver_to_gpu()
394
+
395
+ @triton_heuristics.persistent_reduction(
396
+ size_hints={'x': 131072, 'r0_': 128},
397
+ reduction_hint=ReductionHint.INNER,
398
+ filename=__file__,
399
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
400
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
401
+ )
402
+ @triton.jit
403
+ def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
404
+ r0_numel = 128
405
+ R0_BLOCK: tl.constexpr = 128
406
+ rnumel = r0_numel
407
+ RBLOCK: tl.constexpr = R0_BLOCK
408
+ xoffset = tl.program_id(0) * XBLOCK
409
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
410
+ xmask = xindex < xnumel
411
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
412
+ r0_offset = 0
413
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
414
+ roffset = r0_offset
415
+ rindex = r0_index
416
+ r0_1 = r0_index
417
+ x0 = xindex
418
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
419
+ tmp1 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
420
+ tmp4 = tl.load(in_ptr2 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
421
+ tmp11 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
422
+ tmp2 = tmp0 * tmp1
423
+ tmp3 = tmp2.to(tl.float32)
424
+ tmp5 = tmp4.to(tl.float32)
425
+ tmp6 = tmp3 * tmp5
426
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
427
+ tmp9 = tl.where(xmask, tmp7, 0)
428
+ tmp10 = tl.sum(tmp9, 1)[:, None]
429
+ tmp12 = tmp3 * tmp11
430
+ tmp13 = -0.5
431
+ tmp14 = tmp10 * tmp13
432
+ tmp15 = tmp11 * tmp11
433
+ tmp16 = tmp15 * tmp11
434
+ tmp17 = tmp14 * tmp16
435
+ tmp18 = 0.0078125
436
+ tmp19 = tmp17 * tmp18
437
+ tmp20 = 2.0
438
+ tmp21 = tmp5 * tmp20
439
+ tmp22 = tmp19 * tmp21
440
+ tmp23 = tmp12 + tmp22
441
+ tmp24 = tmp23.to(tl.float32)
442
+ tl.store(out_ptr1 + (r0_1 + 128*x0), tmp24, xmask)
443
+ ''', device_str='cuda')
444
+
445
+
446
+ async_compile.wait(globals())
447
+ del async_compile
448
+
449
+ def call(args):
450
+ primals_1, primals_2, primals_3, primals_4, primals_5, rsqrt, tangents_1 = args
451
+ args.clear()
452
+ s3 = primals_1
453
+ s4 = primals_2
454
+ s5 = primals_3
455
+ assert_size_stride(primals_4, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
456
+ assert_size_stride(primals_5, (128, ), (1, ))
457
+ assert_size_stride(rsqrt, (s3, s4, s5, 1), (s4*s5, s5, 1, 1))
458
+ assert_size_stride(tangents_1, (s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1))
459
+ with torch.cuda._DeviceGuard(0):
460
+ torch.cuda.set_device(0)
461
+ buf0 = empty_strided_cuda((1, 1, 1, 128, 320), (40960, 40960, 40960, 1, 128), torch.float32)
462
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
463
+ triton_red_fused__to_copy_mul_sum_0_r0_numel = (319 + s3*s4*s5) // 320
464
+ stream0 = get_raw_stream(0)
465
+ triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s3, s4, s5, 40960, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream0)
466
+ buf1 = empty_strided_cuda((1, 1, 1, 128), (128, 128, 128, 1), torch.bfloat16)
467
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
468
+ stream0 = get_raw_stream(0)
469
+ triton_red_fused__to_copy_mul_sum_1.run(buf0, buf1, 128, 320, stream=stream0)
470
+ del buf0
471
+ buf3 = empty_strided_cuda((s3, s4, s5, 128), (128*s4*s5, 128*s5, 128, 1), torch.bfloat16)
472
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
473
+ triton_per_fused__to_copy_add_div_mul_pow_sum_2_xnumel = s3*s4*s5
474
+ stream0 = get_raw_stream(0)
475
+ triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(tangents_1, primals_5, primals_4, rsqrt, buf3, triton_per_fused__to_copy_add_div_mul_pow_sum_2_xnumel, 128, stream=stream0)
476
+ del primals_4
477
+ del primals_5
478
+ del rsqrt
479
+ del tangents_1
480
+ return (None, None, None, buf3, reinterpret_tensor(buf1, (128, ), (1, ), 0), )
481
+
482
+
483
+ def benchmark_compiled_module(times=10, repeat=10):
484
+ from torch._dynamo.testing import rand_strided
485
+ from torch._inductor.utils import print_performance
486
+ primals_1 = 8
487
+ primals_2 = 1000
488
+ primals_3 = 16
489
+ primals_4 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
490
+ primals_5 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
491
+ rsqrt = rand_strided((8, 1000, 16, 1), (16000, 16, 1, 1), device='cuda:0', dtype=torch.float32)
492
+ tangents_1 = rand_strided((8, 1000, 16, 128), (2048000, 2048, 128, 1), device='cuda:0', dtype=torch.bfloat16)
493
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, rsqrt, tangents_1])
494
+ return print_performance(fn, times=times, repeat=repeat)
495
+
496
+
497
+ if __name__ == "__main__":
498
+ from torch._inductor.wrapper_benchmark import compiled_module_main
499
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 33554432},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
24
+ tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
25
+ tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
26
+ tmp2 = tmp1.to(tl.float32)
27
+ tmp3 = tl.sigmoid(tmp2)
28
+ tmp4 = tmp2 * tmp3
29
+ tmp5 = tmp4.to(tl.float32)
30
+ tmp6 = tmp0 * tmp5
31
+ tmp8 = tmp0 * tmp7
32
+ tmp9 = tl.sigmoid(tmp1)
33
+ tmp10 = 1.0
34
+ tmp11 = tmp10 - tmp9
35
+ tmp12 = tmp1 * tmp11
36
+ tmp13 = tmp12 + tmp10
37
+ tmp14 = tmp9 * tmp13
38
+ tmp15 = tmp8 * tmp14
39
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
40
+ tl.store(in_out_ptr0 + (x0), tmp15, xmask)
torchinductor_ch-epfl-345354-j/7a/c7a4b5izank2343xz4473c4igojrrhlfxb5ulctqd32qrtkreq3m.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 268435456},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*i1', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__log_softmax__log_softmax_backward_data__to_copy_nll_loss_backward_nll_loss_forward_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused__log_softmax__log_softmax_backward_data__to_copy_nll_loss_backward_nll_loss_forward_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x1 = xindex // ks0
24
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
25
+ tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last').to(tl.int1)
26
+ tmp2 = tl.load(in_ptr2 + (0))
27
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
28
+ tmp7 = tl.load(in_ptr3 + (x2), xmask, eviction_policy='evict_last').to(tl.float32)
29
+ tmp9 = tl.load(in_ptr4 + (x1), xmask, eviction_policy='evict_last')
30
+ tmp11 = tl.load(in_ptr5 + (x1), xmask, eviction_policy='evict_last')
31
+ tmp14 = tl.load(in_ptr6 + (x1), xmask, eviction_policy='evict_last')
32
+ tmp4 = 0.0
33
+ tmp5 = tl.where(tmp1, tmp3, tmp4)
34
+ tmp6 = tmp0 * tmp5
35
+ tmp8 = tmp7.to(tl.float32)
36
+ tmp10 = tmp8 - tmp9
37
+ tmp12 = tmp10 - tmp11
38
+ tmp13 = tl_math.exp(tmp12)
39
+ tmp15 = tmp13 * tmp14
40
+ tmp16 = tmp6 - tmp15
41
+ tmp17 = tmp16.to(tl.float32)
42
+ tl.store(out_ptr0 + (x2), tmp17, xmask)
torchinductor_ch-epfl-345354-j/7a/ddaeab32a9175f6d14ae7329f3defe09537605c41fa4d35da5bf9cbac1616b91.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 85}
torchinductor_ch-epfl-345354-j/7q/c7qudnwq7tyfwnepjsm2ilmratxdwkx4euvow7brbvrfif7hgnwh.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile-time auto-tuning block:
3
+
4
+ import torch
5
+ from torch._dynamo.testing import rand_strided
6
+ from torch._dynamo.utils import preserve_rng_state
7
+ from torch._inductor.select_algorithm import AlgorithmSelectorCache
8
+ from torch._inductor.async_compile import AsyncCompile
9
+
10
+ async_compile = AsyncCompile()
11
+ generate_example_value = AlgorithmSelectorCache.generate_example_value
12
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
13
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
14
+
15
+
16
+ triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0 = async_compile.triton('triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', '''
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
21
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
22
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
23
+ triton_helpers.set_driver_to_gpu()
24
+
25
+ @triton_heuristics.pointwise(
26
+ size_hints={'x': 33554432},
27
+ filename=__file__,
28
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
29
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
30
+ min_elem_per_thread=0
31
+ )
32
+ @triton.jit
33
+ def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
34
+ xoffset = tl.program_id(0) * XBLOCK
35
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
36
+ xmask = xindex < xnumel
37
+ x0 = xindex
38
+ tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
39
+ tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
40
+ tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
41
+ tmp2 = tmp1.to(tl.float32)
42
+ tmp3 = tl.sigmoid(tmp2)
43
+ tmp4 = tmp2 * tmp3
44
+ tmp5 = tmp4.to(tl.float32)
45
+ tmp6 = tmp0 * tmp5
46
+ tmp8 = tmp0 * tmp7
47
+ tmp9 = tl.sigmoid(tmp1)
48
+ tmp10 = 1.0
49
+ tmp11 = tmp10 - tmp9
50
+ tmp12 = tmp1 * tmp11
51
+ tmp13 = tmp12 + tmp10
52
+ tmp14 = tmp9 * tmp13
53
+ tmp15 = tmp8 * tmp14
54
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
55
+ tl.store(in_out_ptr0 + (x0), tmp15, xmask)
56
+ ''', device_str='cuda')
57
+
58
+
59
+ triton_poi_fused_add_1 = async_compile.triton('triton_poi_fused_add_1', '''
60
+ import triton
61
+ import triton.language as tl
62
+
63
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
64
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
65
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
66
+ triton_helpers.set_driver_to_gpu()
67
+
68
+ @triton_heuristics.pointwise(
69
+ size_hints={'x': 8388608},
70
+ filename=__file__,
71
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
72
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
73
+ min_elem_per_thread=0
74
+ )
75
+ @triton.jit
76
+ def triton_poi_fused_add_1(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
77
+ xoffset = tl.program_id(0) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
79
+ xmask = xindex < xnumel
80
+ x0 = xindex
81
+ tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
82
+ tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
83
+ tmp2 = tmp0 + tmp1
84
+ tl.store(in_out_ptr0 + (x0), tmp2, xmask)
85
+ ''', device_str='cuda')
86
+
87
+ async_compile.wait(globals())
88
+ del async_compile
89
+
90
+ import triton
91
+ import triton.language as tl
92
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
93
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
94
+ with torch.cuda._DeviceGuard(0):
95
+ torch.cuda.set_device(0)
96
+ stream0 = get_raw_stream(0)
97
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
98
+ stream0 = get_raw_stream(0)
99
+ buf5 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
100
+ buf1 = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
101
+ mm = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
102
+ buf2 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
103
+ triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0.run(buf5, buf1, mm, buf2, 24576000, stream=stream0)
104
+ del buf5, buf1, mm, buf2
105
+
106
+ stream0 = get_raw_stream(0)
107
+ buf8 = generate_example_value((8, 1000, 1024), (1024000, 1024, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 1024))
108
+ buf7 = generate_example_value((8000, 1024), (1024, 1), 'cuda:0', torch.bfloat16, 0, (8000, 1024))
109
+ triton_poi_fused_add_1.run(buf8, buf7, 8192000, stream=stream0)
110
+ del buf8, buf7
111
+
112
+ """
113
+ # AOT ID: ['11_backward']
114
+ from ctypes import c_void_p, c_long, c_int
115
+ import torch
116
+ import math
117
+ import random
118
+ import os
119
+ import tempfile
120
+ from math import inf, nan
121
+ from cmath import nanj
122
+ from torch._inductor.hooks import run_intermediate_hooks
123
+ from torch._inductor.utils import maybe_profile
124
+ from torch._inductor.codegen.memory_planning import _align as align
125
+ from torch import device, empty_strided
126
+ from torch._inductor.async_compile import AsyncCompile
127
+ from torch._inductor.select_algorithm import extern_kernels
128
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
129
+ import triton
130
+ import triton.language as tl
131
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
132
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
133
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
134
+
135
+ aten = torch.ops.aten
136
+ inductor_ops = torch.ops.inductor
137
+ _quantized = torch.ops._quantized
138
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
139
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
140
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
141
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
142
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
143
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
144
+ async_compile = AsyncCompile()
145
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
146
+
147
+
148
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/6j/c6jqjdux4scc3alxlsrcpnhemegj7ym5pw3twg6xb2eyx4codkvz.py
149
+ # Topologically Sorted Source Nodes: [silu], Original ATen: [aten.silu, aten.mul, aten.sigmoid, aten.fill, aten.sub, aten.add]
150
+ # Source node to ATen node mapping:
151
+ # silu => convert_element_type_2, convert_element_type_3, mul_19, sigmoid
152
+ # Graph fragment:
153
+ # %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
154
+ # %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_2,), kwargs = {})
155
+ # %mul_19 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %sigmoid), kwargs = {})
156
+ # %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_19, torch.bfloat16), kwargs = {})
157
+ # %mul_61 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_7, %convert_element_type_3), kwargs = {})
158
+ # %mul_62 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_7, %view_3), kwargs = {})
159
+ # %sigmoid_1 : [num_users=2] = call_function[target=torch.ops.aten.sigmoid.default](args = (%view_1,), kwargs = {})
160
+ # %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%primals_2, %primals_3, 3072], 1), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
161
+ # %sub_16 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%full_default, %sigmoid_1), kwargs = {})
162
+ # %mul_63 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_1, %sub_16), kwargs = {})
163
+ # %add_38 : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mul_63, 1), kwargs = {})
164
+ # %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sigmoid_1, %add_38), kwargs = {})
165
+ # %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_62, %mul_64), kwargs = {})
166
+ triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0 = async_compile.triton('triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', '''
167
+ import triton
168
+ import triton.language as tl
169
+
170
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
171
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
172
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
173
+ triton_helpers.set_driver_to_gpu()
174
+
175
+ @triton_heuristics.pointwise(
176
+ size_hints={'x': 33554432},
177
+ filename=__file__,
178
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
179
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
180
+ min_elem_per_thread=0
181
+ )
182
+ @triton.jit
183
+ def triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
184
+ xoffset = tl.program_id(0) * XBLOCK
185
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
186
+ xmask = xindex < xnumel
187
+ x0 = xindex
188
+ tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
189
+ tmp1 = tl.load(in_ptr1 + (x0), xmask).to(tl.float32)
190
+ tmp7 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
191
+ tmp2 = tmp1.to(tl.float32)
192
+ tmp3 = tl.sigmoid(tmp2)
193
+ tmp4 = tmp2 * tmp3
194
+ tmp5 = tmp4.to(tl.float32)
195
+ tmp6 = tmp0 * tmp5
196
+ tmp8 = tmp0 * tmp7
197
+ tmp9 = tl.sigmoid(tmp1)
198
+ tmp10 = 1.0
199
+ tmp11 = tmp10 - tmp9
200
+ tmp12 = tmp1 * tmp11
201
+ tmp13 = tmp12 + tmp10
202
+ tmp14 = tmp9 * tmp13
203
+ tmp15 = tmp8 * tmp14
204
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
205
+ tl.store(in_out_ptr0 + (x0), tmp15, xmask)
206
+ ''', device_str='cuda')
207
+
208
+
209
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/cm/ccm7s7qxxaw3nofvzpftd6fqmd57jvnzfc74xrk54wxxhdnqddlo.py
210
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.add]
211
+ # Source node to ATen node mapping:
212
+ # Graph fragment:
213
+ # %add_39 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_9, %view_11), kwargs = {})
214
+ triton_poi_fused_add_1 = async_compile.triton('triton_poi_fused_add_1', '''
215
+ import triton
216
+ import triton.language as tl
217
+
218
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
219
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
220
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
221
+ triton_helpers.set_driver_to_gpu()
222
+
223
+ @triton_heuristics.pointwise(
224
+ size_hints={'x': 8388608},
225
+ filename=__file__,
226
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
227
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
228
+ min_elem_per_thread=0
229
+ )
230
+ @triton.jit
231
+ def triton_poi_fused_add_1(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
232
+ xoffset = tl.program_id(0) * XBLOCK
233
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
234
+ xmask = xindex < xnumel
235
+ x0 = xindex
236
+ tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
237
+ tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
238
+ tmp2 = tmp0 + tmp1
239
+ tl.store(in_out_ptr0 + (x0), tmp2, xmask)
240
+ ''', device_str='cuda')
241
+
242
+
243
+ async_compile.wait(globals())
244
+ del async_compile
245
+
246
+ def call(args):
247
+ primals_2, primals_3, mul, view, mm, mm_1, view_4, permute_5, permute_9, permute_14, tangents_1 = args
248
+ args.clear()
249
+ s0 = primals_2
250
+ s1 = primals_3
251
+ assert_size_stride(view, (s0*s1, 1024), (1024, 1))
252
+ assert_size_stride(mm, (s0*s1, 3072), (3072, 1))
253
+ assert_size_stride(mm_1, (s0*s1, 3072), (3072, 1))
254
+ assert_size_stride(view_4, (s0*s1, 3072), (3072, 1))
255
+ assert_size_stride(permute_5, (1024, 3072), (3072, 1))
256
+ assert_size_stride(permute_9, (3072, 1024), (1024, 1))
257
+ assert_size_stride(permute_14, (3072, 1024), (1024, 1))
258
+ assert_size_stride(tangents_1, (s0, s1, 1024), (1024*s1, 1024, 1))
259
+ with torch.cuda._DeviceGuard(0):
260
+ torch.cuda.set_device(0)
261
+ buf0 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16)
262
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
263
+ extern_kernels.mm(reinterpret_tensor(tangents_1, (1024, s0*s1), (1, 1024), 0), view_4, out=buf0)
264
+ del view_4
265
+ buf1 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
266
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
267
+ extern_kernels.mm(reinterpret_tensor(tangents_1, (s0*s1, 1024), (1024, 1), 0), permute_5, out=buf1)
268
+ del permute_5
269
+ del tangents_1
270
+ buf2 = empty_strided_cuda((s0, s1, 3072), (3072*s1, 3072, 1), torch.bfloat16)
271
+ buf5 = reinterpret_tensor(mm_1, (s0, s1, 3072), (3072*s1, 3072, 1), 0); del mm_1 # reuse
272
+ # Topologically Sorted Source Nodes: [silu], Original ATen: [aten.silu, aten.mul, aten.sigmoid, aten.fill, aten.sub, aten.add]
273
+ triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0_xnumel = 3072*s0*s1
274
+ stream0 = get_raw_stream(0)
275
+ triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0.run(buf5, buf1, mm, buf2, triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0_xnumel, stream=stream0)
276
+ del buf1
277
+ del mm
278
+ buf3 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16)
279
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
280
+ extern_kernels.mm(reinterpret_tensor(buf2, (3072, s0*s1), (1, 3072), 0), view, out=buf3)
281
+ buf4 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
282
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
283
+ extern_kernels.mm(reinterpret_tensor(buf2, (s0*s1, 3072), (3072, 1), 0), permute_9, out=buf4)
284
+ del buf2
285
+ del permute_9
286
+ buf6 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16)
287
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
288
+ extern_kernels.mm(reinterpret_tensor(buf5, (3072, s0*s1), (1, 3072), 0), view, out=buf6)
289
+ del view
290
+ buf7 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
291
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
292
+ extern_kernels.mm(reinterpret_tensor(buf5, (s0*s1, 3072), (3072, 1), 0), permute_14, out=buf7)
293
+ del buf5
294
+ del permute_14
295
+ buf8 = reinterpret_tensor(buf4, (s0, s1, 1024), (1024*s1, 1024, 1), 0); del buf4 # reuse
296
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.add]
297
+ triton_poi_fused_add_1_xnumel = 1024*s0*s1
298
+ stream0 = get_raw_stream(0)
299
+ triton_poi_fused_add_1.run(buf8, buf7, triton_poi_fused_add_1_xnumel, stream=stream0)
300
+ del buf7
301
+ return (buf6, None, None, buf8, buf3, buf0, )
302
+
303
+
304
+ def benchmark_compiled_module(times=10, repeat=10):
305
+ from torch._dynamo.testing import rand_strided
306
+ from torch._inductor.utils import print_performance
307
+ primals_2 = 8
308
+ primals_3 = 1000
309
+ mul = 8000
310
+ view = rand_strided((8000, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
311
+ mm = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
312
+ mm_1 = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
313
+ view_4 = rand_strided((8000, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
314
+ permute_5 = rand_strided((1024, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
315
+ permute_9 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
316
+ permute_14 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
317
+ tangents_1 = rand_strided((8, 1000, 1024), (1024000, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
318
+ fn = lambda: call([primals_2, primals_3, mul, view, mm, mm_1, view_4, permute_5, permute_9, permute_14, tangents_1])
319
+ return print_performance(fn, times=times, repeat=repeat)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ from torch._inductor.wrapper_benchmark import compiled_module_main
324
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor_ch-epfl-345354-j/a6/ca64rxymdowafnowfq53ckfynl3yei5mmfkeefu6f6xndlg3ukok.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile-time auto-tuning block:
3
+
4
+ import torch
5
+ from torch._dynamo.testing import rand_strided
6
+ from torch._dynamo.utils import preserve_rng_state
7
+ from torch._inductor.select_algorithm import AlgorithmSelectorCache
8
+ from torch._inductor.async_compile import AsyncCompile
9
+
10
+ async_compile = AsyncCompile()
11
+ generate_example_value = AlgorithmSelectorCache.generate_example_value
12
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
13
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
14
+
15
+
16
+ triton_poi_fused_mul_silu_0 = async_compile.triton('triton_poi_fused_mul_silu_0', '''
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
21
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
22
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
23
+ triton_helpers.set_driver_to_gpu()
24
+
25
+ @triton_heuristics.pointwise(
26
+ size_hints={'x': 33554432},
27
+ filename=__file__,
28
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
29
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
30
+ min_elem_per_thread=0
31
+ )
32
+ @triton.jit
33
+ def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
34
+ xoffset = tl.program_id(0) * XBLOCK
35
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
36
+ xmask = xindex < xnumel
37
+ x0 = xindex
38
+ tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
39
+ tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
40
+ tmp1 = tmp0.to(tl.float32)
41
+ tmp2 = tl.sigmoid(tmp1)
42
+ tmp3 = tmp1 * tmp2
43
+ tmp4 = tmp3.to(tl.float32)
44
+ tmp6 = tmp4 * tmp5
45
+ tl.store(in_out_ptr0 + (x0), tmp6, xmask)
46
+ ''', device_str='cuda')
47
+
48
+ async_compile.wait(globals())
49
+ del async_compile
50
+
51
+ import triton
52
+ import triton.language as tl
53
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
54
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
55
+ with torch.cuda._DeviceGuard(0):
56
+ torch.cuda.set_device(0)
57
+ stream0 = get_raw_stream(0)
58
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
59
+ stream0 = get_raw_stream(0)
60
+ buf2 = generate_example_value((8, 1000, 3072), (3072000, 3072, 1), 'cuda:0', torch.bfloat16, 0, (8, 1000, 3072))
61
+ buf1 = generate_example_value((8000, 3072), (3072, 1), 'cuda:0', torch.bfloat16, 0, (8000, 3072))
62
+ triton_poi_fused_mul_silu_0.run(buf2, buf1, 24576000, stream=stream0)
63
+ del buf2, buf1
64
+
65
+ """
66
+ # AOT ID: ['5_inference']
67
+ from ctypes import c_void_p, c_long, c_int
68
+ import torch
69
+ import math
70
+ import random
71
+ import os
72
+ import tempfile
73
+ from math import inf, nan
74
+ from cmath import nanj
75
+ from torch._inductor.hooks import run_intermediate_hooks
76
+ from torch._inductor.utils import maybe_profile
77
+ from torch._inductor.codegen.memory_planning import _align as align
78
+ from torch import device, empty_strided
79
+ from torch._inductor.async_compile import AsyncCompile
80
+ from torch._inductor.select_algorithm import extern_kernels
81
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
82
+ import triton
83
+ import triton.language as tl
84
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
85
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
86
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
87
+
88
+ aten = torch.ops.aten
89
+ inductor_ops = torch.ops.inductor
90
+ _quantized = torch.ops._quantized
91
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
92
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
93
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
94
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
95
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
96
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
97
+ async_compile = AsyncCompile()
98
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
99
+
100
+
101
+ # kernel path: /tmp/torchinductor_ch-epfl-345354-j/57/c574kngiopy3pgespyoupnzlae4d5tokyeui7uglwglnym2qijvn.py
102
+ # Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul]
103
+ # Source node to ATen node mapping:
104
+ # mul => mul_38
105
+ # silu => convert_element_type_2, convert_element_type_3, mul_19, sigmoid
106
+ # Graph fragment:
107
+ # %convert_element_type_2 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
108
+ # %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_2,), kwargs = {})
109
+ # %mul_19 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %sigmoid), kwargs = {})
110
+ # %convert_element_type_3 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_19, torch.bfloat16), kwargs = {})
111
+ # %mul_38 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_3, %view_3), kwargs = {})
112
+ triton_poi_fused_mul_silu_0 = async_compile.triton('triton_poi_fused_mul_silu_0', '''
113
+ import triton
114
+ import triton.language as tl
115
+
116
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
117
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
118
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
119
+ triton_helpers.set_driver_to_gpu()
120
+
121
+ @triton_heuristics.pointwise(
122
+ size_hints={'x': 33554432},
123
+ filename=__file__,
124
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
125
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
126
+ min_elem_per_thread=0
127
+ )
128
+ @triton.jit
129
+ def triton_poi_fused_mul_silu_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
130
+ xoffset = tl.program_id(0) * XBLOCK
131
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
132
+ xmask = xindex < xnumel
133
+ x0 = xindex
134
+ tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
135
+ tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
136
+ tmp1 = tmp0.to(tl.float32)
137
+ tmp2 = tl.sigmoid(tmp1)
138
+ tmp3 = tmp1 * tmp2
139
+ tmp4 = tmp3.to(tl.float32)
140
+ tmp6 = tmp4 * tmp5
141
+ tl.store(in_out_ptr0 + (x0), tmp6, xmask)
142
+ ''', device_str='cuda')
143
+
144
+
145
+ async_compile.wait(globals())
146
+ del async_compile
147
+
148
+ def call(args):
149
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
150
+ args.clear()
151
+ s0 = arg1_1
152
+ s1 = arg2_1
153
+ assert_size_stride(arg0_1, (3072, 1024), (1024, 1))
154
+ assert_size_stride(arg3_1, (s0, s1, 1024), (1024*s1, 1024, 1))
155
+ assert_size_stride(arg4_1, (3072, 1024), (1024, 1))
156
+ assert_size_stride(arg5_1, (1024, 3072), (3072, 1))
157
+ with torch.cuda._DeviceGuard(0):
158
+ torch.cuda.set_device(0)
159
+ pool1 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
160
+ buf0 = pool1 # alloc
161
+ # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
162
+ extern_kernels.mm(reinterpret_tensor(arg3_1, (s0*s1, 1024), (1024, 1), 0), reinterpret_tensor(arg0_1, (1024, 3072), (1, 1024), 0), out=buf0)
163
+ del arg0_1
164
+ pool2 = empty_strided_cuda((s0*s1, 3072), (3072, 1), torch.bfloat16)
165
+ buf1 = pool2 # alloc
166
+ # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
167
+ extern_kernels.mm(reinterpret_tensor(arg3_1, (s0*s1, 1024), (1024, 1), 0), reinterpret_tensor(arg4_1, (1024, 3072), (1, 1024), 0), out=buf1)
168
+ del arg3_1
169
+ del arg4_1
170
+ buf2 = reinterpret_tensor(buf0, (s0, s1, 3072), (3072*s1, 3072, 1), 0); # reuse
171
+ # Topologically Sorted Source Nodes: [silu, mul], Original ATen: [aten.silu, aten.mul]
172
+ triton_poi_fused_mul_silu_0_xnumel = 3072*s0*s1
173
+ stream0 = get_raw_stream(0)
174
+ triton_poi_fused_mul_silu_0.run(buf2, buf1, triton_poi_fused_mul_silu_0_xnumel, stream=stream0)
175
+ del pool2, buf1
176
+ pool0 = empty_strided_cuda((s0*s1, 1024), (1024, 1), torch.bfloat16)
177
+ buf3 = pool0 # alloc
178
+ # Topologically Sorted Source Nodes: [down_proj], Original ATen: [aten.mm]
179
+ extern_kernels.mm(reinterpret_tensor(buf2, (s0*s1, 3072), (3072, 1), 0), reinterpret_tensor(arg5_1, (3072, 1024), (1, 3072), 0), out=buf3)
180
+ del arg5_1
181
+ del pool1, buf0, buf2
182
+ return (reinterpret_tensor(buf3, (s0, s1, 1024), (1024*s1, 1024, 1), 0), )
183
+
184
+
185
+ def benchmark_compiled_module(times=10, repeat=10):
186
+ from torch._dynamo.testing import rand_strided
187
+ from torch._inductor.utils import print_performance
188
+ arg0_1 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
189
+ arg1_1 = 8
190
+ arg2_1 = 1000
191
+ arg3_1 = rand_strided((8, 1000, 1024), (1024000, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
192
+ arg4_1 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
193
+ arg5_1 = rand_strided((1024, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
194
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
195
+ return print_performance(fn, times=times, repeat=repeat)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ from torch._inductor.wrapper_benchmark import compiled_module_main
200
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor_ch-epfl-345354-j/aotautograd/acxk7xhb35e5myvrfk4m2smos5f3rwybegalnbqbgtl3ghlaw2vw/entry ADDED
Binary file (96.7 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/adw7o5w6jucvlwdu4mn3nk52nno5z3lt73pmvaksrn3cahxlwc5t/entry ADDED
Binary file (5.55 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/agm67xcx3b2ejeqf3t422b43zsalmtzgitagqmb4kcd76dzg2sr6/entry ADDED
Binary file (15.3 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/ahinqqlnserz457jqclv2vjeogmqix7jcrylpuhbc64kw4k3apfy/entry ADDED
Binary file (7.48 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/ahji7b2arusm47q6ox5itjvurtws6r6kls2kskgxfnc2rqm4ojdg/entry ADDED
Binary file (4.83 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/aibbpzcrlnv7lrbehiaaab4olrvijekv6m46vdzzqh3tbnvnl67m/entry ADDED
Binary file (26.4 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/aig3hpjgzj7f27hhdphh7ozndqiwpruhugzjsiwyog75fn4y3rbj/entry ADDED
Binary file (7.63 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/algm7vsngjdke6rmqon76peuppnhsp625k5d4zxnwgwdbdueo4ay/entry ADDED
Binary file (4.37 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/amtpnp6cq6z6ddoun3fwe4zemhgpsp5jicklj6cf3qzsd3xbdeps/entry ADDED
Binary file (4.65 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/aqskia64x2j4xks7dhp5cpq52le5j6js6ghxfhlvw7gfa6qr6stx/entry ADDED
Binary file (16.2 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/aub73aicaqeihl4qdqbrvljzl2qxdzyu52zezeket676qt3pkgwk/entry ADDED
Binary file (4.65 kB). View file
 
torchinductor_ch-epfl-345354-j/aotautograd/aypob3g4nwzt66m7ur252rhjjobqgnn4hvdhagr4474twkamikxg/entry ADDED
Binary file (18.9 kB). View file
 
torchinductor_ch-epfl-345354-j/bm/cbmn253c3hy77ciw3f6meqi4bsbiio5zhw7hra5np6k5jyjqetnp.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 1},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*i64', 'out_ptr0': '*fp32', 'xnumel': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_div_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ tmp0 = tl.load(in_ptr0 + (0))
24
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
25
+ tmp2 = tl.load(in_ptr1 + (0))
26
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
27
+ tmp4 = tmp3.to(tl.float32)
28
+ tmp5 = (tmp1 / tmp4)
29
+ tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp5, None)
torchinductor_ch-epfl-345354-j/ce/cceyvpztlniy45jdq6sxx7o44obzjinfuxgsvnlhcr3hjdvmek73.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 1024, 'r0_': 64},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 1024
20
+ r0_numel = 40
21
+ R0_BLOCK: tl.constexpr = 64
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = r0_index < r0_numel
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (x0 + 1024*r0_1), xmask & r0_mask, other=0.0)
35
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
36
+ tmp3 = tl.where(r0_mask & xmask, tmp1, 0)
37
+ tmp4 = tl.sum(tmp3, 1)[:, None]
38
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
torchinductor_ch-epfl-345354-j/cf/ae1632ffa009afdc4d40d5477a8e2ffd544972ad9ddf0c636c451826b3219579.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 108}
torchinductor_ch-epfl-345354-j/cf/ccfnt2f53rlwauznvnabnitvjchbzg7at22w4x4fskqzmyirxuxq.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=42, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '8D9A40F96256AE993B0CB3DAC1136935BA540F7848683690590C84AF795CC5ED', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ R0_BLOCK: tl.constexpr = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), xmask, other=0.0).to(tl.float32)
34
+ tmp12 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
35
+ tmp1 = tmp0.to(tl.float32)
36
+ tmp2 = tmp1 * tmp1
37
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
38
+ tmp5 = tl.where(xmask, tmp3, 0)
39
+ tmp6 = tl.sum(tmp5, 1)[:, None]
40
+ tmp7 = 128.0
41
+ tmp8 = (tmp6 / tmp7)
42
+ tmp9 = 1e-06
43
+ tmp10 = tmp8 + tmp9
44
+ tmp11 = libdevice.rsqrt(tmp10)
45
+ tmp13 = tmp1 * tmp11
46
+ tmp14 = tmp13.to(tl.float32)
47
+ tmp15 = tmp12 * tmp14
48
+ tl.debug_barrier()
49
+ tl.store(in_out_ptr0 + (x0), tmp11, xmask)
50
+ tl.store(out_ptr0 + (r0_1 + 128*x0), tmp15, xmask)