File size: 2,444 Bytes
944ee1c
7dfa79f
91f49a8
8b15eea
7dfa79f
8b15eea
9a4e478
 
8b15eea
7dfa79f
 
 
 
8b15eea
 
 
 
9a4e478
 
 
8b15eea
5f311b3
8b15eea
5f311b3
 
 
 
 
944ee1c
 
7dfa79f
944ee1c
7dfa79f
5f311b3
7dfa79f
5f311b3
8b15eea
 
7dfa79f
 
9a4e478
8b15eea
91f49a8
7dfa79f
 
 
91f49a8
7dfa79f
944ee1c
 
91f49a8
7dfa79f
91f49a8
 
 
 
7dfa79f
91f49a8
 
 
7dfa79f
91f49a8
 
 
 
53724ab
91f49a8
53724ab
165d244
 
 
91f49a8
165d244
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import datetime
import logging
import os
from os import getenv
import time

import gradio as gr
import requests

# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

API_URL = getenv('API_URL')
BEARER = getenv('BEARER')

headers = {
    "Authorization": f"Bearer {BEARER}",
    "Content-Type": "application/json"
    }


def call_jais(payload):
    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()  # This will raise an exception for HTTP error codes
        return response.json()
    except requests.exceptions.HTTPError as http_err:
        # Check if the error is a 5XX server error
        if 500 <= http_err.response.status_code < 600:
            raise gr.Warning("The endpoint is loading, it takes about 4 min from the first call.")
        else:
            raise gr.Warning(f"An error occurred while processing the request. {http_err}")
    except Exception as err:
        raise gr.Warning(f"Check Inference Endpoint Status. An error occurred while processing the request. {err}")


def generate(prompt: str):
    start_time = time.perf_counter()

    payload = {'inputs': '', 'prompt': prompt}
    response = call_jais(payload)

    end_time = time.perf_counter()
    elapsed_time = end_time - start_time
    logger.warning(f"Function took {elapsed_time:.1f} seconds to execute")

    return response


def check_endpoint_status():
    # Replace with the actual API URL and headers
    api_url = os.getenv("ENDPOINT_URL")
    headers = {
        'accept': 'application/json',
        'Authorization': f'Bearer {os.getenv("BEARER")}'
    }

    try:
        response = requests.get(api_url, headers=headers)
        response.raise_for_status()
        data = response.json()

        # Extracting the status information
        status = data.get('status', {}).get('state', 'No status found')
        message = data.get('status', {}).get('message', 'No message found').replace('Sending a request will restart the Endpoint', 'Click Wake Up Endpoint')

        if status.lower() != "running":
            return f"<div style='color: red; font-size: 20px; font-weight: bold;'>Status: {status}<br>Message: {message}</div>"
        else:
            return f"<div>Status: {status}<br>Message: {message}</div>"
    except requests.exceptions.RequestException as e:
        return f"<div>Failed to get status: {str(e)}</div>"