albin commited on
Commit
5350d90
·
1 Parent(s): 9b5992c

complete app.py to make the prediction

Browse files
Files changed (1) hide show
  1. app.py +75 -3
app.py CHANGED
@@ -1,7 +1,79 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from fastapi import FastAPI, Form, Depends, Request
5
+ from fastapi.encoders import jsonable_encoder
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ import pickle
10
 
11
  app = FastAPI()
12
 
13
+ # Add CORS middleware
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ model_file = open('logistic_regression_model.pkl', 'rb')
23
+ model = pickle.load(model_file, encoding='bytes')
24
+
25
+
26
+ class Msg(BaseModel):
27
+ msg: str
28
+
29
+ class Req(BaseModel):
30
+ url: str
31
+
32
+ class Resp(BaseModel):
33
+ url: str
34
+ label: int
35
+
36
+
37
  @app.get("/")
38
+ async def root():
39
+ return {"message": "Hello World. Welcome to FastAPI!"}
40
+
41
+ def form_req(url: str = Form(...)):
42
+ return Req(url=str(url))
43
+
44
+
45
+ @app.get("/path")
46
+ async def demo_get():
47
+ return {"message": "This is /path endpoint, use a post request to transform the text to uppercase"}
48
+
49
+
50
+ @app.post("/path")
51
+ async def demo_post(inp: Msg):
52
+ return {"message": inp.msg.upper()}
53
+
54
+
55
+ @app.get("/path/{path_id}")
56
+ async def demo_get_path_id(path_id: int):
57
+ return {"message": f"This is /path/{path_id} endpoint, use post request to retrieve result"}
58
+
59
+
60
+ @app.get("/predict/{path_id}")
61
+ async def predict(path_id: int):
62
+ return {"message": f"This is /predict/{path_id} endpoint, use post request to retrieve result"}
63
+
64
+ @app.post("/predict")
65
+ async def predict(request: Request, requess: Req = Depends(form_req)):
66
+ '''
67
+ Predict if url is phishing or legitimate
68
+ and render the result to the html page
69
+ '''
70
+ url = requess.url
71
+
72
+ prediction = model.predict([str(url)])
73
+ output = prediction[0]
74
+
75
+ output_text = "Legitimate" if output == 1 else "Phishing"
76
+
77
+ # Render index.html with prediction results
78
+ json_compatible_resp_data = jsonable_encoder(Resp(url=requess.url, label=output_text))
79
+ return JSONResponse(content=json_compatible_resp_data)