program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "5.33.5"}, {"coremlc-version", "1877.40.3"}, {"coremltools-component-torch", "2.2.1"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "7.1"}})] { func main(tensor language, tensor task) { tensor var_6 = const()[name = tensor("op_6"), val = tensor(50259)]; tensor var_7 = sub(x = language, y = var_6)[name = tensor("op_7")]; tensor var_8 = const()[name = tensor("op_8"), val = tensor(2)]; tensor var_9 = mul(x = var_7, y = var_8)[name = tensor("op_9")]; tensor input = add(x = var_9, y = task)[name = tensor("input")]; tensor var_15_axis_0 = const()[name = tensor("op_15_axis_0"), val = tensor(0)]; tensor var_15_batch_dims_0 = const()[name = tensor("op_15_batch_dims_0"), val = tensor(0)]; tensor var_15_validate_indices_0 = const()[name = tensor("op_15_validate_indices_0"), val = tensor(false)]; tensor key_cache_lut_weight_to_fp16 = const()[name = tensor("key_cache_lut_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor input_to_int16_dtype_0 = const()[name = tensor("input_to_int16_dtype_0"), val = tensor("int16")]; tensor cast_6 = cast(dtype = input_to_int16_dtype_0, x = input)[name = tensor("cast_6")]; tensor var_15_cast_fp16_cast_int16 = gather(axis = var_15_axis_0, batch_dims = var_15_batch_dims_0, indices = cast_6, validate_indices = var_15_validate_indices_0, x = key_cache_lut_weight_to_fp16)[name = tensor("op_15_cast_fp16_cast_int16")]; tensor var_20 = const()[name = tensor("op_20"), val = tensor([1, 40960, 1, 3])]; tensor key_cache_prefill = reshape(shape = var_20, x = var_15_cast_fp16_cast_int16)[name = tensor("op_21_cast_fp16")]; tensor var_25_axis_0 = const()[name = tensor("op_25_axis_0"), val = tensor(0)]; tensor var_25_batch_dims_0 = const()[name = tensor("op_25_batch_dims_0"), val = tensor(0)]; tensor var_25_validate_indices_0 = const()[name = tensor("op_25_validate_indices_0"), val = tensor(false)]; tensor value_cache_lut_weight_to_fp16 = const()[name = tensor("value_cache_lut_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(49152128)))]; tensor var_25_cast_fp16_cast_int16 = gather(axis = var_25_axis_0, batch_dims = var_25_batch_dims_0, indices = cast_6, validate_indices = var_25_validate_indices_0, x = value_cache_lut_weight_to_fp16)[name = tensor("op_25_cast_fp16_cast_int16")]; tensor var_30 = const()[name = tensor("op_30"), val = tensor([1, 40960, 1, 3])]; tensor value_cache_prefill = reshape(shape = var_30, x = var_25_cast_fp16_cast_int16)[name = tensor("op_31_cast_fp16")]; } -> (key_cache_prefill, value_cache_prefill); }