AdarshRavis commited on
Commit
1848a98
1 Parent(s): 6a2a70d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from keras.models import load_model
3
+ import numpy as np
4
+ from PIL import Image
5
+ from matplotlib import pyplot as plt
6
+ import requests
7
+
8
+
9
+ # Load your trained model
10
+ model = load_model('Demarker_v1.h5')
11
+
12
+ # Load and preprocess your custom image
13
+ def load_and_preprocess_image(image_path, target_shape=(256, 256)):
14
+ # Open the image using PIL
15
+ custom_image = Image.open(image_path)
16
+
17
+ # Resize the image to the desired shape
18
+ custom_image = custom_image.resize(target_shape)
19
+
20
+ # Convert the image to an array
21
+ custom_image = np.array(custom_image)
22
+
23
+ # Assuming the input image is in RGB format
24
+ # Check if the image needs normalization (pixel values in [0, 255] range)
25
+ if custom_image.max() > 1.0:
26
+ # Normalize the pixel values to the range [-1, 1]
27
+ custom_image = (custom_image.astype(np.float32) - 127.5) / 127.50
28
+
29
+ # Add a batch dimension
30
+ custom_image = np.expand_dims(custom_image, axis=0)
31
+
32
+ return custom_image
33
+
34
+
35
+ # Plot source and generated images
36
+ def plot_images(src_img, gen_img):
37
+ images = np.vstack((src_img, gen_img))
38
+ # Scale from [-1,1] to [0,1]
39
+ images = (images + 1) / 2.0
40
+ titles = ['Original', 'Generated']
41
+ # Plot images row by row
42
+ for i in range(len(images)):
43
+ # Define subplot
44
+ plt.subplot(1, 2, 1 + i)
45
+ # Turn off axis
46
+ plt.axis('off')
47
+ # Plot raw pixel data
48
+ plt.imshow(images[i])
49
+ # Show title
50
+ plt.title(titles[i])
51
+ st.pyplot()
52
+
53
+ # Streamlit app
54
+ st.title("GAN Image Generator")
55
+
56
+ # Upload custom image
57
+ custom_image = st.file_uploader("Upload a custom image", type=["png", "jpg", "jpeg"])
58
+
59
+ # Check if an image is uploaded
60
+ if custom_image:
61
+ # Generate an image from your custom source
62
+ gen_image = model.predict(load_and_preprocess_image(custom_image))
63
+
64
+ # Plot both images together
65
+ plot_images(load_and_preprocess_image(custom_image), gen_image)