image-Style / tf_predict.py
d-e-e-k-11's picture
Upload folder using huggingface_hub
d1bfee5 verified
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from tf_models import Generator
def load_image(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [256, 256])
image = (image * 2) - 1
return tf.expand_dims(image, 0)
def predict(model, image_path):
image = load_image(image_path)
prediction = model(image, training=False)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(image[0] * 0.5 + 0.5)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("Predicted Image")
plt.imshow(prediction[0] * 0.5 + 0.5)
plt.axis("off")
plt.savefig("tf_prediction.png")
print("Prediction saved to tf_prediction.png")
def main():
model = Generator()
# Attempt to load existing .h5 files if they exist
potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
loaded = False
for weight_path in potential_weights:
if os.path.exists(weight_path):
try:
model.load_weights(weight_path, by_name=True, skip_mismatch=True)
print(f"Loaded weights from {weight_path}")
loaded = True
break
except Exception as e:
print(f"Could not load {weight_path}: {e}")
if not loaded:
print("Using untrained model.")
test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
if os.path.exists(test_image):
predict(model, test_image)
else:
print(f"Test image {test_image} not found.")
if __name__ == "__main__":
main()