Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import torchvision | |
from PIL import Image | |
from utils import * | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from huggingface_hub import Repository, upload_file | |
from torch.utils.data import Dataset | |
import numpy as np | |
from collections import Counter | |
with open('app.css','r') as f: | |
BLOCK_CSS = f.read() | |
n_epochs = 10 | |
batch_size_train = 128 | |
batch_size_test = 1000 | |
learning_rate = 0.01 | |
momentum = 0.5 | |
log_interval = 10 | |
random_seed = 1 | |
TRAIN_CUTOFF = 10 | |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF) | |
METRIC_PATH = './metrics.json' | |
REPOSITORY_DIR = "data" | |
LOCAL_DIR = 'data_local' | |
os.makedirs(LOCAL_DIR,exist_ok=True) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
HF_DATASET ="mnist-adversarial-dataset" | |
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}" | |
repo = Repository( | |
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
) | |
repo.git_pull() | |
torch.backends.cudnn.enabled = False | |
torch.manual_seed(random_seed) | |
class MNISTAdversarial_Dataset(Dataset): | |
def __init__(self,data_dir,transform): | |
repo.git_pull() | |
self.data_dir = os.path.join(data_dir,'data') | |
self.transform = transform | |
files = [f.name for f in os.scandir(self.data_dir)] | |
self.images = [] | |
self.numbers = [] | |
for f in files: | |
self.FOLDER = os.path.join(os.path.join(self.data_dir,f)) | |
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl') | |
image_path =os.path.join(self.FOLDER,'image.png') | |
if os.path.exists(image_path) and os.path.exists(metadata_path): | |
img = Image.open(image_path) | |
self.images.append(img) | |
metadata = read_json_lines(metadata_path) | |
self.numbers.append(metadata[0]['correct_number']) | |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers." | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self,idx): | |
img, label = self.images[idx], self.numbers[idx] | |
img = self.transform(img) | |
return img, label | |
class MNISTCorrupted_By_Digit(Dataset): | |
def __init__(self,transform,digit,limit=300): | |
self.transform = transform | |
self.digit = digit | |
corrupted_dir="./mnist_c" | |
files = [f.name for f in os.scandir(corrupted_dir)] | |
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] | |
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] | |
self.data = np.vstack(images) | |
self.labels = np.hstack(labels) | |
assert (self.data.shape[0] == self.labels.shape[0]) | |
mask = self.labels == self.digit | |
data_masked = self.data[mask] | |
# Just to be on the safe side, ensure limit is more than the minimum | |
limit = min(limit,data_masked.shape[0]) | |
self.data_for_use = data_masked[:limit] | |
self.labels_for_use = self.labels[mask][:limit] | |
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0]) | |
def __len__(self): | |
return len(self.data_for_use) | |
def __getitem__(self,idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
image = self.data_for_use[idx] | |
label = self.labels_for_use[idx] | |
if self.transform: | |
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
image = self.transform(image_pil) | |
return image, label | |
class MNISTCorrupted(Dataset): | |
def __init__(self,transform): | |
self.transform = transform | |
corrupted_dir="./mnist_c" | |
files = [f.name for f in os.scandir(corrupted_dir)] | |
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:300] for f in files] | |
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:300] for f in files] | |
self.data = np.vstack(images) | |
self.labels = np.hstack(labels) | |
assert (self.data.shape[0] == self.labels.shape[0]) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
image = self.data[idx] | |
label = self.labels[idx] | |
if self.transform: | |
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
image = self.transform(image_pil) | |
return image, label | |
TRAIN_TRANSFORM = torchvision.transforms.Compose([ | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize( | |
(0.1307,), (0.3081,)) | |
]) | |
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM), | |
batch_size=batch_size_test, shuffle=False) | |
# Source: https://nextjournal.com/gkoehler/pytorch-mnist | |
class MNIST_Model(nn.Module): | |
def __init__(self): | |
super(MNIST_Model, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return F.log_softmax(x) | |
def train(epochs,network,optimizer,train_loader): | |
train_losses=[] | |
network.train() | |
for epoch in range(epochs): | |
for batch_idx, (data, target) in enumerate(train_loader): | |
optimizer.zero_grad() | |
output = network(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
train_losses.append(loss.item()) | |
torch.save(network.state_dict(), 'model.pth') | |
torch.save(optimizer.state_dict(), 'optimizer.pth') | |
def test(): | |
test_losses=[] | |
network.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
output = network(data) | |
test_loss += F.nll_loss(output, target, size_average=False).item() | |
pred = output.data.max(1, keepdim=True)[1] | |
correct += pred.eq(target.data.view_as(pred)).sum() | |
test_loss /= len(test_loader.dataset) | |
test_losses.append(test_loss) | |
acc = 100. * correct / len(test_loader.dataset) | |
acc = acc.item() | |
test_metric = 'γ½Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format( | |
test_loss,acc) | |
return test_metric,acc | |
random_seed = 1 | |
torch.backends.cudnn.enabled = False | |
torch.manual_seed(random_seed) | |
network = MNIST_Model() | |
optimizer = optim.SGD(network.parameters(), lr=learning_rate, | |
momentum=momentum) | |
model_state_dict = 'model.pth' | |
optimizer_state_dict = 'optmizer.pth' | |
if os.path.exists(model_state_dict): | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
if os.path.exists(optimizer_state_dict): | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
# Train | |
#train(n_epochs,network,optimizer) | |
def image_classifier(inp): | |
""" | |
It takes an image as input and returns a dictionary of class labels and their corresponding | |
confidence scores. | |
:param inp: the image to be classified | |
:return: A dictionary of the class index and the confidence value. | |
""" | |
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0) | |
#pred_number = prediction.data.max(1, keepdim=True)[1] | |
sorted_prediction = torch.sort(prediction,descending=True) | |
confidences={} | |
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()): | |
confidences.update({s:v}) | |
return confidences | |
def train_and_test(): | |
# Train for one epoch and test | |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM) | |
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True | |
) | |
train(n_epochs,network,optimizer,train_loader) | |
test_metric,test_acc = test() | |
if os.path.exists(METRIC_PATH): | |
metric_dict = read_json(METRIC_PATH) | |
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc] | |
else: | |
metric_dict={} | |
metric_dict['all'] = [test_acc] | |
for i in range(10): | |
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i) | |
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False) | |
data_per_digit, label_per_digit = iter(dataloader_per_digit).next() | |
output = network(data_per_digit) | |
pred = output.data.max(1, keepdim=True)[1] | |
correct = pred.eq(label_per_digit.data.view_as(pred)).sum() | |
acc = 100. * correct / len(data_per_digit) | |
acc=acc.item() | |
if os.path.exists(METRIC_PATH): | |
metric_dict[str(i)].append(acc) | |
else: | |
metric_dict[str(i)] = [acc] | |
dump_json(thing=metric_dict,file=METRIC_PATH) | |
return test_metric | |
def flag(input_image,correct_result,adversarial_number): | |
adversarial_number = 0 if None else adversarial_number | |
metadata_name = get_unique_name() | |
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name) | |
os.makedirs(SAVE_FILE_DIR,exist_ok=True) | |
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png') | |
try: | |
input_image.save(image_output_filename) | |
except Exception: | |
raise Exception(f"Had issues saving PIL image to file") | |
# Write metadata.json to file | |
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl') | |
metadata= {'id':metadata_name,'file_name':'image.png', | |
'correct_number':correct_result | |
} | |
dump_json(metadata,json_file_path) | |
# Simply upload the image file and metadata using the hub's upload_file | |
# Upload the image | |
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png')) | |
_ = upload_file(path_or_fileobj = image_output_filename, | |
path_in_repo =repo_image_path, | |
repo_id=f'chrisjay/{HF_DATASET}', | |
repo_type='dataset', | |
token=HF_TOKEN | |
) | |
# Upload the metadata | |
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl')) | |
_ = upload_file(path_or_fileobj = json_file_path, | |
path_in_repo =repo_json_path, | |
repo_id=f'chrisjay/{HF_DATASET}', | |
repo_type='dataset', | |
token=HF_TOKEN | |
) | |
adversarial_number+=1 | |
output = f'<div> β ({adversarial_number}) Successfully saved your adversarial data. </div>' | |
repo.git_pull() | |
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")]) | |
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>" | |
if length_of_dataset % TRAIN_CUTOFF ==0: | |
test_metric_ = train_and_test() | |
test_metric = f"<html> {test_metric_} </html>" | |
output = f'<div> β ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>' | |
return output,adversarial_number | |
def get_number_dict(DATA_DIR): | |
files = [f.name for f in os.scandir(DATA_DIR)] | |
numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files] | |
numbers_count = Counter(numbers) | |
numbers_count_keys = list(numbers_count.keys()) | |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys] | |
return numbers_count_keys,numbers_count_values | |
def get_statistics(): | |
model_state_dict = 'model.pth' | |
optimizer_state_dict = 'optmizer.pth' | |
if os.path.exists(model_state_dict): | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
if os.path.exists(optimizer_state_dict): | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
repo.git_pull() | |
DATA_DIR = './data_mnist/data' | |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR) | |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values)) | |
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits") | |
fig_d, ax_d = plt.subplots(tight_layout=True) | |
if os.path.exists(METRIC_PATH): | |
metric_dict = read_json(METRIC_PATH) | |
for i in range(10): | |
try: | |
x_i = [i+1 for i in range(len(metric_dict[str(i)]))] | |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i)) | |
except Exception: | |
continue | |
dump_json(thing=metric_dict,file=METRIC_PATH) | |
else: | |
metric_dict={} | |
fig_d.legend() | |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step") | |
done_html = """<div style="color: green"> | |
<p> β Statistics loaded successfully!</p> | |
</div> | |
""" | |
return plt_digits,fig_d,done_html,STATS_EXPLANATION_ | |
def main(): | |
block = gr.Blocks(css=BLOCK_CSS) | |
with block: | |
gr.Markdown(TITLE) | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem('MNIST'): | |
gr.Markdown(WHAT_TO_DO) | |
#test_metric = gr.outputs.HTML("") | |
with gr.Row(): | |
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil") | |
label_output = gr.outputs.Label(num_top_classes=2) | |
submit = gr.Button("Submit") | |
gr.Markdown(MODEL_IS_WRONG) | |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?") | |
flag_btn = gr.Button("Flag") | |
output_result = gr.outputs.HTML() | |
adversarial_number = gr.Variable(value=0) | |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output]) | |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number]) | |
with gr.TabItem('Dashboard') as dashboard: | |
notification = gr.HTML("""<div style="color: green"> | |
<p> β Creating statistics... </p> | |
</div> | |
""") | |
stats = gr.Markdown() | |
stat_adv_image =gr.Plot(type="matplotlib") | |
gr.Markdown(DASHBOARD_EXPLANATION) | |
test_results=gr.Plot(type="matplotlib") | |
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats]) | |
block.launch() | |
if __name__ == "__main__": | |
main() |