AndresZarta commited on
Commit
f8e4331
1 Parent(s): e1e22f4

App and Models

Browse files
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shiny import App, render, ui
2
+ from shiny.ui import output_image, input_file, file_input, panel_main, panel_sidebar, layout_sidebar
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import torch # Assuming the model is a PyTorch model
6
+ from PIL import Image
7
+ import io
8
+
9
+ # Load your pre-trained model
10
+ model = torch.load('path/to/your/model.pth')
11
+ model.eval()
12
+
13
+ def predict(image_file):
14
+ # Open image
15
+ image = Image.open(image_file)
16
+ # Process image as per model requirements (resize, normalize, etc.)
17
+ # Convert to tensor, e.g., if model expects tensor input
18
+ image_tensor = torch.Tensor(np.array(image) / 255).unsqueeze(0) # adjust processing as needed
19
+ with torch.no_grad():
20
+ output = model(image_tensor)
21
+ # Post-process output, e.g., binarize, segment
22
+ segmentation = (output > 0.5).squeeze().numpy() # adjust as needed
23
+ return segmentation
24
+
25
+ def server(input, output, session):
26
+ @output
27
+ @render.image
28
+ def segmented_image():
29
+ # Check if a file has been uploaded
30
+ if not input.image_file:
31
+ return None
32
+
33
+ # Predict segmentation
34
+ uploaded_file = input.image_file()[0]
35
+ segmentation = predict(uploaded_file)
36
+
37
+ # Create figure to show segmentation
38
+ plt.imshow(segmentation, cmap='gray')
39
+ plt.axis('off')
40
+
41
+ # Save to buffer and return
42
+ buf = io.BytesIO()
43
+ plt.savefig(buf, format='png')
44
+ buf.seek(0)
45
+ return buf
46
+
47
+ app_ui = ui.page_fluid(
48
+ layout_sidebar(
49
+ panel_sidebar(
50
+ file_input("image_file", "Upload an image:", multiple=False),
51
+ ),
52
+ panel_main(
53
+ output_image("segmented_image"),
54
+ ),
55
+ )
56
+ )
57
+
58
+ app = App(app_ui, server)
models/model_checkpoint_trained_on_train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95f8d4de38c23599b69a8a99b60b5b2c4cf1671847ca2cdfe8760a9a8ec4d302
3
+ size 375066348
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ shiny
2
+ torch
3
+ numpy
4
+ pillow
5
+ matplotlib