diff --git "a/openai_whisper-tiny.en/AudioEncoder.mlmodelc/model.mil" "b/openai_whisper-tiny.en/AudioEncoder.mlmodelc/model.mil" --- "a/openai_whisper-tiny.en/AudioEncoder.mlmodelc/model.mil" +++ "b/openai_whisper-tiny.en/AudioEncoder.mlmodelc/model.mil" @@ -1,5 +1,5 @@ program(1.0) -[buildInfo = dict, tensor>({{"coremlc-component-MIL", "5.33.5"}, {"coremlc-version", "1877.40.3"}, {"coremltools-component-torch", "2.3.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "7.2"}})] +[buildInfo = dict, tensor>({{"coremlc-component-MIL", "5.33.5"}, {"coremlc-version", "1877.40.3"}, {"coremltools-component-torch", "2.3.1"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "7.2"}})] { func main(tensor melspectrogram_features) { tensor var_34 = const()[name = tensor("op_34"), val = tensor([1, 1])]; @@ -225,102 +225,102 @@ program(1.0) tensor var_421_end_0 = const()[name = tensor("op_421_end_0"), val = tensor([1, 384, 1, 1500])]; tensor var_421_end_mask_0 = const()[name = tensor("op_421_end_mask_0"), val = tensor([true, false, true, true])]; tensor var_421_cast_fp16 = slice_by_index(begin = var_421_begin_0, end = var_421_end_0, end_mask = var_421_end_mask_0, x = value_1_cast_fp16)[name = tensor("op_421_cast_fp16")]; - tensor var_425_equation_0 = const()[name = tensor("op_425_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_425_cast_fp16 = einsum(equation = var_425_equation_0, values = (var_379_cast_fp16, var_213_cast_fp16))[name = tensor("op_425_cast_fp16")]; - tensor var_426_to_fp16 = const()[name = tensor("op_426_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_1_cast_fp16 = mul(x = var_425_cast_fp16, y = var_426_to_fp16)[name = tensor("aw_chunk_1_cast_fp16")]; - tensor var_429_equation_0 = const()[name = tensor("op_429_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_429_cast_fp16 = einsum(equation = var_429_equation_0, values = (var_379_cast_fp16, var_220_cast_fp16))[name = tensor("op_429_cast_fp16")]; - tensor var_430_to_fp16 = const()[name = tensor("op_430_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_3_cast_fp16 = mul(x = var_429_cast_fp16, y = var_430_to_fp16)[name = tensor("aw_chunk_3_cast_fp16")]; - tensor var_433_equation_0 = const()[name = tensor("op_433_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_433_cast_fp16 = einsum(equation = var_433_equation_0, values = (var_379_cast_fp16, var_227_cast_fp16))[name = tensor("op_433_cast_fp16")]; - tensor var_434_to_fp16 = const()[name = tensor("op_434_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_5_cast_fp16 = mul(x = var_433_cast_fp16, y = var_434_to_fp16)[name = tensor("aw_chunk_5_cast_fp16")]; - tensor var_437_equation_0 = const()[name = tensor("op_437_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_437_cast_fp16 = einsum(equation = var_437_equation_0, values = (var_379_cast_fp16, var_234_cast_fp16))[name = tensor("op_437_cast_fp16")]; - tensor var_438_to_fp16 = const()[name = tensor("op_438_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_7_cast_fp16 = mul(x = var_437_cast_fp16, y = var_438_to_fp16)[name = tensor("aw_chunk_7_cast_fp16")]; - tensor var_441_equation_0 = const()[name = tensor("op_441_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_441_cast_fp16 = einsum(equation = var_441_equation_0, values = (var_383_cast_fp16, var_241_cast_fp16))[name = tensor("op_441_cast_fp16")]; - tensor var_442_to_fp16 = const()[name = tensor("op_442_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_9_cast_fp16 = mul(x = var_441_cast_fp16, y = var_442_to_fp16)[name = tensor("aw_chunk_9_cast_fp16")]; - tensor var_445_equation_0 = const()[name = tensor("op_445_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_445_cast_fp16 = einsum(equation = var_445_equation_0, values = (var_383_cast_fp16, var_248_cast_fp16))[name = tensor("op_445_cast_fp16")]; - tensor var_446_to_fp16 = const()[name = tensor("op_446_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_11_cast_fp16 = mul(x = var_445_cast_fp16, y = var_446_to_fp16)[name = tensor("aw_chunk_11_cast_fp16")]; - tensor var_449_equation_0 = const()[name = tensor("op_449_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_449_cast_fp16 = einsum(equation = var_449_equation_0, values = (var_383_cast_fp16, var_255_cast_fp16))[name = tensor("op_449_cast_fp16")]; - tensor var_450_to_fp16 = const()[name = tensor("op_450_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_13_cast_fp16 = mul(x = var_449_cast_fp16, y = var_450_to_fp16)[name = tensor("aw_chunk_13_cast_fp16")]; - tensor var_453_equation_0 = const()[name = tensor("op_453_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_453_cast_fp16 = einsum(equation = var_453_equation_0, values = (var_383_cast_fp16, var_262_cast_fp16))[name = tensor("op_453_cast_fp16")]; - tensor var_454_to_fp16 = const()[name = tensor("op_454_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_15_cast_fp16 = mul(x = var_453_cast_fp16, y = var_454_to_fp16)[name = tensor("aw_chunk_15_cast_fp16")]; - tensor var_457_equation_0 = const()[name = tensor("op_457_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_457_cast_fp16 = einsum(equation = var_457_equation_0, values = (var_387_cast_fp16, var_269_cast_fp16))[name = tensor("op_457_cast_fp16")]; - tensor var_458_to_fp16 = const()[name = tensor("op_458_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_17_cast_fp16 = mul(x = var_457_cast_fp16, y = var_458_to_fp16)[name = tensor("aw_chunk_17_cast_fp16")]; - tensor var_461_equation_0 = const()[name = tensor("op_461_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_461_cast_fp16 = einsum(equation = var_461_equation_0, values = (var_387_cast_fp16, var_276_cast_fp16))[name = tensor("op_461_cast_fp16")]; - tensor var_462_to_fp16 = const()[name = tensor("op_462_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_19_cast_fp16 = mul(x = var_461_cast_fp16, y = var_462_to_fp16)[name = tensor("aw_chunk_19_cast_fp16")]; - tensor var_465_equation_0 = const()[name = tensor("op_465_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_465_cast_fp16 = einsum(equation = var_465_equation_0, values = (var_387_cast_fp16, var_283_cast_fp16))[name = tensor("op_465_cast_fp16")]; - tensor var_466_to_fp16 = const()[name = tensor("op_466_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_21_cast_fp16 = mul(x = var_465_cast_fp16, y = var_466_to_fp16)[name = tensor("aw_chunk_21_cast_fp16")]; - tensor var_469_equation_0 = const()[name = tensor("op_469_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_469_cast_fp16 = einsum(equation = var_469_equation_0, values = (var_387_cast_fp16, var_290_cast_fp16))[name = tensor("op_469_cast_fp16")]; - tensor var_470_to_fp16 = const()[name = tensor("op_470_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_23_cast_fp16 = mul(x = var_469_cast_fp16, y = var_470_to_fp16)[name = tensor("aw_chunk_23_cast_fp16")]; - tensor var_473_equation_0 = const()[name = tensor("op_473_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_473_cast_fp16 = einsum(equation = var_473_equation_0, values = (var_391_cast_fp16, var_297_cast_fp16))[name = tensor("op_473_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_1_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_1_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_1_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_1_equation_0, values = (var_379_cast_fp16, var_213_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_1_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_3_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_3_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_3_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_3_equation_0, values = (var_379_cast_fp16, var_220_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_3_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_5_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_5_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_5_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_5_equation_0, values = (var_379_cast_fp16, var_227_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_5_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_7_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_7_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_7_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_7_equation_0, values = (var_379_cast_fp16, var_234_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_7_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_9_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_9_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_9_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_9_equation_0, values = (var_383_cast_fp16, var_241_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_9_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_11_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_11_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_11_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_11_equation_0, values = (var_383_cast_fp16, var_248_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_11_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_13_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_13_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_13_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_13_equation_0, values = (var_383_cast_fp16, var_255_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_13_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_15_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_15_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_15_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_15_equation_0, values = (var_383_cast_fp16, var_262_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_15_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_17_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_17_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_17_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_17_equation_0, values = (var_387_cast_fp16, var_269_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_17_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_19_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_19_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_19_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_19_equation_0, values = (var_387_cast_fp16, var_276_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_19_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_21_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_21_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_21_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_21_equation_0, values = (var_387_cast_fp16, var_283_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_21_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_23_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_23_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_23_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_23_equation_0, values = (var_387_cast_fp16, var_290_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_23_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_25_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_25_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_25_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_25_equation_0, values = (var_391_cast_fp16, var_297_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_25_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_27_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_27_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_27_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_27_equation_0, values = (var_391_cast_fp16, var_304_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_27_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_29_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_29_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_29_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_29_equation_0, values = (var_391_cast_fp16, var_311_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_29_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_31_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_31_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_31_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_31_equation_0, values = (var_391_cast_fp16, var_318_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_31_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_33_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_33_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_33_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_33_equation_0, values = (var_395_cast_fp16, var_325_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_33_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_35_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_35_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_35_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_35_equation_0, values = (var_395_cast_fp16, var_332_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_35_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_37_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_37_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_37_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_37_equation_0, values = (var_395_cast_fp16, var_339_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_37_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_39_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_39_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_39_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_39_equation_0, values = (var_395_cast_fp16, var_346_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_39_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_41_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_41_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_41_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_41_equation_0, values = (var_399_cast_fp16, var_353_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_41_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_43_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_43_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_43_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_43_equation_0, values = (var_399_cast_fp16, var_360_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_43_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_45_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_45_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_45_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_45_equation_0, values = (var_399_cast_fp16, var_367_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_45_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_47_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_47_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_47_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_47_equation_0, values = (var_399_cast_fp16, var_374_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_47_cast_fp16")]; + tensor var_472_to_fp16 = const()[name = tensor("op_472_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_1_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_1_cast_fp16, y = var_472_to_fp16)[name = tensor("aw_chunk_1_cast_fp16")]; tensor var_474_to_fp16 = const()[name = tensor("op_474_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_25_cast_fp16 = mul(x = var_473_cast_fp16, y = var_474_to_fp16)[name = tensor("aw_chunk_25_cast_fp16")]; - tensor var_477_equation_0 = const()[name = tensor("op_477_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_477_cast_fp16 = einsum(equation = var_477_equation_0, values = (var_391_cast_fp16, var_304_cast_fp16))[name = tensor("op_477_cast_fp16")]; + tensor aw_chunk_3_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_3_cast_fp16, y = var_474_to_fp16)[name = tensor("aw_chunk_3_cast_fp16")]; + tensor var_476_to_fp16 = const()[name = tensor("op_476_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_5_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_5_cast_fp16, y = var_476_to_fp16)[name = tensor("aw_chunk_5_cast_fp16")]; tensor var_478_to_fp16 = const()[name = tensor("op_478_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_27_cast_fp16 = mul(x = var_477_cast_fp16, y = var_478_to_fp16)[name = tensor("aw_chunk_27_cast_fp16")]; - tensor var_481_equation_0 = const()[name = tensor("op_481_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_481_cast_fp16 = einsum(equation = var_481_equation_0, values = (var_391_cast_fp16, var_311_cast_fp16))[name = tensor("op_481_cast_fp16")]; + tensor aw_chunk_7_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_7_cast_fp16, y = var_478_to_fp16)[name = tensor("aw_chunk_7_cast_fp16")]; + tensor var_480_to_fp16 = const()[name = tensor("op_480_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_9_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_9_cast_fp16, y = var_480_to_fp16)[name = tensor("aw_chunk_9_cast_fp16")]; tensor var_482_to_fp16 = const()[name = tensor("op_482_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_29_cast_fp16 = mul(x = var_481_cast_fp16, y = var_482_to_fp16)[name = tensor("aw_chunk_29_cast_fp16")]; - tensor var_485_equation_0 = const()[name = tensor("op_485_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_485_cast_fp16 = einsum(equation = var_485_equation_0, values = (var_391_cast_fp16, var_318_cast_fp16))[name = tensor("op_485_cast_fp16")]; + tensor aw_chunk_11_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_11_cast_fp16, y = var_482_to_fp16)[name = tensor("aw_chunk_11_cast_fp16")]; + tensor var_484_to_fp16 = const()[name = tensor("op_484_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_13_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_13_cast_fp16, y = var_484_to_fp16)[name = tensor("aw_chunk_13_cast_fp16")]; tensor var_486_to_fp16 = const()[name = tensor("op_486_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_31_cast_fp16 = mul(x = var_485_cast_fp16, y = var_486_to_fp16)[name = tensor("aw_chunk_31_cast_fp16")]; - tensor var_489_equation_0 = const()[name = tensor("op_489_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_489_cast_fp16 = einsum(equation = var_489_equation_0, values = (var_395_cast_fp16, var_325_cast_fp16))[name = tensor("op_489_cast_fp16")]; + tensor aw_chunk_15_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_15_cast_fp16, y = var_486_to_fp16)[name = tensor("aw_chunk_15_cast_fp16")]; + tensor var_488_to_fp16 = const()[name = tensor("op_488_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_17_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_17_cast_fp16, y = var_488_to_fp16)[name = tensor("aw_chunk_17_cast_fp16")]; tensor var_490_to_fp16 = const()[name = tensor("op_490_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_33_cast_fp16 = mul(x = var_489_cast_fp16, y = var_490_to_fp16)[name = tensor("aw_chunk_33_cast_fp16")]; - tensor var_493_equation_0 = const()[name = tensor("op_493_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_493_cast_fp16 = einsum(equation = var_493_equation_0, values = (var_395_cast_fp16, var_332_cast_fp16))[name = tensor("op_493_cast_fp16")]; + tensor aw_chunk_19_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_19_cast_fp16, y = var_490_to_fp16)[name = tensor("aw_chunk_19_cast_fp16")]; + tensor var_492_to_fp16 = const()[name = tensor("op_492_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_21_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_21_cast_fp16, y = var_492_to_fp16)[name = tensor("aw_chunk_21_cast_fp16")]; tensor var_494_to_fp16 = const()[name = tensor("op_494_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_35_cast_fp16 = mul(x = var_493_cast_fp16, y = var_494_to_fp16)[name = tensor("aw_chunk_35_cast_fp16")]; - tensor var_497_equation_0 = const()[name = tensor("op_497_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_497_cast_fp16 = einsum(equation = var_497_equation_0, values = (var_395_cast_fp16, var_339_cast_fp16))[name = tensor("op_497_cast_fp16")]; + tensor aw_chunk_23_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_23_cast_fp16, y = var_494_to_fp16)[name = tensor("aw_chunk_23_cast_fp16")]; + tensor var_496_to_fp16 = const()[name = tensor("op_496_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_25_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_25_cast_fp16, y = var_496_to_fp16)[name = tensor("aw_chunk_25_cast_fp16")]; tensor var_498_to_fp16 = const()[name = tensor("op_498_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_37_cast_fp16 = mul(x = var_497_cast_fp16, y = var_498_to_fp16)[name = tensor("aw_chunk_37_cast_fp16")]; - tensor var_501_equation_0 = const()[name = tensor("op_501_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_501_cast_fp16 = einsum(equation = var_501_equation_0, values = (var_395_cast_fp16, var_346_cast_fp16))[name = tensor("op_501_cast_fp16")]; + tensor aw_chunk_27_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_27_cast_fp16, y = var_498_to_fp16)[name = tensor("aw_chunk_27_cast_fp16")]; + tensor var_500_to_fp16 = const()[name = tensor("op_500_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_29_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_29_cast_fp16, y = var_500_to_fp16)[name = tensor("aw_chunk_29_cast_fp16")]; tensor var_502_to_fp16 = const()[name = tensor("op_502_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_39_cast_fp16 = mul(x = var_501_cast_fp16, y = var_502_to_fp16)[name = tensor("aw_chunk_39_cast_fp16")]; - tensor var_505_equation_0 = const()[name = tensor("op_505_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_505_cast_fp16 = einsum(equation = var_505_equation_0, values = (var_399_cast_fp16, var_353_cast_fp16))[name = tensor("op_505_cast_fp16")]; + tensor aw_chunk_31_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_31_cast_fp16, y = var_502_to_fp16)[name = tensor("aw_chunk_31_cast_fp16")]; + tensor var_504_to_fp16 = const()[name = tensor("op_504_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_33_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_33_cast_fp16, y = var_504_to_fp16)[name = tensor("aw_chunk_33_cast_fp16")]; tensor var_506_to_fp16 = const()[name = tensor("op_506_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_41_cast_fp16 = mul(x = var_505_cast_fp16, y = var_506_to_fp16)[name = tensor("aw_chunk_41_cast_fp16")]; - tensor var_509_equation_0 = const()[name = tensor("op_509_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_509_cast_fp16 = einsum(equation = var_509_equation_0, values = (var_399_cast_fp16, var_360_cast_fp16))[name = tensor("op_509_cast_fp16")]; + tensor aw_chunk_35_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_35_cast_fp16, y = var_506_to_fp16)[name = tensor("aw_chunk_35_cast_fp16")]; + tensor var_508_to_fp16 = const()[name = tensor("op_508_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_37_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_37_cast_fp16, y = var_508_to_fp16)[name = tensor("aw_chunk_37_cast_fp16")]; tensor var_510_to_fp16 = const()[name = tensor("op_510_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_43_cast_fp16 = mul(x = var_509_cast_fp16, y = var_510_to_fp16)[name = tensor("aw_chunk_43_cast_fp16")]; - tensor var_513_equation_0 = const()[name = tensor("op_513_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_513_cast_fp16 = einsum(equation = var_513_equation_0, values = (var_399_cast_fp16, var_367_cast_fp16))[name = tensor("op_513_cast_fp16")]; + tensor aw_chunk_39_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_39_cast_fp16, y = var_510_to_fp16)[name = tensor("aw_chunk_39_cast_fp16")]; + tensor var_512_to_fp16 = const()[name = tensor("op_512_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_41_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_41_cast_fp16, y = var_512_to_fp16)[name = tensor("aw_chunk_41_cast_fp16")]; tensor var_514_to_fp16 = const()[name = tensor("op_514_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_45_cast_fp16 = mul(x = var_513_cast_fp16, y = var_514_to_fp16)[name = tensor("aw_chunk_45_cast_fp16")]; - tensor var_517_equation_0 = const()[name = tensor("op_517_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_517_cast_fp16 = einsum(equation = var_517_equation_0, values = (var_399_cast_fp16, var_374_cast_fp16))[name = tensor("op_517_cast_fp16")]; + tensor aw_chunk_43_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_43_cast_fp16, y = var_514_to_fp16)[name = tensor("aw_chunk_43_cast_fp16")]; + tensor var_516_to_fp16 = const()[name = tensor("op_516_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_45_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_45_cast_fp16, y = var_516_to_fp16)[name = tensor("aw_chunk_45_cast_fp16")]; tensor var_518_to_fp16 = const()[name = tensor("op_518_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_47_cast_fp16 = mul(x = var_517_cast_fp16, y = var_518_to_fp16)[name = tensor("aw_chunk_47_cast_fp16")]; + tensor aw_chunk_47_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_47_cast_fp16, y = var_518_to_fp16)[name = tensor("aw_chunk_47_cast_fp16")]; tensor var_520_cast_fp16 = softmax(axis = var_129, x = aw_chunk_1_cast_fp16)[name = tensor("op_520_cast_fp16")]; tensor var_521_cast_fp16 = softmax(axis = var_129, x = aw_chunk_3_cast_fp16)[name = tensor("op_521_cast_fp16")]; tensor var_522_cast_fp16 = softmax(axis = var_129, x = aw_chunk_5_cast_fp16)[name = tensor("op_522_cast_fp16")]; @@ -638,102 +638,102 @@ program(1.0) tensor var_954_end_0 = const()[name = tensor("op_954_end_0"), val = tensor([1, 384, 1, 1500])]; tensor var_954_end_mask_0 = const()[name = tensor("op_954_end_mask_0"), val = tensor([true, false, true, true])]; tensor var_954_cast_fp16 = slice_by_index(begin = var_954_begin_0, end = var_954_end_0, end_mask = var_954_end_mask_0, x = value_3_cast_fp16)[name = tensor("op_954_cast_fp16")]; - tensor var_958_equation_0 = const()[name = tensor("op_958_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_958_cast_fp16 = einsum(equation = var_958_equation_0, values = (var_912_cast_fp16, var_746_cast_fp16))[name = tensor("op_958_cast_fp16")]; - tensor var_959_to_fp16 = const()[name = tensor("op_959_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_49_cast_fp16 = mul(x = var_958_cast_fp16, y = var_959_to_fp16)[name = tensor("aw_chunk_49_cast_fp16")]; - tensor var_962_equation_0 = const()[name = tensor("op_962_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_962_cast_fp16 = einsum(equation = var_962_equation_0, values = (var_912_cast_fp16, var_753_cast_fp16))[name = tensor("op_962_cast_fp16")]; - tensor var_963_to_fp16 = const()[name = tensor("op_963_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_51_cast_fp16 = mul(x = var_962_cast_fp16, y = var_963_to_fp16)[name = tensor("aw_chunk_51_cast_fp16")]; - tensor var_966_equation_0 = const()[name = tensor("op_966_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_966_cast_fp16 = einsum(equation = var_966_equation_0, values = (var_912_cast_fp16, var_760_cast_fp16))[name = tensor("op_966_cast_fp16")]; - tensor var_967_to_fp16 = const()[name = tensor("op_967_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_53_cast_fp16 = mul(x = var_966_cast_fp16, y = var_967_to_fp16)[name = tensor("aw_chunk_53_cast_fp16")]; - tensor var_970_equation_0 = const()[name = tensor("op_970_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_970_cast_fp16 = einsum(equation = var_970_equation_0, values = (var_912_cast_fp16, var_767_cast_fp16))[name = tensor("op_970_cast_fp16")]; - tensor var_971_to_fp16 = const()[name = tensor("op_971_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_55_cast_fp16 = mul(x = var_970_cast_fp16, y = var_971_to_fp16)[name = tensor("aw_chunk_55_cast_fp16")]; - tensor var_974_equation_0 = const()[name = tensor("op_974_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_974_cast_fp16 = einsum(equation = var_974_equation_0, values = (var_916_cast_fp16, var_774_cast_fp16))[name = tensor("op_974_cast_fp16")]; - tensor var_975_to_fp16 = const()[name = tensor("op_975_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_57_cast_fp16 = mul(x = var_974_cast_fp16, y = var_975_to_fp16)[name = tensor("aw_chunk_57_cast_fp16")]; - tensor var_978_equation_0 = const()[name = tensor("op_978_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_978_cast_fp16 = einsum(equation = var_978_equation_0, values = (var_916_cast_fp16, var_781_cast_fp16))[name = tensor("op_978_cast_fp16")]; - tensor var_979_to_fp16 = const()[name = tensor("op_979_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_59_cast_fp16 = mul(x = var_978_cast_fp16, y = var_979_to_fp16)[name = tensor("aw_chunk_59_cast_fp16")]; - tensor var_982_equation_0 = const()[name = tensor("op_982_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_982_cast_fp16 = einsum(equation = var_982_equation_0, values = (var_916_cast_fp16, var_788_cast_fp16))[name = tensor("op_982_cast_fp16")]; - tensor var_983_to_fp16 = const()[name = tensor("op_983_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_61_cast_fp16 = mul(x = var_982_cast_fp16, y = var_983_to_fp16)[name = tensor("aw_chunk_61_cast_fp16")]; - tensor var_986_equation_0 = const()[name = tensor("op_986_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_986_cast_fp16 = einsum(equation = var_986_equation_0, values = (var_916_cast_fp16, var_795_cast_fp16))[name = tensor("op_986_cast_fp16")]; - tensor var_987_to_fp16 = const()[name = tensor("op_987_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_63_cast_fp16 = mul(x = var_986_cast_fp16, y = var_987_to_fp16)[name = tensor("aw_chunk_63_cast_fp16")]; - tensor var_990_equation_0 = const()[name = tensor("op_990_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_990_cast_fp16 = einsum(equation = var_990_equation_0, values = (var_920_cast_fp16, var_802_cast_fp16))[name = tensor("op_990_cast_fp16")]; - tensor var_991_to_fp16 = const()[name = tensor("op_991_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_65_cast_fp16 = mul(x = var_990_cast_fp16, y = var_991_to_fp16)[name = tensor("aw_chunk_65_cast_fp16")]; - tensor var_994_equation_0 = const()[name = tensor("op_994_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_994_cast_fp16 = einsum(equation = var_994_equation_0, values = (var_920_cast_fp16, var_809_cast_fp16))[name = tensor("op_994_cast_fp16")]; - tensor var_995_to_fp16 = const()[name = tensor("op_995_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_67_cast_fp16 = mul(x = var_994_cast_fp16, y = var_995_to_fp16)[name = tensor("aw_chunk_67_cast_fp16")]; - tensor var_998_equation_0 = const()[name = tensor("op_998_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_998_cast_fp16 = einsum(equation = var_998_equation_0, values = (var_920_cast_fp16, var_816_cast_fp16))[name = tensor("op_998_cast_fp16")]; - tensor var_999_to_fp16 = const()[name = tensor("op_999_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_69_cast_fp16 = mul(x = var_998_cast_fp16, y = var_999_to_fp16)[name = tensor("aw_chunk_69_cast_fp16")]; - tensor var_1002_equation_0 = const()[name = tensor("op_1002_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1002_cast_fp16 = einsum(equation = var_1002_equation_0, values = (var_920_cast_fp16, var_823_cast_fp16))[name = tensor("op_1002_cast_fp16")]; - tensor var_1003_to_fp16 = const()[name = tensor("op_1003_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_71_cast_fp16 = mul(x = var_1002_cast_fp16, y = var_1003_to_fp16)[name = tensor("aw_chunk_71_cast_fp16")]; - tensor var_1006_equation_0 = const()[name = tensor("op_1006_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1006_cast_fp16 = einsum(equation = var_1006_equation_0, values = (var_924_cast_fp16, var_830_cast_fp16))[name = tensor("op_1006_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_49_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_49_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_49_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_49_equation_0, values = (var_912_cast_fp16, var_746_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_49_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_51_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_51_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_51_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_51_equation_0, values = (var_912_cast_fp16, var_753_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_51_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_53_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_53_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_53_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_53_equation_0, values = (var_912_cast_fp16, var_760_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_53_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_55_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_55_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_55_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_55_equation_0, values = (var_912_cast_fp16, var_767_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_55_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_57_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_57_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_57_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_57_equation_0, values = (var_916_cast_fp16, var_774_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_57_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_59_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_59_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_59_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_59_equation_0, values = (var_916_cast_fp16, var_781_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_59_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_61_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_61_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_61_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_61_equation_0, values = (var_916_cast_fp16, var_788_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_61_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_63_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_63_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_63_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_63_equation_0, values = (var_916_cast_fp16, var_795_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_63_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_65_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_65_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_65_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_65_equation_0, values = (var_920_cast_fp16, var_802_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_65_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_67_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_67_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_67_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_67_equation_0, values = (var_920_cast_fp16, var_809_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_67_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_69_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_69_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_69_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_69_equation_0, values = (var_920_cast_fp16, var_816_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_69_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_71_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_71_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_71_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_71_equation_0, values = (var_920_cast_fp16, var_823_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_71_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_73_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_73_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_73_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_73_equation_0, values = (var_924_cast_fp16, var_830_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_73_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_75_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_75_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_75_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_75_equation_0, values = (var_924_cast_fp16, var_837_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_75_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_77_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_77_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_77_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_77_equation_0, values = (var_924_cast_fp16, var_844_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_77_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_79_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_79_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_79_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_79_equation_0, values = (var_924_cast_fp16, var_851_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_79_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_81_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_81_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_81_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_81_equation_0, values = (var_928_cast_fp16, var_858_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_81_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_83_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_83_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_83_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_83_equation_0, values = (var_928_cast_fp16, var_865_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_83_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_85_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_85_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_85_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_85_equation_0, values = (var_928_cast_fp16, var_872_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_85_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_87_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_87_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_87_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_87_equation_0, values = (var_928_cast_fp16, var_879_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_87_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_89_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_89_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_89_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_89_equation_0, values = (var_932_cast_fp16, var_886_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_89_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_91_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_91_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_91_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_91_equation_0, values = (var_932_cast_fp16, var_893_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_91_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_93_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_93_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_93_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_93_equation_0, values = (var_932_cast_fp16, var_900_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_93_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_95_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_95_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_95_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_95_equation_0, values = (var_932_cast_fp16, var_907_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_95_cast_fp16")]; + tensor var_1005_to_fp16 = const()[name = tensor("op_1005_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_49_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_49_cast_fp16, y = var_1005_to_fp16)[name = tensor("aw_chunk_49_cast_fp16")]; tensor var_1007_to_fp16 = const()[name = tensor("op_1007_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_73_cast_fp16 = mul(x = var_1006_cast_fp16, y = var_1007_to_fp16)[name = tensor("aw_chunk_73_cast_fp16")]; - tensor var_1010_equation_0 = const()[name = tensor("op_1010_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1010_cast_fp16 = einsum(equation = var_1010_equation_0, values = (var_924_cast_fp16, var_837_cast_fp16))[name = tensor("op_1010_cast_fp16")]; + tensor aw_chunk_51_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_51_cast_fp16, y = var_1007_to_fp16)[name = tensor("aw_chunk_51_cast_fp16")]; + tensor var_1009_to_fp16 = const()[name = tensor("op_1009_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_53_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_53_cast_fp16, y = var_1009_to_fp16)[name = tensor("aw_chunk_53_cast_fp16")]; tensor var_1011_to_fp16 = const()[name = tensor("op_1011_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_75_cast_fp16 = mul(x = var_1010_cast_fp16, y = var_1011_to_fp16)[name = tensor("aw_chunk_75_cast_fp16")]; - tensor var_1014_equation_0 = const()[name = tensor("op_1014_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1014_cast_fp16 = einsum(equation = var_1014_equation_0, values = (var_924_cast_fp16, var_844_cast_fp16))[name = tensor("op_1014_cast_fp16")]; + tensor aw_chunk_55_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_55_cast_fp16, y = var_1011_to_fp16)[name = tensor("aw_chunk_55_cast_fp16")]; + tensor var_1013_to_fp16 = const()[name = tensor("op_1013_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_57_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_57_cast_fp16, y = var_1013_to_fp16)[name = tensor("aw_chunk_57_cast_fp16")]; tensor var_1015_to_fp16 = const()[name = tensor("op_1015_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_77_cast_fp16 = mul(x = var_1014_cast_fp16, y = var_1015_to_fp16)[name = tensor("aw_chunk_77_cast_fp16")]; - tensor var_1018_equation_0 = const()[name = tensor("op_1018_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1018_cast_fp16 = einsum(equation = var_1018_equation_0, values = (var_924_cast_fp16, var_851_cast_fp16))[name = tensor("op_1018_cast_fp16")]; + tensor aw_chunk_59_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_59_cast_fp16, y = var_1015_to_fp16)[name = tensor("aw_chunk_59_cast_fp16")]; + tensor var_1017_to_fp16 = const()[name = tensor("op_1017_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_61_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_61_cast_fp16, y = var_1017_to_fp16)[name = tensor("aw_chunk_61_cast_fp16")]; tensor var_1019_to_fp16 = const()[name = tensor("op_1019_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_79_cast_fp16 = mul(x = var_1018_cast_fp16, y = var_1019_to_fp16)[name = tensor("aw_chunk_79_cast_fp16")]; - tensor var_1022_equation_0 = const()[name = tensor("op_1022_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1022_cast_fp16 = einsum(equation = var_1022_equation_0, values = (var_928_cast_fp16, var_858_cast_fp16))[name = tensor("op_1022_cast_fp16")]; + tensor aw_chunk_63_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_63_cast_fp16, y = var_1019_to_fp16)[name = tensor("aw_chunk_63_cast_fp16")]; + tensor var_1021_to_fp16 = const()[name = tensor("op_1021_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_65_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_65_cast_fp16, y = var_1021_to_fp16)[name = tensor("aw_chunk_65_cast_fp16")]; tensor var_1023_to_fp16 = const()[name = tensor("op_1023_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_81_cast_fp16 = mul(x = var_1022_cast_fp16, y = var_1023_to_fp16)[name = tensor("aw_chunk_81_cast_fp16")]; - tensor var_1026_equation_0 = const()[name = tensor("op_1026_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1026_cast_fp16 = einsum(equation = var_1026_equation_0, values = (var_928_cast_fp16, var_865_cast_fp16))[name = tensor("op_1026_cast_fp16")]; + tensor aw_chunk_67_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_67_cast_fp16, y = var_1023_to_fp16)[name = tensor("aw_chunk_67_cast_fp16")]; + tensor var_1025_to_fp16 = const()[name = tensor("op_1025_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_69_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_69_cast_fp16, y = var_1025_to_fp16)[name = tensor("aw_chunk_69_cast_fp16")]; tensor var_1027_to_fp16 = const()[name = tensor("op_1027_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_83_cast_fp16 = mul(x = var_1026_cast_fp16, y = var_1027_to_fp16)[name = tensor("aw_chunk_83_cast_fp16")]; - tensor var_1030_equation_0 = const()[name = tensor("op_1030_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1030_cast_fp16 = einsum(equation = var_1030_equation_0, values = (var_928_cast_fp16, var_872_cast_fp16))[name = tensor("op_1030_cast_fp16")]; + tensor aw_chunk_71_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_71_cast_fp16, y = var_1027_to_fp16)[name = tensor("aw_chunk_71_cast_fp16")]; + tensor var_1029_to_fp16 = const()[name = tensor("op_1029_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_73_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_73_cast_fp16, y = var_1029_to_fp16)[name = tensor("aw_chunk_73_cast_fp16")]; tensor var_1031_to_fp16 = const()[name = tensor("op_1031_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_85_cast_fp16 = mul(x = var_1030_cast_fp16, y = var_1031_to_fp16)[name = tensor("aw_chunk_85_cast_fp16")]; - tensor var_1034_equation_0 = const()[name = tensor("op_1034_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1034_cast_fp16 = einsum(equation = var_1034_equation_0, values = (var_928_cast_fp16, var_879_cast_fp16))[name = tensor("op_1034_cast_fp16")]; + tensor aw_chunk_75_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_75_cast_fp16, y = var_1031_to_fp16)[name = tensor("aw_chunk_75_cast_fp16")]; + tensor var_1033_to_fp16 = const()[name = tensor("op_1033_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_77_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_77_cast_fp16, y = var_1033_to_fp16)[name = tensor("aw_chunk_77_cast_fp16")]; tensor var_1035_to_fp16 = const()[name = tensor("op_1035_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_87_cast_fp16 = mul(x = var_1034_cast_fp16, y = var_1035_to_fp16)[name = tensor("aw_chunk_87_cast_fp16")]; - tensor var_1038_equation_0 = const()[name = tensor("op_1038_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1038_cast_fp16 = einsum(equation = var_1038_equation_0, values = (var_932_cast_fp16, var_886_cast_fp16))[name = tensor("op_1038_cast_fp16")]; + tensor aw_chunk_79_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_79_cast_fp16, y = var_1035_to_fp16)[name = tensor("aw_chunk_79_cast_fp16")]; + tensor var_1037_to_fp16 = const()[name = tensor("op_1037_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_81_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_81_cast_fp16, y = var_1037_to_fp16)[name = tensor("aw_chunk_81_cast_fp16")]; tensor var_1039_to_fp16 = const()[name = tensor("op_1039_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_89_cast_fp16 = mul(x = var_1038_cast_fp16, y = var_1039_to_fp16)[name = tensor("aw_chunk_89_cast_fp16")]; - tensor var_1042_equation_0 = const()[name = tensor("op_1042_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1042_cast_fp16 = einsum(equation = var_1042_equation_0, values = (var_932_cast_fp16, var_893_cast_fp16))[name = tensor("op_1042_cast_fp16")]; + tensor aw_chunk_83_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_83_cast_fp16, y = var_1039_to_fp16)[name = tensor("aw_chunk_83_cast_fp16")]; + tensor var_1041_to_fp16 = const()[name = tensor("op_1041_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_85_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_85_cast_fp16, y = var_1041_to_fp16)[name = tensor("aw_chunk_85_cast_fp16")]; tensor var_1043_to_fp16 = const()[name = tensor("op_1043_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_91_cast_fp16 = mul(x = var_1042_cast_fp16, y = var_1043_to_fp16)[name = tensor("aw_chunk_91_cast_fp16")]; - tensor var_1046_equation_0 = const()[name = tensor("op_1046_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1046_cast_fp16 = einsum(equation = var_1046_equation_0, values = (var_932_cast_fp16, var_900_cast_fp16))[name = tensor("op_1046_cast_fp16")]; + tensor aw_chunk_87_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_87_cast_fp16, y = var_1043_to_fp16)[name = tensor("aw_chunk_87_cast_fp16")]; + tensor var_1045_to_fp16 = const()[name = tensor("op_1045_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_89_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_89_cast_fp16, y = var_1045_to_fp16)[name = tensor("aw_chunk_89_cast_fp16")]; tensor var_1047_to_fp16 = const()[name = tensor("op_1047_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_93_cast_fp16 = mul(x = var_1046_cast_fp16, y = var_1047_to_fp16)[name = tensor("aw_chunk_93_cast_fp16")]; - tensor var_1050_equation_0 = const()[name = tensor("op_1050_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1050_cast_fp16 = einsum(equation = var_1050_equation_0, values = (var_932_cast_fp16, var_907_cast_fp16))[name = tensor("op_1050_cast_fp16")]; + tensor aw_chunk_91_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_91_cast_fp16, y = var_1047_to_fp16)[name = tensor("aw_chunk_91_cast_fp16")]; + tensor var_1049_to_fp16 = const()[name = tensor("op_1049_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_93_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_93_cast_fp16, y = var_1049_to_fp16)[name = tensor("aw_chunk_93_cast_fp16")]; tensor var_1051_to_fp16 = const()[name = tensor("op_1051_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_95_cast_fp16 = mul(x = var_1050_cast_fp16, y = var_1051_to_fp16)[name = tensor("aw_chunk_95_cast_fp16")]; + tensor aw_chunk_95_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_95_cast_fp16, y = var_1051_to_fp16)[name = tensor("aw_chunk_95_cast_fp16")]; tensor var_1053_cast_fp16 = softmax(axis = var_662, x = aw_chunk_49_cast_fp16)[name = tensor("op_1053_cast_fp16")]; tensor var_1054_cast_fp16 = softmax(axis = var_662, x = aw_chunk_51_cast_fp16)[name = tensor("op_1054_cast_fp16")]; tensor var_1055_cast_fp16 = softmax(axis = var_662, x = aw_chunk_53_cast_fp16)[name = tensor("op_1055_cast_fp16")]; @@ -1051,102 +1051,102 @@ program(1.0) tensor var_1487_end_0 = const()[name = tensor("op_1487_end_0"), val = tensor([1, 384, 1, 1500])]; tensor var_1487_end_mask_0 = const()[name = tensor("op_1487_end_mask_0"), val = tensor([true, false, true, true])]; tensor var_1487_cast_fp16 = slice_by_index(begin = var_1487_begin_0, end = var_1487_end_0, end_mask = var_1487_end_mask_0, x = value_5_cast_fp16)[name = tensor("op_1487_cast_fp16")]; - tensor var_1491_equation_0 = const()[name = tensor("op_1491_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1491_cast_fp16 = einsum(equation = var_1491_equation_0, values = (var_1445_cast_fp16, var_1279_cast_fp16))[name = tensor("op_1491_cast_fp16")]; - tensor var_1492_to_fp16 = const()[name = tensor("op_1492_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_97_cast_fp16 = mul(x = var_1491_cast_fp16, y = var_1492_to_fp16)[name = tensor("aw_chunk_97_cast_fp16")]; - tensor var_1495_equation_0 = const()[name = tensor("op_1495_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1495_cast_fp16 = einsum(equation = var_1495_equation_0, values = (var_1445_cast_fp16, var_1286_cast_fp16))[name = tensor("op_1495_cast_fp16")]; - tensor var_1496_to_fp16 = const()[name = tensor("op_1496_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_99_cast_fp16 = mul(x = var_1495_cast_fp16, y = var_1496_to_fp16)[name = tensor("aw_chunk_99_cast_fp16")]; - tensor var_1499_equation_0 = const()[name = tensor("op_1499_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1499_cast_fp16 = einsum(equation = var_1499_equation_0, values = (var_1445_cast_fp16, var_1293_cast_fp16))[name = tensor("op_1499_cast_fp16")]; - tensor var_1500_to_fp16 = const()[name = tensor("op_1500_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_101_cast_fp16 = mul(x = var_1499_cast_fp16, y = var_1500_to_fp16)[name = tensor("aw_chunk_101_cast_fp16")]; - tensor var_1503_equation_0 = const()[name = tensor("op_1503_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1503_cast_fp16 = einsum(equation = var_1503_equation_0, values = (var_1445_cast_fp16, var_1300_cast_fp16))[name = tensor("op_1503_cast_fp16")]; - tensor var_1504_to_fp16 = const()[name = tensor("op_1504_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_103_cast_fp16 = mul(x = var_1503_cast_fp16, y = var_1504_to_fp16)[name = tensor("aw_chunk_103_cast_fp16")]; - tensor var_1507_equation_0 = const()[name = tensor("op_1507_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1507_cast_fp16 = einsum(equation = var_1507_equation_0, values = (var_1449_cast_fp16, var_1307_cast_fp16))[name = tensor("op_1507_cast_fp16")]; - tensor var_1508_to_fp16 = const()[name = tensor("op_1508_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_105_cast_fp16 = mul(x = var_1507_cast_fp16, y = var_1508_to_fp16)[name = tensor("aw_chunk_105_cast_fp16")]; - tensor var_1511_equation_0 = const()[name = tensor("op_1511_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1511_cast_fp16 = einsum(equation = var_1511_equation_0, values = (var_1449_cast_fp16, var_1314_cast_fp16))[name = tensor("op_1511_cast_fp16")]; - tensor var_1512_to_fp16 = const()[name = tensor("op_1512_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_107_cast_fp16 = mul(x = var_1511_cast_fp16, y = var_1512_to_fp16)[name = tensor("aw_chunk_107_cast_fp16")]; - tensor var_1515_equation_0 = const()[name = tensor("op_1515_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1515_cast_fp16 = einsum(equation = var_1515_equation_0, values = (var_1449_cast_fp16, var_1321_cast_fp16))[name = tensor("op_1515_cast_fp16")]; - tensor var_1516_to_fp16 = const()[name = tensor("op_1516_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_109_cast_fp16 = mul(x = var_1515_cast_fp16, y = var_1516_to_fp16)[name = tensor("aw_chunk_109_cast_fp16")]; - tensor var_1519_equation_0 = const()[name = tensor("op_1519_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1519_cast_fp16 = einsum(equation = var_1519_equation_0, values = (var_1449_cast_fp16, var_1328_cast_fp16))[name = tensor("op_1519_cast_fp16")]; - tensor var_1520_to_fp16 = const()[name = tensor("op_1520_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_111_cast_fp16 = mul(x = var_1519_cast_fp16, y = var_1520_to_fp16)[name = tensor("aw_chunk_111_cast_fp16")]; - tensor var_1523_equation_0 = const()[name = tensor("op_1523_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1523_cast_fp16 = einsum(equation = var_1523_equation_0, values = (var_1453_cast_fp16, var_1335_cast_fp16))[name = tensor("op_1523_cast_fp16")]; - tensor var_1524_to_fp16 = const()[name = tensor("op_1524_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_113_cast_fp16 = mul(x = var_1523_cast_fp16, y = var_1524_to_fp16)[name = tensor("aw_chunk_113_cast_fp16")]; - tensor var_1527_equation_0 = const()[name = tensor("op_1527_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1527_cast_fp16 = einsum(equation = var_1527_equation_0, values = (var_1453_cast_fp16, var_1342_cast_fp16))[name = tensor("op_1527_cast_fp16")]; - tensor var_1528_to_fp16 = const()[name = tensor("op_1528_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_115_cast_fp16 = mul(x = var_1527_cast_fp16, y = var_1528_to_fp16)[name = tensor("aw_chunk_115_cast_fp16")]; - tensor var_1531_equation_0 = const()[name = tensor("op_1531_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1531_cast_fp16 = einsum(equation = var_1531_equation_0, values = (var_1453_cast_fp16, var_1349_cast_fp16))[name = tensor("op_1531_cast_fp16")]; - tensor var_1532_to_fp16 = const()[name = tensor("op_1532_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_117_cast_fp16 = mul(x = var_1531_cast_fp16, y = var_1532_to_fp16)[name = tensor("aw_chunk_117_cast_fp16")]; - tensor var_1535_equation_0 = const()[name = tensor("op_1535_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1535_cast_fp16 = einsum(equation = var_1535_equation_0, values = (var_1453_cast_fp16, var_1356_cast_fp16))[name = tensor("op_1535_cast_fp16")]; - tensor var_1536_to_fp16 = const()[name = tensor("op_1536_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_119_cast_fp16 = mul(x = var_1535_cast_fp16, y = var_1536_to_fp16)[name = tensor("aw_chunk_119_cast_fp16")]; - tensor var_1539_equation_0 = const()[name = tensor("op_1539_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1539_cast_fp16 = einsum(equation = var_1539_equation_0, values = (var_1457_cast_fp16, var_1363_cast_fp16))[name = tensor("op_1539_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_97_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_97_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_97_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_97_equation_0, values = (var_1445_cast_fp16, var_1279_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_97_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_99_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_99_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_99_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_99_equation_0, values = (var_1445_cast_fp16, var_1286_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_99_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_101_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_101_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_101_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_101_equation_0, values = (var_1445_cast_fp16, var_1293_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_101_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_103_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_103_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_103_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_103_equation_0, values = (var_1445_cast_fp16, var_1300_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_103_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_105_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_105_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_105_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_105_equation_0, values = (var_1449_cast_fp16, var_1307_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_105_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_107_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_107_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_107_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_107_equation_0, values = (var_1449_cast_fp16, var_1314_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_107_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_109_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_109_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_109_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_109_equation_0, values = (var_1449_cast_fp16, var_1321_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_109_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_111_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_111_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_111_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_111_equation_0, values = (var_1449_cast_fp16, var_1328_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_111_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_113_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_113_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_113_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_113_equation_0, values = (var_1453_cast_fp16, var_1335_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_113_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_115_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_115_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_115_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_115_equation_0, values = (var_1453_cast_fp16, var_1342_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_115_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_117_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_117_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_117_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_117_equation_0, values = (var_1453_cast_fp16, var_1349_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_117_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_119_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_119_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_119_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_119_equation_0, values = (var_1453_cast_fp16, var_1356_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_119_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_121_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_121_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_121_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_121_equation_0, values = (var_1457_cast_fp16, var_1363_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_121_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_123_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_123_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_123_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_123_equation_0, values = (var_1457_cast_fp16, var_1370_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_123_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_125_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_125_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_125_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_125_equation_0, values = (var_1457_cast_fp16, var_1377_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_125_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_127_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_127_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_127_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_127_equation_0, values = (var_1457_cast_fp16, var_1384_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_127_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_129_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_129_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_129_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_129_equation_0, values = (var_1461_cast_fp16, var_1391_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_129_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_131_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_131_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_131_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_131_equation_0, values = (var_1461_cast_fp16, var_1398_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_131_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_133_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_133_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_133_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_133_equation_0, values = (var_1461_cast_fp16, var_1405_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_133_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_135_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_135_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_135_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_135_equation_0, values = (var_1461_cast_fp16, var_1412_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_135_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_137_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_137_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_137_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_137_equation_0, values = (var_1465_cast_fp16, var_1419_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_137_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_139_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_139_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_139_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_139_equation_0, values = (var_1465_cast_fp16, var_1426_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_139_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_141_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_141_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_141_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_141_equation_0, values = (var_1465_cast_fp16, var_1433_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_141_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_143_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_143_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_143_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_143_equation_0, values = (var_1465_cast_fp16, var_1440_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_143_cast_fp16")]; + tensor var_1538_to_fp16 = const()[name = tensor("op_1538_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_97_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_97_cast_fp16, y = var_1538_to_fp16)[name = tensor("aw_chunk_97_cast_fp16")]; tensor var_1540_to_fp16 = const()[name = tensor("op_1540_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_121_cast_fp16 = mul(x = var_1539_cast_fp16, y = var_1540_to_fp16)[name = tensor("aw_chunk_121_cast_fp16")]; - tensor var_1543_equation_0 = const()[name = tensor("op_1543_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1543_cast_fp16 = einsum(equation = var_1543_equation_0, values = (var_1457_cast_fp16, var_1370_cast_fp16))[name = tensor("op_1543_cast_fp16")]; + tensor aw_chunk_99_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_99_cast_fp16, y = var_1540_to_fp16)[name = tensor("aw_chunk_99_cast_fp16")]; + tensor var_1542_to_fp16 = const()[name = tensor("op_1542_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_101_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_101_cast_fp16, y = var_1542_to_fp16)[name = tensor("aw_chunk_101_cast_fp16")]; tensor var_1544_to_fp16 = const()[name = tensor("op_1544_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_123_cast_fp16 = mul(x = var_1543_cast_fp16, y = var_1544_to_fp16)[name = tensor("aw_chunk_123_cast_fp16")]; - tensor var_1547_equation_0 = const()[name = tensor("op_1547_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1547_cast_fp16 = einsum(equation = var_1547_equation_0, values = (var_1457_cast_fp16, var_1377_cast_fp16))[name = tensor("op_1547_cast_fp16")]; + tensor aw_chunk_103_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_103_cast_fp16, y = var_1544_to_fp16)[name = tensor("aw_chunk_103_cast_fp16")]; + tensor var_1546_to_fp16 = const()[name = tensor("op_1546_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_105_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_105_cast_fp16, y = var_1546_to_fp16)[name = tensor("aw_chunk_105_cast_fp16")]; tensor var_1548_to_fp16 = const()[name = tensor("op_1548_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_125_cast_fp16 = mul(x = var_1547_cast_fp16, y = var_1548_to_fp16)[name = tensor("aw_chunk_125_cast_fp16")]; - tensor var_1551_equation_0 = const()[name = tensor("op_1551_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1551_cast_fp16 = einsum(equation = var_1551_equation_0, values = (var_1457_cast_fp16, var_1384_cast_fp16))[name = tensor("op_1551_cast_fp16")]; + tensor aw_chunk_107_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_107_cast_fp16, y = var_1548_to_fp16)[name = tensor("aw_chunk_107_cast_fp16")]; + tensor var_1550_to_fp16 = const()[name = tensor("op_1550_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_109_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_109_cast_fp16, y = var_1550_to_fp16)[name = tensor("aw_chunk_109_cast_fp16")]; tensor var_1552_to_fp16 = const()[name = tensor("op_1552_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_127_cast_fp16 = mul(x = var_1551_cast_fp16, y = var_1552_to_fp16)[name = tensor("aw_chunk_127_cast_fp16")]; - tensor var_1555_equation_0 = const()[name = tensor("op_1555_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1555_cast_fp16 = einsum(equation = var_1555_equation_0, values = (var_1461_cast_fp16, var_1391_cast_fp16))[name = tensor("op_1555_cast_fp16")]; + tensor aw_chunk_111_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_111_cast_fp16, y = var_1552_to_fp16)[name = tensor("aw_chunk_111_cast_fp16")]; + tensor var_1554_to_fp16 = const()[name = tensor("op_1554_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_113_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_113_cast_fp16, y = var_1554_to_fp16)[name = tensor("aw_chunk_113_cast_fp16")]; tensor var_1556_to_fp16 = const()[name = tensor("op_1556_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_129_cast_fp16 = mul(x = var_1555_cast_fp16, y = var_1556_to_fp16)[name = tensor("aw_chunk_129_cast_fp16")]; - tensor var_1559_equation_0 = const()[name = tensor("op_1559_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1559_cast_fp16 = einsum(equation = var_1559_equation_0, values = (var_1461_cast_fp16, var_1398_cast_fp16))[name = tensor("op_1559_cast_fp16")]; + tensor aw_chunk_115_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_115_cast_fp16, y = var_1556_to_fp16)[name = tensor("aw_chunk_115_cast_fp16")]; + tensor var_1558_to_fp16 = const()[name = tensor("op_1558_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_117_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_117_cast_fp16, y = var_1558_to_fp16)[name = tensor("aw_chunk_117_cast_fp16")]; tensor var_1560_to_fp16 = const()[name = tensor("op_1560_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_131_cast_fp16 = mul(x = var_1559_cast_fp16, y = var_1560_to_fp16)[name = tensor("aw_chunk_131_cast_fp16")]; - tensor var_1563_equation_0 = const()[name = tensor("op_1563_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1563_cast_fp16 = einsum(equation = var_1563_equation_0, values = (var_1461_cast_fp16, var_1405_cast_fp16))[name = tensor("op_1563_cast_fp16")]; + tensor aw_chunk_119_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_119_cast_fp16, y = var_1560_to_fp16)[name = tensor("aw_chunk_119_cast_fp16")]; + tensor var_1562_to_fp16 = const()[name = tensor("op_1562_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_121_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_121_cast_fp16, y = var_1562_to_fp16)[name = tensor("aw_chunk_121_cast_fp16")]; tensor var_1564_to_fp16 = const()[name = tensor("op_1564_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_133_cast_fp16 = mul(x = var_1563_cast_fp16, y = var_1564_to_fp16)[name = tensor("aw_chunk_133_cast_fp16")]; - tensor var_1567_equation_0 = const()[name = tensor("op_1567_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1567_cast_fp16 = einsum(equation = var_1567_equation_0, values = (var_1461_cast_fp16, var_1412_cast_fp16))[name = tensor("op_1567_cast_fp16")]; + tensor aw_chunk_123_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_123_cast_fp16, y = var_1564_to_fp16)[name = tensor("aw_chunk_123_cast_fp16")]; + tensor var_1566_to_fp16 = const()[name = tensor("op_1566_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_125_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_125_cast_fp16, y = var_1566_to_fp16)[name = tensor("aw_chunk_125_cast_fp16")]; tensor var_1568_to_fp16 = const()[name = tensor("op_1568_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_135_cast_fp16 = mul(x = var_1567_cast_fp16, y = var_1568_to_fp16)[name = tensor("aw_chunk_135_cast_fp16")]; - tensor var_1571_equation_0 = const()[name = tensor("op_1571_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1571_cast_fp16 = einsum(equation = var_1571_equation_0, values = (var_1465_cast_fp16, var_1419_cast_fp16))[name = tensor("op_1571_cast_fp16")]; + tensor aw_chunk_127_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_127_cast_fp16, y = var_1568_to_fp16)[name = tensor("aw_chunk_127_cast_fp16")]; + tensor var_1570_to_fp16 = const()[name = tensor("op_1570_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_129_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_129_cast_fp16, y = var_1570_to_fp16)[name = tensor("aw_chunk_129_cast_fp16")]; tensor var_1572_to_fp16 = const()[name = tensor("op_1572_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_137_cast_fp16 = mul(x = var_1571_cast_fp16, y = var_1572_to_fp16)[name = tensor("aw_chunk_137_cast_fp16")]; - tensor var_1575_equation_0 = const()[name = tensor("op_1575_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1575_cast_fp16 = einsum(equation = var_1575_equation_0, values = (var_1465_cast_fp16, var_1426_cast_fp16))[name = tensor("op_1575_cast_fp16")]; + tensor aw_chunk_131_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_131_cast_fp16, y = var_1572_to_fp16)[name = tensor("aw_chunk_131_cast_fp16")]; + tensor var_1574_to_fp16 = const()[name = tensor("op_1574_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_133_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_133_cast_fp16, y = var_1574_to_fp16)[name = tensor("aw_chunk_133_cast_fp16")]; tensor var_1576_to_fp16 = const()[name = tensor("op_1576_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_139_cast_fp16 = mul(x = var_1575_cast_fp16, y = var_1576_to_fp16)[name = tensor("aw_chunk_139_cast_fp16")]; - tensor var_1579_equation_0 = const()[name = tensor("op_1579_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1579_cast_fp16 = einsum(equation = var_1579_equation_0, values = (var_1465_cast_fp16, var_1433_cast_fp16))[name = tensor("op_1579_cast_fp16")]; + tensor aw_chunk_135_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_135_cast_fp16, y = var_1576_to_fp16)[name = tensor("aw_chunk_135_cast_fp16")]; + tensor var_1578_to_fp16 = const()[name = tensor("op_1578_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_137_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_137_cast_fp16, y = var_1578_to_fp16)[name = tensor("aw_chunk_137_cast_fp16")]; tensor var_1580_to_fp16 = const()[name = tensor("op_1580_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_141_cast_fp16 = mul(x = var_1579_cast_fp16, y = var_1580_to_fp16)[name = tensor("aw_chunk_141_cast_fp16")]; - tensor var_1583_equation_0 = const()[name = tensor("op_1583_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_1583_cast_fp16 = einsum(equation = var_1583_equation_0, values = (var_1465_cast_fp16, var_1440_cast_fp16))[name = tensor("op_1583_cast_fp16")]; + tensor aw_chunk_139_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_139_cast_fp16, y = var_1580_to_fp16)[name = tensor("aw_chunk_139_cast_fp16")]; + tensor var_1582_to_fp16 = const()[name = tensor("op_1582_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_141_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_141_cast_fp16, y = var_1582_to_fp16)[name = tensor("aw_chunk_141_cast_fp16")]; tensor var_1584_to_fp16 = const()[name = tensor("op_1584_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_143_cast_fp16 = mul(x = var_1583_cast_fp16, y = var_1584_to_fp16)[name = tensor("aw_chunk_143_cast_fp16")]; + tensor aw_chunk_143_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_143_cast_fp16, y = var_1584_to_fp16)[name = tensor("aw_chunk_143_cast_fp16")]; tensor var_1586_cast_fp16 = softmax(axis = var_1195, x = aw_chunk_97_cast_fp16)[name = tensor("op_1586_cast_fp16")]; tensor var_1587_cast_fp16 = softmax(axis = var_1195, x = aw_chunk_99_cast_fp16)[name = tensor("op_1587_cast_fp16")]; tensor var_1588_cast_fp16 = softmax(axis = var_1195, x = aw_chunk_101_cast_fp16)[name = tensor("op_1588_cast_fp16")]; @@ -1464,102 +1464,102 @@ program(1.0) tensor var_2020_end_0 = const()[name = tensor("op_2020_end_0"), val = tensor([1, 384, 1, 1500])]; tensor var_2020_end_mask_0 = const()[name = tensor("op_2020_end_mask_0"), val = tensor([true, false, true, true])]; tensor var_2020_cast_fp16 = slice_by_index(begin = var_2020_begin_0, end = var_2020_end_0, end_mask = var_2020_end_mask_0, x = value_cast_fp16)[name = tensor("op_2020_cast_fp16")]; - tensor var_2024_equation_0 = const()[name = tensor("op_2024_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2024_cast_fp16 = einsum(equation = var_2024_equation_0, values = (var_1978_cast_fp16, var_1812_cast_fp16))[name = tensor("op_2024_cast_fp16")]; - tensor var_2025_to_fp16 = const()[name = tensor("op_2025_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_145_cast_fp16 = mul(x = var_2024_cast_fp16, y = var_2025_to_fp16)[name = tensor("aw_chunk_145_cast_fp16")]; - tensor var_2028_equation_0 = const()[name = tensor("op_2028_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2028_cast_fp16 = einsum(equation = var_2028_equation_0, values = (var_1978_cast_fp16, var_1819_cast_fp16))[name = tensor("op_2028_cast_fp16")]; - tensor var_2029_to_fp16 = const()[name = tensor("op_2029_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_147_cast_fp16 = mul(x = var_2028_cast_fp16, y = var_2029_to_fp16)[name = tensor("aw_chunk_147_cast_fp16")]; - tensor var_2032_equation_0 = const()[name = tensor("op_2032_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2032_cast_fp16 = einsum(equation = var_2032_equation_0, values = (var_1978_cast_fp16, var_1826_cast_fp16))[name = tensor("op_2032_cast_fp16")]; - tensor var_2033_to_fp16 = const()[name = tensor("op_2033_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_149_cast_fp16 = mul(x = var_2032_cast_fp16, y = var_2033_to_fp16)[name = tensor("aw_chunk_149_cast_fp16")]; - tensor var_2036_equation_0 = const()[name = tensor("op_2036_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2036_cast_fp16 = einsum(equation = var_2036_equation_0, values = (var_1978_cast_fp16, var_1833_cast_fp16))[name = tensor("op_2036_cast_fp16")]; - tensor var_2037_to_fp16 = const()[name = tensor("op_2037_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_151_cast_fp16 = mul(x = var_2036_cast_fp16, y = var_2037_to_fp16)[name = tensor("aw_chunk_151_cast_fp16")]; - tensor var_2040_equation_0 = const()[name = tensor("op_2040_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2040_cast_fp16 = einsum(equation = var_2040_equation_0, values = (var_1982_cast_fp16, var_1840_cast_fp16))[name = tensor("op_2040_cast_fp16")]; - tensor var_2041_to_fp16 = const()[name = tensor("op_2041_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_153_cast_fp16 = mul(x = var_2040_cast_fp16, y = var_2041_to_fp16)[name = tensor("aw_chunk_153_cast_fp16")]; - tensor var_2044_equation_0 = const()[name = tensor("op_2044_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2044_cast_fp16 = einsum(equation = var_2044_equation_0, values = (var_1982_cast_fp16, var_1847_cast_fp16))[name = tensor("op_2044_cast_fp16")]; - tensor var_2045_to_fp16 = const()[name = tensor("op_2045_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_155_cast_fp16 = mul(x = var_2044_cast_fp16, y = var_2045_to_fp16)[name = tensor("aw_chunk_155_cast_fp16")]; - tensor var_2048_equation_0 = const()[name = tensor("op_2048_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2048_cast_fp16 = einsum(equation = var_2048_equation_0, values = (var_1982_cast_fp16, var_1854_cast_fp16))[name = tensor("op_2048_cast_fp16")]; - tensor var_2049_to_fp16 = const()[name = tensor("op_2049_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_157_cast_fp16 = mul(x = var_2048_cast_fp16, y = var_2049_to_fp16)[name = tensor("aw_chunk_157_cast_fp16")]; - tensor var_2052_equation_0 = const()[name = tensor("op_2052_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2052_cast_fp16 = einsum(equation = var_2052_equation_0, values = (var_1982_cast_fp16, var_1861_cast_fp16))[name = tensor("op_2052_cast_fp16")]; - tensor var_2053_to_fp16 = const()[name = tensor("op_2053_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_159_cast_fp16 = mul(x = var_2052_cast_fp16, y = var_2053_to_fp16)[name = tensor("aw_chunk_159_cast_fp16")]; - tensor var_2056_equation_0 = const()[name = tensor("op_2056_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2056_cast_fp16 = einsum(equation = var_2056_equation_0, values = (var_1986_cast_fp16, var_1868_cast_fp16))[name = tensor("op_2056_cast_fp16")]; - tensor var_2057_to_fp16 = const()[name = tensor("op_2057_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_161_cast_fp16 = mul(x = var_2056_cast_fp16, y = var_2057_to_fp16)[name = tensor("aw_chunk_161_cast_fp16")]; - tensor var_2060_equation_0 = const()[name = tensor("op_2060_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2060_cast_fp16 = einsum(equation = var_2060_equation_0, values = (var_1986_cast_fp16, var_1875_cast_fp16))[name = tensor("op_2060_cast_fp16")]; - tensor var_2061_to_fp16 = const()[name = tensor("op_2061_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_163_cast_fp16 = mul(x = var_2060_cast_fp16, y = var_2061_to_fp16)[name = tensor("aw_chunk_163_cast_fp16")]; - tensor var_2064_equation_0 = const()[name = tensor("op_2064_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2064_cast_fp16 = einsum(equation = var_2064_equation_0, values = (var_1986_cast_fp16, var_1882_cast_fp16))[name = tensor("op_2064_cast_fp16")]; - tensor var_2065_to_fp16 = const()[name = tensor("op_2065_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_165_cast_fp16 = mul(x = var_2064_cast_fp16, y = var_2065_to_fp16)[name = tensor("aw_chunk_165_cast_fp16")]; - tensor var_2068_equation_0 = const()[name = tensor("op_2068_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2068_cast_fp16 = einsum(equation = var_2068_equation_0, values = (var_1986_cast_fp16, var_1889_cast_fp16))[name = tensor("op_2068_cast_fp16")]; - tensor var_2069_to_fp16 = const()[name = tensor("op_2069_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_167_cast_fp16 = mul(x = var_2068_cast_fp16, y = var_2069_to_fp16)[name = tensor("aw_chunk_167_cast_fp16")]; - tensor var_2072_equation_0 = const()[name = tensor("op_2072_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2072_cast_fp16 = einsum(equation = var_2072_equation_0, values = (var_1990_cast_fp16, var_1896_cast_fp16))[name = tensor("op_2072_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_145_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_145_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_145_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_145_equation_0, values = (var_1978_cast_fp16, var_1812_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_145_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_147_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_147_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_147_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_147_equation_0, values = (var_1978_cast_fp16, var_1819_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_147_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_149_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_149_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_149_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_149_equation_0, values = (var_1978_cast_fp16, var_1826_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_149_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_151_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_151_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_151_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_151_equation_0, values = (var_1978_cast_fp16, var_1833_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_151_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_153_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_153_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_153_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_153_equation_0, values = (var_1982_cast_fp16, var_1840_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_153_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_155_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_155_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_155_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_155_equation_0, values = (var_1982_cast_fp16, var_1847_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_155_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_157_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_157_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_157_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_157_equation_0, values = (var_1982_cast_fp16, var_1854_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_157_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_159_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_159_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_159_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_159_equation_0, values = (var_1982_cast_fp16, var_1861_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_159_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_161_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_161_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_161_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_161_equation_0, values = (var_1986_cast_fp16, var_1868_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_161_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_163_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_163_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_163_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_163_equation_0, values = (var_1986_cast_fp16, var_1875_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_163_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_165_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_165_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_165_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_165_equation_0, values = (var_1986_cast_fp16, var_1882_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_165_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_167_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_167_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_167_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_167_equation_0, values = (var_1986_cast_fp16, var_1889_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_167_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_169_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_169_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_169_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_169_equation_0, values = (var_1990_cast_fp16, var_1896_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_169_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_171_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_171_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_171_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_171_equation_0, values = (var_1990_cast_fp16, var_1903_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_171_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_173_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_173_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_173_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_173_equation_0, values = (var_1990_cast_fp16, var_1910_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_173_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_175_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_175_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_175_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_175_equation_0, values = (var_1990_cast_fp16, var_1917_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_175_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_177_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_177_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_177_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_177_equation_0, values = (var_1994_cast_fp16, var_1924_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_177_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_179_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_179_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_179_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_179_equation_0, values = (var_1994_cast_fp16, var_1931_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_179_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_181_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_181_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_181_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_181_equation_0, values = (var_1994_cast_fp16, var_1938_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_181_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_183_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_183_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_183_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_183_equation_0, values = (var_1994_cast_fp16, var_1945_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_183_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_185_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_185_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_185_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_185_equation_0, values = (var_1998_cast_fp16, var_1952_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_185_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_187_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_187_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_187_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_187_equation_0, values = (var_1998_cast_fp16, var_1959_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_187_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_189_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_189_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_189_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_189_equation_0, values = (var_1998_cast_fp16, var_1966_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_189_cast_fp16")]; + tensor _SplitHeadsQ__mh_w_equation_0 = const()[name = tensor("_SplitHeadsQ__mh_w_equation_0"), val = tensor("bkhc,bchq->bkhq")]; + tensor _SplitHeadsQ__mh_w_cast_fp16 = einsum(equation = _SplitHeadsQ__mh_w_equation_0, values = (var_1998_cast_fp16, var_1973_cast_fp16))[name = tensor("_SplitHeadsQ__mh_w_cast_fp16")]; + tensor var_2071_to_fp16 = const()[name = tensor("op_2071_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_145_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_145_cast_fp16, y = var_2071_to_fp16)[name = tensor("aw_chunk_145_cast_fp16")]; tensor var_2073_to_fp16 = const()[name = tensor("op_2073_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_169_cast_fp16 = mul(x = var_2072_cast_fp16, y = var_2073_to_fp16)[name = tensor("aw_chunk_169_cast_fp16")]; - tensor var_2076_equation_0 = const()[name = tensor("op_2076_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2076_cast_fp16 = einsum(equation = var_2076_equation_0, values = (var_1990_cast_fp16, var_1903_cast_fp16))[name = tensor("op_2076_cast_fp16")]; + tensor aw_chunk_147_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_147_cast_fp16, y = var_2073_to_fp16)[name = tensor("aw_chunk_147_cast_fp16")]; + tensor var_2075_to_fp16 = const()[name = tensor("op_2075_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_149_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_149_cast_fp16, y = var_2075_to_fp16)[name = tensor("aw_chunk_149_cast_fp16")]; tensor var_2077_to_fp16 = const()[name = tensor("op_2077_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_171_cast_fp16 = mul(x = var_2076_cast_fp16, y = var_2077_to_fp16)[name = tensor("aw_chunk_171_cast_fp16")]; - tensor var_2080_equation_0 = const()[name = tensor("op_2080_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2080_cast_fp16 = einsum(equation = var_2080_equation_0, values = (var_1990_cast_fp16, var_1910_cast_fp16))[name = tensor("op_2080_cast_fp16")]; + tensor aw_chunk_151_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_151_cast_fp16, y = var_2077_to_fp16)[name = tensor("aw_chunk_151_cast_fp16")]; + tensor var_2079_to_fp16 = const()[name = tensor("op_2079_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_153_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_153_cast_fp16, y = var_2079_to_fp16)[name = tensor("aw_chunk_153_cast_fp16")]; tensor var_2081_to_fp16 = const()[name = tensor("op_2081_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_173_cast_fp16 = mul(x = var_2080_cast_fp16, y = var_2081_to_fp16)[name = tensor("aw_chunk_173_cast_fp16")]; - tensor var_2084_equation_0 = const()[name = tensor("op_2084_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2084_cast_fp16 = einsum(equation = var_2084_equation_0, values = (var_1990_cast_fp16, var_1917_cast_fp16))[name = tensor("op_2084_cast_fp16")]; + tensor aw_chunk_155_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_155_cast_fp16, y = var_2081_to_fp16)[name = tensor("aw_chunk_155_cast_fp16")]; + tensor var_2083_to_fp16 = const()[name = tensor("op_2083_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_157_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_157_cast_fp16, y = var_2083_to_fp16)[name = tensor("aw_chunk_157_cast_fp16")]; tensor var_2085_to_fp16 = const()[name = tensor("op_2085_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_175_cast_fp16 = mul(x = var_2084_cast_fp16, y = var_2085_to_fp16)[name = tensor("aw_chunk_175_cast_fp16")]; - tensor var_2088_equation_0 = const()[name = tensor("op_2088_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2088_cast_fp16 = einsum(equation = var_2088_equation_0, values = (var_1994_cast_fp16, var_1924_cast_fp16))[name = tensor("op_2088_cast_fp16")]; + tensor aw_chunk_159_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_159_cast_fp16, y = var_2085_to_fp16)[name = tensor("aw_chunk_159_cast_fp16")]; + tensor var_2087_to_fp16 = const()[name = tensor("op_2087_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_161_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_161_cast_fp16, y = var_2087_to_fp16)[name = tensor("aw_chunk_161_cast_fp16")]; tensor var_2089_to_fp16 = const()[name = tensor("op_2089_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_177_cast_fp16 = mul(x = var_2088_cast_fp16, y = var_2089_to_fp16)[name = tensor("aw_chunk_177_cast_fp16")]; - tensor var_2092_equation_0 = const()[name = tensor("op_2092_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2092_cast_fp16 = einsum(equation = var_2092_equation_0, values = (var_1994_cast_fp16, var_1931_cast_fp16))[name = tensor("op_2092_cast_fp16")]; + tensor aw_chunk_163_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_163_cast_fp16, y = var_2089_to_fp16)[name = tensor("aw_chunk_163_cast_fp16")]; + tensor var_2091_to_fp16 = const()[name = tensor("op_2091_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_165_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_165_cast_fp16, y = var_2091_to_fp16)[name = tensor("aw_chunk_165_cast_fp16")]; tensor var_2093_to_fp16 = const()[name = tensor("op_2093_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_179_cast_fp16 = mul(x = var_2092_cast_fp16, y = var_2093_to_fp16)[name = tensor("aw_chunk_179_cast_fp16")]; - tensor var_2096_equation_0 = const()[name = tensor("op_2096_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2096_cast_fp16 = einsum(equation = var_2096_equation_0, values = (var_1994_cast_fp16, var_1938_cast_fp16))[name = tensor("op_2096_cast_fp16")]; + tensor aw_chunk_167_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_167_cast_fp16, y = var_2093_to_fp16)[name = tensor("aw_chunk_167_cast_fp16")]; + tensor var_2095_to_fp16 = const()[name = tensor("op_2095_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_169_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_169_cast_fp16, y = var_2095_to_fp16)[name = tensor("aw_chunk_169_cast_fp16")]; tensor var_2097_to_fp16 = const()[name = tensor("op_2097_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_181_cast_fp16 = mul(x = var_2096_cast_fp16, y = var_2097_to_fp16)[name = tensor("aw_chunk_181_cast_fp16")]; - tensor var_2100_equation_0 = const()[name = tensor("op_2100_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2100_cast_fp16 = einsum(equation = var_2100_equation_0, values = (var_1994_cast_fp16, var_1945_cast_fp16))[name = tensor("op_2100_cast_fp16")]; + tensor aw_chunk_171_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_171_cast_fp16, y = var_2097_to_fp16)[name = tensor("aw_chunk_171_cast_fp16")]; + tensor var_2099_to_fp16 = const()[name = tensor("op_2099_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_173_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_173_cast_fp16, y = var_2099_to_fp16)[name = tensor("aw_chunk_173_cast_fp16")]; tensor var_2101_to_fp16 = const()[name = tensor("op_2101_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_183_cast_fp16 = mul(x = var_2100_cast_fp16, y = var_2101_to_fp16)[name = tensor("aw_chunk_183_cast_fp16")]; - tensor var_2104_equation_0 = const()[name = tensor("op_2104_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2104_cast_fp16 = einsum(equation = var_2104_equation_0, values = (var_1998_cast_fp16, var_1952_cast_fp16))[name = tensor("op_2104_cast_fp16")]; + tensor aw_chunk_175_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_175_cast_fp16, y = var_2101_to_fp16)[name = tensor("aw_chunk_175_cast_fp16")]; + tensor var_2103_to_fp16 = const()[name = tensor("op_2103_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_177_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_177_cast_fp16, y = var_2103_to_fp16)[name = tensor("aw_chunk_177_cast_fp16")]; tensor var_2105_to_fp16 = const()[name = tensor("op_2105_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_185_cast_fp16 = mul(x = var_2104_cast_fp16, y = var_2105_to_fp16)[name = tensor("aw_chunk_185_cast_fp16")]; - tensor var_2108_equation_0 = const()[name = tensor("op_2108_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2108_cast_fp16 = einsum(equation = var_2108_equation_0, values = (var_1998_cast_fp16, var_1959_cast_fp16))[name = tensor("op_2108_cast_fp16")]; + tensor aw_chunk_179_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_179_cast_fp16, y = var_2105_to_fp16)[name = tensor("aw_chunk_179_cast_fp16")]; + tensor var_2107_to_fp16 = const()[name = tensor("op_2107_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_181_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_181_cast_fp16, y = var_2107_to_fp16)[name = tensor("aw_chunk_181_cast_fp16")]; tensor var_2109_to_fp16 = const()[name = tensor("op_2109_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_187_cast_fp16 = mul(x = var_2108_cast_fp16, y = var_2109_to_fp16)[name = tensor("aw_chunk_187_cast_fp16")]; - tensor var_2112_equation_0 = const()[name = tensor("op_2112_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2112_cast_fp16 = einsum(equation = var_2112_equation_0, values = (var_1998_cast_fp16, var_1966_cast_fp16))[name = tensor("op_2112_cast_fp16")]; + tensor aw_chunk_183_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_183_cast_fp16, y = var_2109_to_fp16)[name = tensor("aw_chunk_183_cast_fp16")]; + tensor var_2111_to_fp16 = const()[name = tensor("op_2111_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_185_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_185_cast_fp16, y = var_2111_to_fp16)[name = tensor("aw_chunk_185_cast_fp16")]; tensor var_2113_to_fp16 = const()[name = tensor("op_2113_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_189_cast_fp16 = mul(x = var_2112_cast_fp16, y = var_2113_to_fp16)[name = tensor("aw_chunk_189_cast_fp16")]; - tensor var_2116_equation_0 = const()[name = tensor("op_2116_equation_0"), val = tensor("bkhc,bchq->bkhq")]; - tensor var_2116_cast_fp16 = einsum(equation = var_2116_equation_0, values = (var_1998_cast_fp16, var_1973_cast_fp16))[name = tensor("op_2116_cast_fp16")]; + tensor aw_chunk_187_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_187_cast_fp16, y = var_2113_to_fp16)[name = tensor("aw_chunk_187_cast_fp16")]; + tensor var_2115_to_fp16 = const()[name = tensor("op_2115_to_fp16"), val = tensor(0x1p-3)]; + tensor aw_chunk_189_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_189_cast_fp16, y = var_2115_to_fp16)[name = tensor("aw_chunk_189_cast_fp16")]; tensor var_2117_to_fp16 = const()[name = tensor("op_2117_to_fp16"), val = tensor(0x1p-3)]; - tensor aw_chunk_cast_fp16 = mul(x = var_2116_cast_fp16, y = var_2117_to_fp16)[name = tensor("aw_chunk_cast_fp16")]; + tensor aw_chunk_cast_fp16 = mul(x = _SplitHeadsQ__mh_w_cast_fp16, y = var_2117_to_fp16)[name = tensor("aw_chunk_cast_fp16")]; tensor var_2119_cast_fp16 = softmax(axis = var_1728, x = aw_chunk_145_cast_fp16)[name = tensor("op_2119_cast_fp16")]; tensor var_2120_cast_fp16 = softmax(axis = var_1728, x = aw_chunk_147_cast_fp16)[name = tensor("op_2120_cast_fp16")]; tensor var_2121_cast_fp16 = softmax(axis = var_1728, x = aw_chunk_149_cast_fp16)[name = tensor("op_2121_cast_fp16")];