|
from torch.fx.proxy import Proxy |
|
from ._compatibility import compatibility |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def annotate(val, type): |
|
|
|
|
|
if isinstance(val, Proxy): |
|
if val.node.type: |
|
raise RuntimeError(f"Tried to annotate a value that already had a type on it!" |
|
f" Existing type is {val.node.type} " |
|
f"and new type is {type}. " |
|
f"This could happen if you tried to annotate a function parameter " |
|
f"value (in which case you should use the type slot " |
|
f"on the function signature) or you called " |
|
f"annotate on the same value twice") |
|
else: |
|
val.node.type = type |
|
return val |
|
else: |
|
return val |
|
|