pawlo2013 commited on
Commit
7b8a66e
1 Parent(s): e4acec1

init commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.st filter=lfs diff=lfs merge=lfs -text
__pycache__/app.cpython-310.pyc ADDED
Binary file (2.06 kB). View file
 
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ import torch
5
+ import random
6
+ import os
7
+ from models.structure.model import SketchKeras
8
+ from safetensors.torch import load_model
9
+ import cv2
10
+ import numpy as np
11
+
12
+
13
+ path_to_weights = os.path.join(
14
+ os.path.dirname(__file__), "models/weights/sketch_keras.st"
15
+ )
16
+ model = SketchKeras()
17
+ load_model(model, path_to_weights)
18
+ model.eval()
19
+
20
+
21
+ def preprocess(img):
22
+ h, w, c = img.shape
23
+ blurred = cv2.GaussianBlur(img, (0, 0), 3)
24
+ highpass = img.astype(int) - blurred.astype(int)
25
+ highpass = highpass.astype(float) / 128.0
26
+ highpass /= np.max(highpass)
27
+
28
+ ret = np.zeros((512, 512, 3), dtype=float)
29
+ ret[0:h, 0:w, 0:c] = highpass
30
+ return ret
31
+
32
+
33
+ def postprocess(pred, thresh=0.18, smooth=False):
34
+ assert thresh <= 1.0 and thresh >= 0.0
35
+
36
+ pred = np.amax(pred, 0)
37
+ pred[pred < thresh] = 0
38
+ pred = 1 - pred
39
+ pred *= 255
40
+ pred = np.clip(pred, 0, 255).astype(np.uint8)
41
+ if smooth:
42
+ pred = cv2.medianBlur(pred, 3)
43
+ return pred
44
+
45
+
46
+ def output_sketch(img):
47
+ # resize
48
+ height, width = float(img.shape[0]), float(img.shape[1])
49
+ if width > height:
50
+ new_width, new_height = (512, int(512 / width * height))
51
+ else:
52
+ new_width, new_height = (int(512 / height * width), 512)
53
+ img = cv2.resize(img, (new_width, new_height))
54
+
55
+ img = preprocess(img)
56
+ x = img.reshape(1, *img.shape).transpose(3, 0, 1, 2)
57
+ x = torch.tensor(x).float()
58
+
59
+ with torch.no_grad():
60
+ pred = model(x)
61
+
62
+ pred = pred.squeeze()
63
+
64
+ # postprocess
65
+ output = pred.cpu().detach().numpy()
66
+ output = postprocess(output, thresh=0.1, smooth=False)
67
+ output = output[:new_height, :new_width]
68
+
69
+ return output
70
+
71
+
72
+ gr.Interface(
73
+ title="Turn Any Image Into a Sketch",
74
+ fn=output_sketch,
75
+ inputs=gr.Image(type="numpy"),
76
+ outputs=gr.Image(type="numpy"),
77
+ ).launch()
flagged/img/11968894a849e0c35c0b/a1555132334_10.jpg ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ img,output,flag,username,timestamp
2
+ "{""path"":""flagged/img/11968894a849e0c35c0b/a1555132334_10.jpg"",""url"":""http://localhost:7860/file=/private/var/folders/cm/2c5qsgtd2z9_c__1cmj825dm0000gn/T/gradio/2ec5829897639b70466458f600fc46c698660291/a1555132334_10.jpg"",""size"":424307,""orig_name"":""a1555132334_10.jpg"",""mime_type"":""""}","{""path"":""flagged/output/0d8f13ee9e4c0ae5521d/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}",,,2023-12-08 15:35:22.482224
flagged/output/0d8f13ee9e4c0ae5521d/image.png ADDED
models/structure/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
models/structure/model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SketchKeras(nn.Module):
6
+ def __init__(self):
7
+ super(SketchKeras, self).__init__()
8
+
9
+ self.downblock_1 = nn.Sequential(
10
+ nn.ReflectionPad2d((1, 1, 1, 1)),
11
+ nn.Conv2d(1, 32, kernel_size=3, stride=1),
12
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
13
+ nn.ReLU(),
14
+ )
15
+ self.downblock_2 = nn.Sequential(
16
+ nn.ReflectionPad2d((1, 1, 1, 1)),
17
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
18
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
19
+ nn.ReLU(),
20
+ nn.ReflectionPad2d((1, 1, 1, 1)),
21
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
22
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
23
+ nn.ReLU(),
24
+ )
25
+ self.downblock_3 = nn.Sequential(
26
+ nn.ReflectionPad2d((1, 1, 1, 1)),
27
+ nn.Conv2d(64, 128, kernel_size=4, stride=2),
28
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
29
+ nn.ReLU(),
30
+ nn.ReflectionPad2d((1, 1, 1, 1)),
31
+ nn.Conv2d(128, 128, kernel_size=3, stride=1),
32
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
33
+ nn.ReLU(),
34
+ )
35
+ self.downblock_4 = nn.Sequential(
36
+ nn.ReflectionPad2d((1, 1, 1, 1)),
37
+ nn.Conv2d(128, 256, kernel_size=4, stride=2),
38
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
39
+ nn.ReLU(),
40
+ nn.ReflectionPad2d((1, 1, 1, 1)),
41
+ nn.Conv2d(256, 256, kernel_size=3, stride=1),
42
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
43
+ nn.ReLU(),
44
+ )
45
+ self.downblock_5 = nn.Sequential(
46
+ nn.ReflectionPad2d((1, 1, 1, 1)),
47
+ nn.Conv2d(256, 512, kernel_size=4, stride=2),
48
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
49
+ nn.ReLU(),
50
+ )
51
+ self.downblock_6 = nn.Sequential(
52
+ nn.ReflectionPad2d((1, 1, 1, 1)),
53
+ nn.Conv2d(512, 512, kernel_size=3, stride=1),
54
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
55
+ nn.ReLU(),
56
+ )
57
+
58
+ self.upblock_1 = nn.Sequential(
59
+ nn.Upsample((64, 64)),
60
+ nn.ReflectionPad2d((1, 2, 1, 2)),
61
+ nn.Conv2d(1024, 512, kernel_size=4, stride=1),
62
+ nn.BatchNorm2d(512, eps=1e-3, momentum=0),
63
+ nn.ReLU(),
64
+ nn.ReflectionPad2d((1, 1, 1, 1)),
65
+ nn.Conv2d(512, 256, kernel_size=3, stride=1),
66
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
67
+ nn.ReLU(),
68
+ )
69
+
70
+ self.upblock_2 = nn.Sequential(
71
+ nn.Upsample((128, 128)),
72
+ nn.ReflectionPad2d((1, 2, 1, 2)),
73
+ nn.Conv2d(512, 256, kernel_size=4, stride=1),
74
+ nn.BatchNorm2d(256, eps=1e-3, momentum=0),
75
+ nn.ReLU(),
76
+ nn.ReflectionPad2d((1, 1, 1, 1)),
77
+ nn.Conv2d(256, 128, kernel_size=3, stride=1),
78
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
79
+ nn.ReLU(),
80
+ )
81
+
82
+ self.upblock_3 = nn.Sequential(
83
+ nn.Upsample((256, 256)),
84
+ nn.ReflectionPad2d((1, 2, 1, 2)),
85
+ nn.Conv2d(256, 128, kernel_size=4, stride=1),
86
+ nn.BatchNorm2d(128, eps=1e-3, momentum=0),
87
+ nn.ReLU(),
88
+ nn.ReflectionPad2d((1, 1, 1, 1)),
89
+ nn.Conv2d(128, 64, kernel_size=3, stride=1),
90
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
91
+ nn.ReLU(),
92
+ )
93
+
94
+ self.upblock_4 = nn.Sequential(
95
+ nn.Upsample((512, 512)),
96
+ nn.ReflectionPad2d((1, 2, 1, 2)),
97
+ nn.Conv2d(128, 64, kernel_size=4, stride=1),
98
+ nn.BatchNorm2d(64, eps=1e-3, momentum=0),
99
+ nn.ReLU(),
100
+ nn.ReflectionPad2d((1, 1, 1, 1)),
101
+ nn.Conv2d(64, 32, kernel_size=3, stride=1),
102
+ nn.BatchNorm2d(32, eps=1e-3, momentum=0),
103
+ nn.ReLU(),
104
+ )
105
+
106
+ self.last_pad = nn.ReflectionPad2d((1, 1, 1, 1))
107
+ self.last_conv = nn.Conv2d(64, 1, kernel_size=3, stride=1)
108
+
109
+ def forward(self, x):
110
+ d1 = self.downblock_1(x)
111
+ d2 = self.downblock_2(d1)
112
+ d3 = self.downblock_3(d2)
113
+ d4 = self.downblock_4(d3)
114
+ d5 = self.downblock_5(d4)
115
+ d6 = self.downblock_6(d5)
116
+
117
+ u1 = torch.cat((d5, d6), dim=1)
118
+ u1 = self.upblock_1(u1)
119
+ u2 = torch.cat((d4, u1), dim=1)
120
+ u2 = self.upblock_2(u2)
121
+ u3 = torch.cat((d3, u2), dim=1)
122
+ u3 = self.upblock_3(u3)
123
+ u4 = torch.cat((d2, u3), dim=1)
124
+ u4 = self.upblock_4(u4)
125
+ u5 = torch.cat((d1, u4), dim=1)
126
+
127
+ out = self.last_conv(self.last_pad(u5))
128
+
129
+ return out
models/weights/sketch_keras.st ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1237e0231dc99879bc6f3aa8371186b149054a38b827dd0d39ec5e394dff02
3
+ size 74588236
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ datasets
3
+ tqdm
4
+ accelerate
5
+ torchinfo
6
+ diffusers
7
+ transformers
8
+ pathlib
9
+ safetensors
10
+ torchvision
11
+ Pillow