import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # gpu_devices = tf.config.experimental.list_physical_devices('GPU') # if gpu_devices: # tf.config.experimental.set_memory_growth(gpu_devices[0], True) # else: # print(f"TensorFlow device: {gpu_devices}") from keras.applications import resnet import tensorflow as tf import keras import os import matplotlib.pyplot as plt from typing import Tuple from huggingface_hub import snapshot_download from labels import lookup_140 import numpy as np if not os.path.exists('model_classification'): REPO_ID='Serrelab/fossil_classification_models' token = os.getenv('READ_TOKEN') print('read token:',token) if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID,token=token,repo_type='model',local_dir='model_classification') def get_resnet_model(model_path): cce = tf.keras.losses.categorical_crossentropy model = keras.models.load_model(model_path, custom_objects = {"cce":cce}) g = keras.Model(model.input, model.layers[2].output) # out = tf.keras.layers.Activation('relu')(g_.output) # g = tf.keras.Model(model.input, out) h = keras.Model(model.layers[3].input, model.layers[-1].output) return model, g, h def select_top_n(preds,n=10): top_n = np.argsort(preds)[-n:][::-1] return top_n def parse_results(top_n,logits): results = {} for n in top_n: label = lookup_140[n] results[label] = float(logits[n]) return results def inference_resnet_embedding_v2(x,model,size=384,n_classes=140,n_top=10): x = tf.image.resize(x, (size, size)) x = tf.reshape(x, (384, 384, 3))/255 embedding = model.predict(np.array([x]))[0][0] return embedding def inference_resnet_finer_v2(x,model,size=384,n_classes=142,n_top=10): x = tf.image.resize(x, (size, size)) x = tf.reshape(x, (384, 384, 3))/255 # _, batch_logits = model.predict(x) # predictions = tf.math.top_k(batch_logits, k=10) # print(predictions) logits = model.predict(np.array([x])) print(len(logits[0][0])) print(logits) logits = tf.nn.softmax(logits[1][0]).cpu().numpy() print(logits) top_n = select_top_n(logits,n=n_top) print(top_n) return parse_results(top_n,logits)