Spaces:
Sleeping
Sleeping
File size: 3,209 Bytes
63c4e31 bdc28b7 63c4e31 ae690b8 63c4e31 c839aa3 63c4e31 e471491 e3ef8a2 63c4e31 c839aa3 344060f 63c4e31 344060f 63c4e31 ae690b8 63c4e31 99d22e0 63c4e31 e471491 c839aa3 63c4e31 457ff50 63c4e31 c839aa3 63c4e31 e471491 63c4e31 bdc28b7 63c4e31 ae690b8 63c4e31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Dense, LSTM
from tensorflow.keras.optimizers import Adam
import traceback
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import tempfile
import os
app = FastAPI()
@app.post("/predict")
async def predict(model: UploadFile = File(...), data: str = None):
try:
# Save the uploaded model to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_model_file:
temp_model_file.write(await model.read())
temp_model_path = temp_model_file.name
ds = eval(data)
ds = np.array(ds).reshape(-1, 1)
# Normalize the data
scaler = MinMaxScaler()
ds_normalized = scaler.fit_transform(ds)
# Load the model
model = load_model(temp_model_path, compile=False)
model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', run_eagerly=True)
print(data)
# Process the data
predictions = model.predict(ds_normalized.reshape(1, 12, 1)).tolist()
predictions_rescaled = scaler.inverse_transform(predictions).flatten().tolist()
return {"predictions": predictions_rescaled}
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post("/retrain")
async def retrain(model: UploadFile = File(...), data: str = None):
try:
# Save the uploaded model and data to temporary files
with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_model_file:
temp_model_file.write(await model.read())
temp_model_path = temp_model_file.name
# Load the model and data
model = load_model(temp_model_path, compile=False)
model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', run_eagerly=True)
ds = eval(data)
ds = np.array(ds).reshape(-1, 1)
# Normalize the data
scaler = MinMaxScaler()
ds_normalized = scaler.fit_transform(ds)
x_train = np.array([ds_normalized[i - 12:i] for i in range(12, len(ds_normalized))])
y_train = ds_normalized[12:]
model.compile(optimizer=Adam(learning_rate=0.001), loss="mse", run_eagerly=True)
model.fit(x_train, y_train, epochs=1, batch_size=32)
# Save the updated model to a temporary file
updated_model_path = temp_model_path.replace(".h5", "_updated.h5")
model.save(updated_model_path)
# Return the path for downloading
return FileResponse(
path=updated_model_path,
filename="updated_model.h5",
media_type="application/octet-stream",
headers={"Content-Disposition": "attachment; filename=updated_model.h5"}
)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temporary files
if os.path.exists(temp_model_path):
os.remove(temp_model_path)
|