File size: 93 Bytes
af6e330
 
1
2
3
def identity_with_cast(q, k, v, offset: int = 0):
    return q.to(v.dtype), k.to(v.dtype), v