|
import os |
|
|
|
def fix_basicsr_import(): |
|
file_path = os.path.join(os.path.dirname(os.__file__), 'site-packages/basicsr/data/degradations.py') |
|
|
|
if os.path.exists(file_path): |
|
with open(file_path, "r") as file: |
|
data = file.read() |
|
|
|
data = data.replace("from torchvision.transforms.functional_tensor import rgb_to_grayscale", |
|
"from torchvision.transforms.functional import rgb_to_grayscale") |
|
|
|
with open(file_path, "w") as file: |
|
file.write(data) |
|
|
|
print("Fixed basicsr import issue.") |
|
else: |
|
print(f"File {file_path} does not exist. Please check the path.") |
|
|
|
fix_basicsr_import() |
|
|