Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Abubakar Abid
commited on
Commit
•
b9e6b57
1
Parent(s):
a10ea79
all files
Browse files- .gitattributes +0 -16
- .gitignore +1 -0
- app.py +90 -0
- img1.jpg +0 -0
- img2.jpg +0 -0
- requirements.txt +7 -0
- thumbnail.png +0 -0
.gitattributes
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
weights/*
|
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, os.path
|
2 |
+
from os.path import splitext
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import wget
|
9 |
+
|
10 |
+
destination_folder = "output"
|
11 |
+
destination_for_weights = "weights"
|
12 |
+
|
13 |
+
if os.path.exists(destination_for_weights):
|
14 |
+
print("The weights are at", destination_for_weights)
|
15 |
+
else:
|
16 |
+
print("Creating folder at ", destination_for_weights, " to store weights")
|
17 |
+
os.mkdir(destination_for_weights)
|
18 |
+
|
19 |
+
segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
|
20 |
+
|
21 |
+
if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
|
22 |
+
print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
|
23 |
+
filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
|
24 |
+
else:
|
25 |
+
print("Segmentation Weights already present")
|
26 |
+
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
|
29 |
+
def collate_fn(x):
|
30 |
+
x, f = zip(*x)
|
31 |
+
i = list(map(lambda t: t.shape[1], x))
|
32 |
+
x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
|
33 |
+
return x, f, i
|
34 |
+
|
35 |
+
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
|
36 |
+
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
|
37 |
+
|
38 |
+
print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
|
39 |
+
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
print("cuda is available, original weights")
|
42 |
+
device = torch.device("cuda")
|
43 |
+
model = torch.nn.DataParallel(model)
|
44 |
+
model.to(device)
|
45 |
+
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
|
46 |
+
model.load_state_dict(checkpoint['state_dict'])
|
47 |
+
else:
|
48 |
+
print("cuda is not available, cpu weights")
|
49 |
+
device = torch.device("cpu")
|
50 |
+
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
|
51 |
+
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
|
52 |
+
model.load_state_dict(state_dict_cpu)
|
53 |
+
|
54 |
+
model.eval()
|
55 |
+
|
56 |
+
def segment(inp):
|
57 |
+
x = inp.transpose([2, 0, 1]) # channels-first
|
58 |
+
x = np.expand_dims(x, axis=0) # adding a batch dimension
|
59 |
+
|
60 |
+
mean = x.mean(axis=(0, 2, 3))
|
61 |
+
std = x.std(axis=(0, 2, 3))
|
62 |
+
x = x - mean.reshape(1, 3, 1, 1)
|
63 |
+
x = x / std.reshape(1, 3, 1, 1)
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
|
67 |
+
output = model(x)
|
68 |
+
|
69 |
+
y = output['out'].numpy()
|
70 |
+
y = y.squeeze()
|
71 |
+
|
72 |
+
out = y>0
|
73 |
+
|
74 |
+
mask = inp.copy()
|
75 |
+
mask[out] = np.array([0, 0, 255])
|
76 |
+
|
77 |
+
return mask
|
78 |
+
|
79 |
+
import gradio as gr
|
80 |
+
|
81 |
+
i = gr.inputs.Image(shape=(112, 112))
|
82 |
+
o = gr.outputs.Image()
|
83 |
+
|
84 |
+
examples = [["img1.jpg"], ["img2.jpg"]]
|
85 |
+
title = "Left Ventricle Segmentation"
|
86 |
+
description = "This semantic segmentation model identifies the left ventricle in echocardiogram videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of ‘Video-based AI for beat-to-beat assessment of cardiac function’ by Ouyang et al. in Nature, 2020."
|
87 |
+
thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
|
88 |
+
|
89 |
+
gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False,
|
90 |
+
title=title, description=description, thumbnail=thumbnail).launch()
|
img1.jpg
ADDED
img2.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
numpy
|
3 |
+
matplotlib
|
4 |
+
wget
|
5 |
+
torch==1.6.0+cpu
|
6 |
+
torchvision==0.7.0+cpu
|
7 |
+
|
thumbnail.png
ADDED