farmNet / app.py
filipzawadka's picture
model registry
67f9c4e
import gradio as gr
import os
import requests
from PIL import Image
from torchvision import transforms
import torch
import torchvision.models as models
import torch.nn as nn
import io
import wandb
run = wandb.init(project="farmnet", job_type='inference')
artifact = run.use_artifact("farmnet_model_1:latest", type='model')
artifact_dir = artifact.download()
wandb.finish()
class FarmNet(nn.Module):
def __init__(self):
super(FarmNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 50 * 50, 512)
self.fc2 = nn.Linear(512, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = self.pool(self.relu(self.conv3(x)))
x = x.view(-1, 64 * 50 * 50)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
model = FarmNet() # Replace with your model and architecture
model.load_state_dict(torch.load('artifacts/farmnet_model_1:v0/farmnet_model.pth'))
model.eval() # Set the model to evaluation mode
# Preprocess the image
transform = transforms.Compose([
transforms.Resize((400, 400)), # Adjust according to your model's input size
transforms.ToTensor(),
])
# Print the prediction
classes = ['not farm', 'farm'] # Adjust according to your classes
#64.777466,-147.489792
def greet(latitude,longitude):
image_url = f"https://maps.googleapis.com/maps/api/staticmap?center={latitude},{longitude}&zoom=17&size=400x400&maptype=satellite&key={os.environ['GOOGLE_API_KEY']}"
response = requests.get(image_url)
img_data = response.content
pil_img = Image.open(io.BytesIO(img_data)).convert('RGB')
img = transform(pil_img)
img = img.unsqueeze(0) # Add batch dimension
# Make an inference
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
return gr.Image(pil_img), classes[predicted.item()]
iface = gr.Interface(fn=greet, inputs=["number","number"], outputs=["image","label"])
iface.launch()