Paresh1879 commited on
Commit
44ad3d0
1 Parent(s): 4444f02

Upload API Server File

Browse files
Files changed (1) hide show
  1. SDXL_API_Server.py +82 -0
SDXL_API_Server.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import io
3
+ import logging
4
+ from PIL import Image
5
+ from flask import Flask, request, jsonify, send_file
6
+ from functools import wraps
7
+ from requests.adapters import HTTPAdapter
8
+ from requests.packages.urllib3.util.retry import Retry
9
+
10
+ API_URL = "https://api-inference.huggingface.co/models/Paresh1879/stable-diffusion-xl-thumbsup-extend"
11
+ API_KEY = "your-api-key"
12
+
13
+ app = Flask(__name__)
14
+
15
+ # Storage for API key usage counts
16
+ api_key_usage = {}
17
+
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Configure HTTP Session with Retry Logic
23
+ session = requests.Session()
24
+ retry = Retry(
25
+ total=10,
26
+ backoff_factor=2,
27
+ status_forcelist=[429, 500, 502, 503, 504],
28
+ allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
29
+ )
30
+ adapter = HTTPAdapter(max_retries=retry)
31
+ session.mount("http://", adapter)
32
+ session.mount("https://", adapter)
33
+
34
+ def query(payload):
35
+ try:
36
+ response = session.post(API_URL, headers={"Authorization": f"Bearer {API_KEY}"}, json=payload)
37
+ response.raise_for_status()
38
+ return response.content
39
+ except requests.exceptions.RequestException as e:
40
+ logger.error(f"Request failed: {e}")
41
+ raise
42
+ # checking api_key
43
+ def require_api_key(f):
44
+ @wraps(f)
45
+ def check_api_key(*args, **kwargs):
46
+ api_key = request.headers.get('Authorization')
47
+ if api_key != f"Bearer {API_KEY}":
48
+ return jsonify({"error": "Unauthorized"}), 401
49
+ if api_key not in api_key_usage:
50
+ api_key_usage[api_key] = 0
51
+ api_key_usage[api_key] += 1
52
+ return f(*args, **kwargs)
53
+ return check_api_key
54
+
55
+ # /generate_image endpoint
56
+ @app.route('/generate_image', methods=['POST'])
57
+ @require_api_key
58
+ def generate_image():
59
+ prompt = request.json.get('prompt')
60
+ if not prompt:
61
+ return jsonify({"error": "Missing prompt"}), 400
62
+
63
+ try:
64
+ image_bytes = query({"inputs": prompt})
65
+ image = Image.open(io.BytesIO(image_bytes))
66
+ img_io = io.BytesIO()
67
+ image.save(img_io, 'PNG')
68
+ img_io.seek(0)
69
+ return send_file(img_io, mimetype='image/png')
70
+ except requests.exceptions.RequestException as e:
71
+ return jsonify({"error": "Service is temporarily unavailable. Please try again later."}), 503
72
+
73
+ # api_key_usage endpoint
74
+ @app.route('/api_key_usage', methods=['GET'])
75
+ @require_api_key
76
+ def get_api_key_usage():
77
+ api_key = request.headers.get('Authorization')
78
+ usage_count = api_key_usage.get(api_key, 0)
79
+ return jsonify({"usage_count": usage_count})
80
+
81
+ if __name__ == '__main__':
82
+ app.run(debug=True)