Overglitch's picture
Update app.py
a2a680b verified
raw
history blame
2.78 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
import numpy as np
import pickle
from io import BytesIO
import math
def load_model():
with open('somlucuma.pkl', 'rb') as fid:
som = pickle.load(fid)
MM = np.loadtxt('matrizMM.txt', delimiter=" ")
return som, MM
def sobel(I):
m, n = I.shape
Gx = np.zeros([m-2, n-2], np.float32)
Gy = np.zeros([m-2, n-2], np.float32)
gx = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]
gy = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
for j in range(1, m-2):
for i in range(1, n-2):
Gx[j-1, i-1] = sum(sum(I[j-1:j+2, i-1:i+2] * gx))
Gy[j-1, i-1] = sum(sum(I[j-1:j+2, i-1:i+2] * gy))
return Gx, Gy
def medfilt2(G, d=3):
m, n = G.shape
temp = np.zeros([m+2*(d//2), n+2*(d//2)], np.float32)
salida = np.zeros([m, n], np.float32)
temp[1:m+1, 1:n+1] = G
for i in range(1, m):
for j in range(1, n):
A = np.asarray(temp[i-1:i+2, j-1:j+2]).reshape(-1)
salida[i-1, j-1] = np.sort(A)[d+1]
return salida
def orientacion(patron, w):
Gx, Gy = sobel(patron)
Gx = medfilt2(Gx)
Gy = medfilt2(Gy)
m, n = Gx.shape
mOrientaciones = np.zeros([m//w, n//w], np.float32)
for i in range(m//w):
for j in range(n//w):
YY = sum(sum(2*Gx[i*w:(i+1)*w, j:j+1] * Gy[i*w:(i+1)*w, j:j+1]))
XX = sum(sum(Gx[i*w:(i+1)*w, j:j+1]**2 - Gy[i*w:(i+1)*w, j:j+1]**2))
mOrientaciones[i, j] = (0.5 * math.atan2(YY, XX) + math.pi / 2.0) * (180.0 / math.pi)
return mOrientaciones
def representativo(imarray):
imarray = np.squeeze(imarray)
m, n = imarray.shape
patron = imarray[1:m-1, 1:n-1]
EE = orientacion(patron, 14)
return np.asarray(EE).reshape(-1)
app = FastAPI()
som, MM = load_model()
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(BytesIO(contents)).convert('L')
image = np.asarray(image)
if image.shape != (256, 256):
raise ValueError("La imagen debe ser de tamaño 256x256.")
image = image.reshape(256, 256, 1)
print(f"Imagen convertida a matriz: {image.shape}")
representative_data = representativo(image)
print(f"Datos representativos de la imagen: {representative_data.shape}")
representative_data = representative_data.reshape(1, -1)
w = som.winner(representative_data)
print(f"Índice ganador del SOM: {w}")
prediction = MM[w]
print(f"Predicción: {prediction}")
return {"prediction": prediction}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))