sedrickkeh's picture
Upload 13 files
016285f
raw
history blame
1.63 kB
import os
import logging
import matplotlib.pyplot as plt
from PIL import Image
import nltk
def logging_handler(verbose, save_name, idx=0):
logger = logging.getLogger(str(idx))
logger.setLevel(logging.INFO)
stream_logger = logging.StreamHandler()
stream_logger.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_logger)
if save_name is not None:
savepath = f"results/{save_name}"
if not os.path.exists(savepath):
os.makedirs(savepath)
file_logger = logging.FileHandler(f"{savepath}/{idx}.log")
file_logger.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(file_logger)
return logger
def image_saver(images, save_name, idx=0, interactive=True):
fig, a = plt.subplots(2,5)
fig.set_size_inches(30, 15)
for i in range(10):
a[i//5][i%5].imshow(images[i])
a[i//5][i%5].axis('off')
a[i//5][i%5].set_aspect('equal')
plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
if not interactive:
plt.savefig(f"results/{save_name}/{idx}.png")
else:
plt.savefig(f"{save_name}.png")
def assert_checks(args):
if args.question_strategy=="gpt3":
assert args.include_what
def extract_nouns(sents):
noun_list = []
for idx, s in enumerate(sents):
curr = []
sent = (nltk.pos_tag(s.split()))
for word in sent:
if word[1] not in ["NN", "NNS"]: continue
currword = word[0].replace('.','')
curr.append(currword.lower())
noun_list.append(curr)
return noun_list