|
from pathlib import Path |
|
|
|
from datasets import disable_caching, enable_caching, is_caching_enabled |
|
from datasets.utils.py_utils import get_imports |
|
|
|
from .file_utils import get_all_files_in_dir |
|
|
|
HF_CACHING_ENABLED = False |
|
|
|
|
|
class HFCachingContextManager: |
|
def __init__(self, should_cache): |
|
self.should_cache = should_cache |
|
self.was_caching_enabled = None |
|
|
|
def __enter__(self): |
|
self.was_caching_enabled = is_caching_enabled() |
|
if self.should_cache: |
|
if not self.was_caching_enabled: |
|
enable_caching() |
|
else: |
|
if self.was_caching_enabled: |
|
disable_caching() |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if self.was_caching_enabled != is_caching_enabled(): |
|
if self.was_caching_enabled: |
|
enable_caching() |
|
else: |
|
disable_caching() |
|
|
|
|
|
def get_missing_imports(file, exclude=[]): |
|
src_dir = Path(__file__).parent |
|
python_files = get_all_files_in_dir(src_dir, file_extension=".py") |
|
|
|
required_modules = [Path(p).stem for p in python_files] |
|
imports = get_imports(file) |
|
imported_modules = [i[1] for i in imports if i[0] == "internal"] |
|
missing_imports = [i for i in required_modules if i not in imported_modules and i not in exclude] |
|
return missing_imports |
|
|