plantvision / plantvision.py
BraydenMoore's picture
Add cache
6573ecd
import requests
from io import BytesIO
from PIL import Image, ImageOps
import torchvision.transforms as T
import torch
import gc
import pickle as pkl
from pathlib import Path
THIS_FOLDER = Path(__file__).parent.resolve()
import datetime as dt
from transformers import AutoModel
import torch.nn as nn
import torch.nn.functional as F
import logging
logging.disable(logging.INFO)
logging.disable(logging.WARNING)
visionTransformer = AutoModel.from_pretrained(r"google/vit-base-patch16-224-in21k",
cache_dir=None
)
class PlantVision(nn.Module):
def __init__(self, num_classes):
super(PlantVision, self).__init__()
self.vit = visionTransformer
count = 0
for child in self.vit.children():
count += 1
if count < 4:
for param in child.parameters():
param.requires_grad = False
self.vitLayers = list(self.vit.children())
self.vitTop = nn.Sequential(*self.vitLayers[:-2])
self.vitNorm = list(self.vit.children())[2]
self.vit = None
gc.collect()
self.vitFlatten = nn.Flatten()
self.vitLinear = nn.Linear(151296,num_classes)
self.fc = nn.Linear(num_classes, num_classes)
def forward(self, input):
output = self.vitTop(input).last_hidden_state
output = self.vitNorm(output)
output = self.vitFlatten(output)
output = F.relu(self.vitLinear(output))
output = self.fc(output)
return output
device = 'cpu' # ('cuda' if torch.cuda.is_available else 'cpu')
with open(fr'{THIS_FOLDER}/resources/flowerLabelSet.pkl', 'rb') as f:
flowerLabelSet = pkl.load(f)
with open(fr'{THIS_FOLDER}/resources/leafLabelSet.pkl', 'rb') as f:
leafLabelSet = pkl.load(f)
with open(fr'{THIS_FOLDER}/resources/fruitLabelSet.pkl', 'rb') as f:
fruitLabelSet = pkl.load(f)
def loadModel(feature, labelSet):
model = PlantVision(num_classes=len(labelSet))
model.vitFlatten.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-vitFlatten-weights.pt").content), map_location=torch.device(device)), strict=False)
model.vitLinear.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-vitLinear-weights.pt").content), map_location=torch.device(device)), strict=False)
model.fc.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-fc-weights.pt").content), map_location=torch.device(device)), strict=False)
model = model.half()
return model
start = dt.datetime.now()
flower = loadModel('flower',flowerLabelSet)
leaf = loadModel('leaf',leafLabelSet)
fruit = loadModel('fruit',fruitLabelSet)
print(dt.datetime.now() - start)
def processImage(imagePath, feature):
with open(fr'{THIS_FOLDER}/resources/{feature}MeansAndStds.pkl', 'rb') as f:
meansAndStds = pkl.load(f)
img = Image.open(imagePath).convert('RGB')
cropped = ImageOps.fit(img, (224,224), Image.Resampling.LANCZOS)
process = T.Compose([
T.CenterCrop(224),
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(
mean=meansAndStds['mean'],
std=meansAndStds['std'])
])
return process(cropped)
def see(tensor,feature,k):
if feature=='flower':
model = flower.float()
labelSet = flowerLabelSet
elif feature=='leaf':
model = leaf.float()
labelSet = leafLabelSet
elif feature=='fruit':
model = fruit.float()
labelSet = fruitLabelSet
with torch.no_grad():
output = model(tensor.unsqueeze(0))
top = torch.topk(output,k,dim=1)
predictions = top.indices[0]
predictedSpecies = []
for i in predictions:
predictedSpecies.append(labelSet[i])
model = None
gc.collect()
return predictedSpecies