File size: 93 Bytes
af6e330 |
1 2 3 |
def identity_with_cast(q, k, v, offset: int = 0):
return q.to(v.dtype), k.to(v.dtype), v
|
af6e330 |
1 2 3 |
def identity_with_cast(q, k, v, offset: int = 0):
return q.to(v.dtype), k.to(v.dtype), v
|