LegacyLeague commited on
Commit
b01849c
1 Parent(s): 0658f26

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import fastai
3
+ from fastai.vision import *
4
+ from fastai.utils.mem import *
5
+ from fastai.vision import open_image, load_learner, image, torch
6
+ import numpy as np4
7
+ import urllib.request
8
+ import PIL.Image
9
+ from io import BytesIO
10
+ import torchvision.transforms as T
11
+ from PIL import Image
12
+ import requests
13
+ from io import BytesIO
14
+ import fastai
15
+ from fastai.vision import *
16
+ from fastai.utils.mem import *
17
+ from fastai.vision import open_image, load_learner, image, torch
18
+ import numpy as np
19
+ import urllib.request
20
+ from urllib.request import urlretrieve
21
+ import PIL.Image
22
+ from io import BytesIO
23
+ import torchvision.transforms as T
24
+ import torchvision.transforms as tfms
25
+
26
+ class FeatureLoss(nn.Module):
27
+ def __init__(self, m_feat, layer_ids, layer_wgts):
28
+ super().__init__()
29
+ self.m_feat = m_feat
30
+ self.loss_features = [self.m_feat[i] for i in layer_ids]
31
+ self.hooks = hook_outputs(self.loss_features, detach=False)
32
+ self.wgts = layer_wgts
33
+ self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
34
+ ] + [f'gram_{i}' for i in range(len(layer_ids))]
35
+
36
+ def make_features(self, x, clone=False):
37
+ self.m_feat(x)
38
+ return [(o.clone() if clone else o) for o in self.hooks.stored]
39
+
40
+ def forward(self, input, target):
41
+ out_feat = self.make_features(target, clone=True)
42
+ in_feat = self.make_features(input)
43
+ self.feat_losses = [base_loss(input,target)]
44
+ self.feat_losses += [base_loss(f_in, f_out)*w
45
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
46
+ self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
47
+ for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
48
+ self.metrics = dict(zip(self.metric_names, self.feat_losses))
49
+ return sum(self.feat_losses)
50
+
51
+ def __del__(self): self.hooks.remove()
52
+
53
+ MODEL_URL = "https://www.dropbox.com/s/daf70v42oo93kym/Legacy_best.pkl?dl=1"
54
+ urllib.request.urlretrieve(MODEL_URL, "Legacy_best.pkl")
55
+ path = Path(".")
56
+ learn=load_learner(path, 'Legacy_best.pkl')
57
+
58
+ urlretrieve("https://s.hdnux.com/photos/01/07/33/71/18726490/5/1200x0.jpg","soccer1.jpg")
59
+ urlretrieve("https://cdn.vox-cdn.com/thumbor/4J8EqJBsS2qEQltIBuFOJWSn8dc=/1400x1400/filters:format(jpeg)/cdn.vox-cdn.com/uploads/chorus_asset/file/22466347/1312893179.jpg","soccer2.jpg")
60
+ urlretrieve("https://cdn.vox-cdn.com/thumbor/VHa7adj0Oie2Ao12RwKbs40i58s=/0x0:2366x2730/1200x800/filters:focal(1180x774:1558x1152)/cdn.vox-cdn.com/uploads/chorus_image/image/69526697/E5GnQUTWEAEK445.0.jpg","baseball.jpg")
61
+ urlretrieve("https://baseball.ca/uploads/images/content/Diodati(1).jpeg","baseball2.jpeg")
62
+
63
+ sample_images = [["soccer1.jpg"],
64
+ ["soccer2.jpg"],
65
+ ["baseball.jpg"],
66
+ ["baseball2.jpeg"]]
67
+
68
+
69
+ def predict(input):
70
+ img_t = T.ToTensor()(input)
71
+ img_fast = Image(img_t)
72
+ p,img_hr,b = learn.predict(img_fast)
73
+ x = np.minimum(np.maximum(image2np(img_hr.data*255), 0), 255).astype(np.uint8)
74
+ img = PIL.Image.fromarray(x)
75
+ return img
76
+
77
+ gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(), outputs="image", title='Legacy-League',examples=sample_images).launch();