File size: 3,209 Bytes
63c4e31
bdc28b7
63c4e31
 
 
ae690b8
63c4e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c839aa3
 
 
 
 
 
63c4e31
 
e471491
e3ef8a2
63c4e31
c839aa3
344060f
63c4e31
344060f
63c4e31
 
ae690b8
63c4e31
 
 
99d22e0
63c4e31
 
 
 
 
 
 
 
 
e471491
 
c839aa3
 
63c4e31
 
457ff50
63c4e31
c839aa3
 
 
63c4e31
e471491
63c4e31
 
 
 
 
bdc28b7
 
 
 
 
 
 
63c4e31
ae690b8
63c4e31
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Dense, LSTM
from tensorflow.keras.optimizers import Adam
import traceback
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import tempfile
import os

app = FastAPI()

@app.post("/predict")
async def predict(model: UploadFile = File(...), data: str = None):
    try:
        # Save the uploaded model to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_model_file:
            temp_model_file.write(await model.read())
            temp_model_path = temp_model_file.name

        ds = eval(data)
        ds = np.array(ds).reshape(-1, 1)
        # Normalize the data
        scaler = MinMaxScaler()
        ds_normalized = scaler.fit_transform(ds)
        
        # Load the model
        model = load_model(temp_model_path, compile=False)
        model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', run_eagerly=True)
        print(data)
        # Process the data
        predictions = model.predict(ds_normalized.reshape(1, 12, 1)).tolist()
        predictions_rescaled = scaler.inverse_transform(predictions).flatten().tolist()  

        return {"predictions": predictions_rescaled}

    except Exception as e:
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/retrain")
async def retrain(model: UploadFile = File(...), data: str = None):
    try:
        # Save the uploaded model and data to temporary files
        with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as temp_model_file:
            temp_model_file.write(await model.read())
            temp_model_path = temp_model_file.name


        # Load the model and data
        model = load_model(temp_model_path, compile=False)
        model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', run_eagerly=True)

        ds = eval(data)
        ds = np.array(ds).reshape(-1, 1)
        # Normalize the data
        scaler = MinMaxScaler()
        ds_normalized = scaler.fit_transform(ds)

        x_train = np.array([ds_normalized[i - 12:i] for i in range(12, len(ds_normalized))])
        y_train = ds_normalized[12:]
        
        model.compile(optimizer=Adam(learning_rate=0.001), loss="mse", run_eagerly=True)
        model.fit(x_train, y_train, epochs=1, batch_size=32)

        # Save the updated model to a temporary file
        updated_model_path = temp_model_path.replace(".h5", "_updated.h5")
        model.save(updated_model_path)

        # Return the path for downloading
        return FileResponse(
            path=updated_model_path,
            filename="updated_model.h5",
            media_type="application/octet-stream",
            headers={"Content-Disposition": "attachment; filename=updated_model.h5"}
        )
    except Exception as e:
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        # Clean up temporary files
        if os.path.exists(temp_model_path):
            os.remove(temp_model_path)