Spaces:
Sleeping
Sleeping
from lib import * | |
from twokenize import tokenizeRawTweetText | |
import re | |
def muscaps_tokenize(raw): | |
raw = raw.lower() | |
for punc in string.punctuation: | |
raw = raw.replace(punc, ' ') | |
tokens = raw.split() | |
return tokens | |
def get_device(device_id: int) -> torch.device: | |
if not torch.cuda.is_available(): | |
return torch.device('cpu') | |
device_id = min(torch.cuda.device_count() - 1, device_id) | |
return torch.device(f'cuda:{device_id}') | |
def preproc(caption, tokenizer, stop=True): | |
caption = caption.replace('.', '<STOP>') | |
caption_proc = tokenizer.encode(caption) | |
if stop: | |
caption_proc += tokenizer.encode('.') | |
return caption_proc | |
def postproc(caption): | |
caption = caption.replace('<STOP>', '.') | |
if caption[-1] == '.': | |
caption = caption[:-1] | |
return caption | |
class CheckpointManager: | |
def __init__(self): | |
self.checkpoint_dir = '/home/nsrivats/Repositories/MusicCaptioning/checkpoints' | |
def get_checkpoint(self, checkpoint): | |
with open(checkpoint, 'rb') as infile: | |
return torch.load(infile) | |
def save_checkpoint(self, state_dict, checkpoint): | |
filename = f'{self.checkpoint_dir}/{checkpoint}' | |
with open(filename, 'wb') as outfile: | |
torch.save(state_dict, outfile) | |
def save_logs(self, logdir): | |
pass | |