Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
svincoff's picture
adding utility files used throughout FusOn-pLM training and benchmarking
ffaff91
raw
history blame
3.81 kB
from datetime import datetime
from contextlib import contextmanager
import sys
import pytz
import os
class CustomParams:
"""
Class for custom parameters where dictionary elements can be accessed as attributes
"""
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def print_config(self,indent=''):
for attr, value in self.__dict__.items():
print(f"{indent}{attr}: {value}")
def log_update(text: str):
"""
Logs input text to an output file
Args:
text (str): the text to be logged
"""
print(text) # log_update the text
sys.stdout.flush() # flush to automatically update the output file
@contextmanager
def open_logfile(log_path,mode='w'):
"""
Open log-file for real-time logging of the most important updates
"""
log_file = open(log_path, mode) # open
original_stdout = sys.stdout # save original stdout
sys.stdout = log_file # redirect stdout to log_file
try:
yield log_file
finally:
sys.stdout = original_stdout
log_file.close()
@contextmanager
def open_errfile(log_path,mode='w'):
"""
Redirects stderr (error messages) to a separate log file.
"""
log_file = open(log_path, mode) # open the error log file for writing
original_stderr = sys.stderr # save original stderr
sys.stderr = log_file # redirect stderr to log_file
try:
yield log_file
finally:
sys.stderr = original_stderr # restore original stderr
log_file.close() # close the error log file
def print_configpy(module):
"""
Prints all the configurations in a config.py file
"""
log_update("All configurations:")
# Iterate over attributes
for attribute in dir(module):
# Filter out built-in attributes and methods
if not attribute.startswith("__"):
value = getattr(module, attribute)
log_update(f"\t{attribute}: {value}")
def get_local_time(timezone_str='US/Eastern'):
"""
Get current time in the specified timezone.
Args:
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
Returns:
str: The formatted current time in the specified timezone.
"""
try:
timezone = pytz.timezone(timezone_str)
except pytz.UnknownTimeZoneError:
return f"Unknown timezone: {timezone_str}"
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
return current_datetime.strftime('%m-%d-%Y-%H:%M:%S')
def get_local_date_yr(timezone_str='US/Eastern'):
"""
Get current time in the specified timezone.
Args:
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
Returns:
str: The formatted current time in the specified timezone.
"""
try:
timezone = pytz.timezone(timezone_str)
except pytz.UnknownTimeZoneError:
return f"Unknown timezone: {timezone_str}"
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
return current_datetime.strftime('%m_%d_%Y')
def find_fuson_plm_directory():
"""
Constructs a path backwards to fuson_plm directory so we don't have to use absolute paths (helps for docker containers)
"""
current_dir = os.path.abspath(os.getcwd())
while True:
if 'fuson_plm' in os.listdir(current_dir):
return os.path.join(current_dir, 'fuson_plm')
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
if parent_dir == current_dir: # If we've reached the root directory
raise FileNotFoundError("fuson_plm directory not found.")
current_dir = parent_dir