Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
from datetime import timedelta, datetime | |
from pathlib import Path | |
from time import time | |
import gradio as gr | |
import pytz | |
from tinydb import TinyDB, where | |
from app.fn import send_email, get_timezone_by_ip | |
SERVER_DATA_DIR = os.getenv('DATA', 'results') | |
DB_EXPIRY = timedelta(hours=48).total_seconds() | |
def init_job_db(): | |
db = JobDB(f'{SERVER_DATA_DIR}/job_db.json') | |
jobs = db.all() | |
for job in jobs: | |
if job['status'] == 'RUNNING': | |
db.update({'status': 'FAILED'}, where('id') == job['id']) | |
return db | |
def ts_to_str(timestamp, timezone): | |
# Create a timezone-aware datetime object from the UNIX timestamp | |
dt = datetime.fromtimestamp(timestamp, pytz.utc) | |
# Convert the timezone-aware datetime object to the target timezone | |
target_timezone = pytz.timezone(timezone) | |
localized_dt = dt.astimezone(target_timezone) | |
# Format the datetime object to the specified string format | |
return localized_dt.strftime('%Y-%m-%d %H:%M:%S (%Z%z)') | |
class JobDB(TinyDB): | |
def remove_job_record(self,job_id): | |
# Delete the job from the database | |
self.remove(where('id') == job_id) | |
# Delete the corresponding files | |
files = Path(SERVER_DATA_DIR).glob(f"job_id*") | |
for file_path in files: | |
if file_path.is_file(): | |
os.remove(file_path) | |
def check_expiry(self): | |
jobs = self.all() | |
for job in jobs: | |
# Check if the job has expired | |
if job['status'] != 'RUNNING': | |
expiry_time = job['expiry_time'] if job['expiry_time'] is not None else job['start_time'] + DB_EXPIRY | |
if expiry_time < time(): | |
# Delete the job from the database | |
self.remove(where('id') == job['id']) | |
# Delete the corresponding file | |
files = Path(SERVER_DATA_DIR).glob(f"job_id*") | |
for file_path in files: | |
if Path(file_path).is_file(): | |
os.remove(file_path) | |
elif job['status'] == 'RUNNING' and time() - job['start_time'] > 4 * 60 * 60: # 4 hours | |
# Mark the job as failed | |
self.update( | |
{ | |
'status': 'FAILED', | |
'error': 'Job has timed out by exceeding the maximum running time of 4 hours.' | |
}, | |
where('id') == job['id'] | |
) | |
if job.get('email'): | |
send_email(job) | |
def check_user_running_job(self, email, request): | |
message = ("You already have a running prediction job (ID: {id}) under this {reason}. " | |
"Please wait for it to complete before submitting another job.") | |
try: | |
if email: | |
job = self.search((where('email') == email) & (where('status') == "RUNNING")) | |
if job: | |
return message.format(id=job[0]['id'], reason="email") | |
# check if a job is running for the session | |
elif request.cookies: | |
for key, value in request.cookies.items(): | |
job = self.search((where('cookies').key == value) & (where('status') == "RUNNING")) | |
if job: | |
return message.format(id=job[0]['id'], reason="session") | |
# check if a job is running for the IP | |
else: | |
job = self.search((where('IP') == request.client.host) & (where('status') == "RUNNING")) | |
if job: | |
return message.format(id=job[0]['id'], reason="IP") | |
return False | |
except Exception as e: | |
raise gr.Error(f'Failed to validate user running jobs due to error: {str(e)}') | |
def job_lookup(self, job_id): | |
jobs = self.search((where('id') == job_id)) | |
if jobs: | |
job = dict(jobs[0]) | |
for time_key in ['start_time', 'end_time', 'expiry_time']: | |
if job[time_key] is not None: | |
# Convert the timestamp to time string in the user's timezone | |
job[time_key] = ts_to_str(job[time_key], get_timezone_by_ip(job['ip'])) | |
return job | |
def job_initiate(self, email, session_info): | |
gr.Info('Finished processing inputs. Initiating the GenFBDD job... ' | |
'You will be redirected to Job Status page.') | |
job_id = str(uuid.uuid4()) | |
job_info = {'id': job_id, | |
'status': 'RUNNING', | |
'email': email, | |
'ip': session_info.headers.get('x-forwarded-for', session_info.client.host), | |
'cookies': dict(session_info.cookies), | |
'start_time': time(), | |
'end_time': None, | |
'expiry_time': None, | |
'error': None} | |
self.insert(job_info) | |
return job_info | |
def job_update(self, job_id, update_info): | |
job_query = (where('id') == job_id) | |
end_time = time() | |
expiry_time = end_time + DB_EXPIRY | |
self.update( | |
update_info | { | |
'end_time': end_time, | |
'expiry_time': expiry_time | |
}, | |
job_query | |
) | |
if job_info := self.search(job_query)[0]: | |
if job_info.get('email'): | |
send_email(job_info) |