File size: 2,568 Bytes
44ad3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import requests
import io
import logging
from PIL import Image
from flask import Flask, request, jsonify, send_file
from functools import wraps
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

API_URL = "https://api-inference.huggingface.co/models/Paresh1879/stable-diffusion-xl-thumbsup-extend"
API_KEY = "your-api-key"

app = Flask(__name__)

# Storage for API key usage counts
api_key_usage = {}


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure HTTP Session with Retry Logic
session = requests.Session()
retry = Retry(
    total=10, 
    backoff_factor=2, 
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.mount("https://", adapter)

def query(payload):
    try:
        response = session.post(API_URL, headers={"Authorization": f"Bearer {API_KEY}"}, json=payload)
        response.raise_for_status() 
        return response.content
    except requests.exceptions.RequestException as e:
        logger.error(f"Request failed: {e}")
        raise
# checking api_key
def require_api_key(f):
    @wraps(f)
    def check_api_key(*args, **kwargs):
        api_key = request.headers.get('Authorization')
        if api_key != f"Bearer {API_KEY}":
            return jsonify({"error": "Unauthorized"}), 401
        if api_key not in api_key_usage:
            api_key_usage[api_key] = 0
        api_key_usage[api_key] += 1
        return f(*args, **kwargs)
    return check_api_key

# /generate_image endpoint
@app.route('/generate_image', methods=['POST'])
@require_api_key
def generate_image():
    prompt = request.json.get('prompt')
    if not prompt:
        return jsonify({"error": "Missing prompt"}), 400

    try:
        image_bytes = query({"inputs": prompt})
        image = Image.open(io.BytesIO(image_bytes))
        img_io = io.BytesIO()
        image.save(img_io, 'PNG')
        img_io.seek(0)
        return send_file(img_io, mimetype='image/png')
    except requests.exceptions.RequestException as e:
        return jsonify({"error": "Service is temporarily unavailable. Please try again later."}), 503
    
# api_key_usage endpoint
@app.route('/api_key_usage', methods=['GET'])
@require_api_key
def get_api_key_usage():
    api_key = request.headers.get('Authorization')
    usage_count = api_key_usage.get(api_key, 0)
    return jsonify({"usage_count": usage_count})

if __name__ == '__main__':
    app.run(debug=True)