gabriel-fallen commited on
Commit
6ac585c
1 Parent(s): 49ef003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -48
app.py CHANGED
@@ -1,16 +1,6 @@
1
- import os
2
- import requests
3
- import json
4
- from io import BytesIO
5
-
6
- from flask import Flask, jsonify, render_template, request, send_file
7
-
8
- from modules.inference import infer_t5
9
- from modules.dataset import query_emotion
10
-
11
- # https://huggingface.co/settings/tokens
12
- # https://huggingface.co/spaces/{username}/{space}/settings
13
- API_TOKEN = os.getenv("BIG_GAN_TOKEN")
14
 
15
  app = Flask(__name__)
16
 
@@ -19,42 +9,43 @@ app = Flask(__name__)
19
  def index():
20
  return render_template("index.html")
21
 
22
-
23
- @app.route("/infer_biggan")
24
- def biggan():
25
- input = request.args.get("input")
26
-
27
- output = requests.request(
28
- "POST",
29
- "https://api-inference.huggingface.co/models/osanseviero/BigGAN-deep-128",
30
- headers={"Authorization": f"Bearer {API_TOKEN}"},
31
- data=json.dumps(input),
32
- )
33
-
34
- return send_file(BytesIO(output.content), mimetype="image/png")
35
-
36
-
37
- @app.route("/infer_t5")
38
- def t5():
39
- input = request.args.get("input")
40
-
41
- output = infer_t5(input)
42
-
43
- return jsonify({"output": output})
44
-
45
-
46
- @app.route("/query_emotion")
47
- def emotion():
48
- start = request.args.get("start")
49
- end = request.args.get("end")
50
-
51
- print(start)
52
- print(end)
53
-
54
- output = query_emotion(int(start), int(end))
55
-
56
- return jsonify({"output": output})
57
 
58
 
59
  if __name__ == "__main__":
 
60
  app.run(host="0.0.0.0", port=7860)
 
1
+ import os, pickle
2
+ import logging
3
+ from flask import Flask, jsonify, render_template, request
 
 
 
 
 
 
 
 
 
 
4
 
5
  app = Flask(__name__)
6
 
 
9
  def index():
10
  return render_template("index.html")
11
 
12
+ @app.route('/api/predict', methods=['POST'])
13
+ def predict():
14
+ try:
15
+ # load and parse input
16
+ data = request.json
17
+ vector = [
18
+ float(data['age']),
19
+ data['travel'],
20
+ data['department'],
21
+ float(data['distance']),
22
+ float(data['education']),
23
+ data['gender'],
24
+ float(data['satisfaction']),
25
+ data['maritalstatus'],
26
+ float(data['income']),
27
+ data['overtime'],
28
+ float(data['totalyears']),
29
+ float(data['years']),
30
+ float(data['lastpromotion'])
31
+ ]
32
+ # app.logging.info(f'vector: {vector}')
33
+ # print(f'vector: {vector}\n')
34
+
35
+ # load the model
36
+ with open(os.path.join('data', 'logistic.pkcls'), 'rb') as file:
37
+ model = pickle.load(file)
38
+
39
+ predictions = model(vector, 1)
40
+ # app.logging.info(f'predictions: {predictions}')
41
+ # print(f'predictions: {predictions}\n')
42
+
43
+ # send the response
44
+ return jsonify({ "predictions": { "leave": predictions[0], "stay": predictions[1] } })
45
+ except Exception as e:
46
+ return jsonify({"error": repr(e)})
47
 
48
 
49
  if __name__ == "__main__":
50
+ app.logger.setLevel(logging.INFO)
51
  app.run(host="0.0.0.0", port=7860)