Spaces:
Runtime error
Runtime error
# from transformers import AutoModel | |
import argparse | |
import logging | |
import os | |
import glob | |
import tqdm | |
import torch, re | |
import PIL | |
import cv2 | |
import numpy as np | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from utils import Config, Logger, CharsetMapper | |
import gradio as gr | |
import gdown | |
gdown.download(id='16PF_b4dURVkBt4OT7E-a-vq-SRxi0uDl', output='lol.pth') | |
gdown.download(id='19rGjfo73P25O_keQv30snfe3IHrK0uV2', output='config.yaml') | |
gdown.download(id='1qyNV80qmYHx_r4KsG3_8PXQ6ff1a1dov', output='modules.zip') | |
os.system('unzip modules.zip') | |
gdown.download(id='1UMZ7i8SpfuNw0N2JvVY8euaNx9gu3x6N', output='configs.zip') | |
os.system('unzip configs.zip') | |
gdown.download(id='1yHD7_4DD_keUwGs2nenAYDaQ2CNEA5IU', output='data.zip') | |
os.system('unzip data.zip') | |
def get_model(config): | |
import importlib | |
names = config.model_name.split('.') | |
module_name, class_name = '.'.join(names[:-1]), names[-1] | |
cls = getattr(importlib.import_module(module_name), class_name) | |
model = cls(config) | |
logging.info(model) | |
model = model.eval() | |
return model | |
def load(model, file, device=None, strict=True): | |
if device is None: device = 'cpu' | |
elif isinstance(device, int): device = torch.device('cuda', device) | |
assert os.path.isfile(file) | |
state = torch.load(file, map_location=device) | |
if set(state.keys()) == {'model', 'opt'}: | |
state = state['model'] | |
model.load_state_dict(state, strict=strict) | |
return model | |
config = Config('config.yaml') | |
config.model_vision_checkpoint = None | |
model = get_model(config) | |
model = load(model, 'lol.pth') | |
def postprocess(output, charset, model_eval): | |
def _get_output(last_output, model_eval): | |
if isinstance(last_output, (tuple, list)): | |
for res in last_output: | |
if res['name'] == model_eval: output = res | |
else: output = last_output | |
return output | |
def _decode(logit): | |
""" Greed decode """ | |
out = F.softmax(logit, dim=2) | |
pt_text, pt_scores, pt_lengths = [], [], [] | |
for o in out: | |
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) | |
text = text.split(charset.null_char)[0] # end at end-token | |
pt_text.append(text) | |
pt_scores.append(o.max(dim=1)[0]) | |
pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token | |
return pt_text, pt_scores, pt_lengths | |
output = _get_output(output, model_eval) | |
logits, pt_lengths = output['logits'], output['pt_lengths'] | |
pt_text, pt_scores, pt_lengths_ = _decode(logits) | |
return pt_text, pt_scores, pt_lengths_ | |
def preprocess(img, width, height): | |
img = cv2.resize(np.array(img), (width, height)) | |
img = transforms.ToTensor()(img).unsqueeze(0) | |
mean = torch.tensor([0.485, 0.456, 0.406]) | |
std = torch.tensor([0.229, 0.224, 0.225]) | |
return (img-mean[...,None,None]) / std[...,None,None] | |
def process_image(image): | |
charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1) | |
img = image.convert('RGB') | |
img = preprocess(img, config.dataset_image_width, config.dataset_image_height) | |
res = model(img) | |
return postprocess(res, charset, 'alignment')[0][0] | |
iface = gr.Interface(fn=process_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Textbox(), | |
title="8kun kek", | |
description="Making Jim Watkins sheete because he is a techlet pedo", | |
# article=article, | |
# examples=glob.glob('figs/test/*.png') | |
) | |
iface.launch(debug=True) |