ai-sl-api / app.py
deenasun's picture
fix for catching Gradio DataFile objects when they are passed from API calls as strings
f37f939
from document_to_gloss import DocumentToASLConverter
from document_parsing import DocumentParser
from vectorizer import Vectorizer
from video_gen import create_multi_stitched_video
import gradio as gr
import asyncio
import re
import boto3
import os
from botocore.config import Config
from dotenv import load_dotenv
import requests
import tempfile
import uuid
import base64
# Load environment variables from .env file
load_dotenv()
# Load R2/S3 environment secrets
R2_ASL_VIDEOS_URL = os.environ.get("R2_ASL_VIDEOS_URL")
R2_ENDPOINT = os.environ.get("R2_ENDPOINT")
R2_ACCESS_KEY_ID = os.environ.get("R2_ACCESS_KEY_ID")
R2_SECRET_ACCESS_KEY = os.environ.get("R2_SECRET_ACCESS_KEY")
# Validate that required environment variables are set
if not all([R2_ASL_VIDEOS_URL, R2_ENDPOINT, R2_ACCESS_KEY_ID,
R2_SECRET_ACCESS_KEY]):
raise ValueError(
"Missing required R2 environment variables. "
"Please check your .env file."
)
title = "AI-SL"
description = "Convert text to ASL!"
article = ("<p style='text-align: center'><a href='https://github.com/deenasun' "
"target='_blank'>Deena Sun on Github</a></p>")
inputs = gr.File(label="Upload Document (pdf, txt, docx, or epub)")
outputs = [
gr.JSON(label="Processing Results"),
gr.Video(label="ASL Video Output"),
gr.HTML(label="Download Link")
]
parser = DocumentParser()
asl_converter = DocumentToASLConverter()
vectorizer = Vectorizer()
session = boto3.session.Session()
s3 = session.client(
service_name='s3',
region_name='auto',
endpoint_url=R2_ENDPOINT,
aws_access_key_id=R2_ACCESS_KEY_ID,
aws_secret_access_key=R2_SECRET_ACCESS_KEY,
config=Config(signature_version='s3v4')
)
def clean_gloss_token(token):
"""Clean a single gloss token"""
if not token:
return None
# Remove punctuation and convert to lowercase
cleaned = re.sub(r'[^\w\s]', '', token).lower().strip()
# Remove extra whitespace
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
return cleaned if cleaned else None
def verify_video_format(video_path):
"""
Verify that a video file is in a browser-compatible format (H.264 MP4)
"""
try:
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return False, "Could not open video file"
# Get video properties
fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
codec = "".join([chr((fourcc >> 8 * i) & 0xFF) for i in range(4)])
cap.release()
# Check if it's H.264
if codec in ['avc1', 'H264', 'h264']:
return True, f"Video is H.264 encoded ({codec})"
else:
return False, f"Video codec {codec} may not be browser compatible"
except Exception as e:
return False, f"Error checking video format: {e}"
def upload_video_to_r2(video_path, bucket_name="asl-videos"):
"""
Upload a video file to R2 and return a public URL
"""
try:
# Verify video format for browser compatibility
is_compatible, message = verify_video_format(video_path)
print(f"Video format check: {message}")
# Generate a unique filename
file_extension = os.path.splitext(video_path)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# Upload to R2
with open(video_path, 'rb') as video_file:
s3.upload_fileobj(
video_file,
bucket_name,
unique_filename,
ExtraArgs={
'ACL': 'public-read',
'ContentType': 'video/mp4; codecs="avc1.42E01E"', # H.264
'CacheControl': 'max-age=86400', # Cache for 24 hours
'ContentDisposition': 'inline' # Force inline display
})
# Replace the endpoint with the domain for uploading
if R2_ENDPOINT:
public_domain = (R2_ENDPOINT.replace('https://', '')
.split('.')[0])
video_url = (f"https://{public_domain}.r2.cloudflarestorage.com/"
f"{bucket_name}/{unique_filename}")
print(f"Video uploaded to R2: {video_url}")
public_video_url = f"{R2_ASL_VIDEOS_URL}/{unique_filename}"
print(f"Public video url: {public_video_url}")
return public_video_url
else:
print("R2_ENDPOINT is not configured")
return None
except Exception as e:
print(f"Error uploading video to R2: {e}")
return None
def video_to_base64(video_path):
"""
Convert a video file to base64 string for direct download
"""
try:
with open(video_path, 'rb') as video_file:
video_data = video_file.read()
base64_data = base64.b64encode(video_data).decode('utf-8')
return f"data:video/mp4;base64,{base64_data}"
except Exception as e:
print(f"Error converting video to base64: {e}")
return None
def download_video_from_url(video_url):
"""
Download a video from a public R2 URL
Returns the local file path where the video is saved
"""
try:
# Create a temporary file with .mp4 extension
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
temp_path = temp_file.name
temp_file.close()
# Download the video
print(f"Downloading video from: {video_url}")
response = requests.get(video_url, stream=True)
response.raise_for_status()
# Save to temporary file
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Video downloaded to: {temp_path}")
return temp_path
except Exception as e:
print(f"Error downloading video: {e}")
return None
def cleanup_temp_video(file_path):
"""
Clean up temporary video file
"""
try:
if file_path and os.path.exists(file_path):
os.unlink(file_path)
print(f"Cleaned up: {file_path}")
except Exception as e:
print(f"Error cleaning up file: {e}")
def determine_input_type(input_data):
"""
Determine the type of input data and return a standardized format.
Returns: (input_type, processed_data) where input_type is 'text',
'file_path', or 'file_object'
"""
if isinstance(input_data, str):
# Check if it's a file path (contains file extension)
if any(ext in input_data.lower() for ext in ['.pdf', '.txt', '.docx', '.doc', '.epub']):
return 'file_path', input_data
# Check if it's a string representation of a gradio.FileData dict
elif input_data.startswith('{') and 'gradio.FileData' in input_data:
try:
import ast
import json
# Try to parse as JSON first
try:
file_data = json.loads(input_data)
except json.JSONDecodeError:
# Fall back to ast.literal_eval for safer parsing
file_data = ast.literal_eval(input_data)
if isinstance(file_data, dict) and 'path' in file_data:
print(f"Parsed FileData: {file_data}")
return 'file_path', file_data['path']
except (ValueError, SyntaxError, json.JSONDecodeError) as e:
print(f"Error parsing FileData string: {e}")
print(f"Input data: {input_data}")
pass
else:
return 'text', input_data.strip()
elif isinstance(input_data, dict) and 'path' in input_data:
# This is a gradio.FileData object from API calls
return 'file_path', input_data['path']
elif hasattr(input_data, 'name'):
# This is a regular file object
return 'file_path', input_data.name
else:
return 'unknown', None
def process_input(input_data):
"""
Extract text content from various input types.
Returns the text content ready for ASL conversion.
"""
input_type, processed_data = determine_input_type(input_data)
if input_type == 'text':
return processed_data
elif input_type == 'file_path':
try:
print(f"Processing file: {processed_data}")
# Use document converter for all file types
gloss = asl_converter.convert_document(processed_data)
print(f"Converted gloss: {gloss[:100]}...")
return gloss
except Exception as e:
print(f"Error processing file: {e}")
return None
else:
print(f"Unsupported input type: {type(input_data)}")
return None
async def parse_vectorize_and_search_unified(input_data):
"""
Unified function that handles both text and file inputs
"""
# Process the input to get gloss
gloss = process_input(input_data)
if not gloss:
return {
"status": "error",
"message": "Failed to process input"
}, None
print("ASL", gloss)
# Split by spaces and clean each token
gloss_tokens = gloss.split()
cleaned_tokens = []
for token in gloss_tokens:
cleaned = clean_gloss_token(token)
if cleaned: # Only add non-empty tokens
cleaned_tokens.append(cleaned)
print("Cleaned tokens:", cleaned_tokens)
videos = []
video_files = [] # Store local file paths for stitching
for g in cleaned_tokens:
print(f"Processing {g}")
try:
result = await vectorizer.vector_query_from_supabase(query=g)
print("result", result)
if result.get("match", False):
video_url = result["video_url"]
videos.append(video_url)
# Download the video
local_path = download_video_from_url(video_url)
if local_path:
video_files.append(local_path)
except Exception as e:
print(f"Error processing {g}: {e}")
continue
# Create stitched video if we have multiple videos
stitched_video_path = None
if len(video_files) > 1:
try:
print(f"Creating stitched video from {len(video_files)} videos...")
stitched_video_path = tempfile.NamedTemporaryFile(
delete=False, suffix='.mp4'
).name
create_multi_stitched_video(video_files, stitched_video_path)
print(f"Stitched video created: {stitched_video_path}")
except Exception as e:
print(f"Error creating stitched video: {e}")
stitched_video_path = None
elif len(video_files) == 1:
# If only one video, just use it directly
stitched_video_path = video_files[0]
# Upload final video to R2 and get public URL
video_download_url = None
if stitched_video_path:
video_download_url = upload_video_to_r2(stitched_video_path)
# Don't clean up the local file yet - let frontend use it first
# Clean up individual video files after stitching
for video_file in video_files:
if video_file != stitched_video_path: # Don't delete the final output
cleanup_temp_video(video_file)
video64 = video_to_base64(stitched_video_path)
# Return simplified results
return {
"status": "success",
"videos": videos,
"video_count": len(videos),
"gloss": gloss,
"cleaned_tokens": cleaned_tokens,
"video_download_url": video_download_url,
"video_as_base_64": video64
}, stitched_video_path
def parse_vectorize_and_search_unified_sync(input_data):
return asyncio.run(parse_vectorize_and_search_unified(input_data))
def predict_unified(input_data):
"""
Unified prediction function that handles both text and file inputs
"""
try:
if input_data is None:
return {
"status": "error",
"message": "Please provide text or upload a document"
}, None
# Use the unified processing function
result = parse_vectorize_and_search_unified_sync(input_data)
# Get the results
json_data, local_video_path = result
# If we have a local video path, use it directly for Gradio
if local_video_path and json_data.get("status") == "success":
# Schedule cleanup of the video file after a delay
# This gives Gradio time to load and display the video
import threading
import time
def delayed_cleanup(video_path):
time.sleep(30) # Wait 30 seconds before cleanup
cleanup_temp_video(video_path)
# Start cleanup thread
cleanup_thread = threading.Thread(
target=delayed_cleanup,
args=(local_video_path,)
)
cleanup_thread.daemon = True
cleanup_thread.start()
return json_data, local_video_path
return result
except Exception as e:
print(f"Error in predict_unified function: {e}")
return {
"status": "error",
"message": f"An error occurred: {str(e)}"
}, None
# Create the Gradio interface
def create_interface():
"""Create and configure the Gradio interface"""
# Create the interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(
label="Enter text to convert to ASL",
placeholder="Type or paste your text here...",
lines=5
),
gr.File(
label="Upload Document (pdf, txt, docx, or epub)",
file_types=[".pdf", ".txt", ".docx", ".epub"]
)
],
outputs=[
gr.JSON(label="Results"),
gr.Video(label="ASL Video")
],
title=title,
description=description,
article=article
)
return interface
# Add a predict function for Hugging Face API access
def predict(text, file):
"""
Predict function for Hugging Face API access.
This function will be available as the /predict endpoint.
"""
# Determine which input to use
if text and text.strip():
# Use text input
input_data = text.strip()
elif file is not None:
# Use file input - let the centralized processor handle the type
input_data = file
else:
# No input provided
return {
"status": "error",
"message": "Please provide either text or upload a file"
}, None
print("Input to the prediction function", input_data)
print("Input type:", type(input))
# Process using the unified function
return predict_unified(input_data)
# For Hugging Face Spaces, use the Interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True # Set to True for local testing with public URL
)