|
from document_to_gloss import DocumentToASLConverter |
|
from document_parsing import DocumentParser |
|
from vectorizer import Vectorizer |
|
from video_gen import create_multi_stitched_video |
|
import gradio as gr |
|
import asyncio |
|
import re |
|
import boto3 |
|
import os |
|
from botocore.config import Config |
|
from dotenv import load_dotenv |
|
import requests |
|
import tempfile |
|
import uuid |
|
import base64 |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
R2_ASL_VIDEOS_URL = os.environ.get("R2_ASL_VIDEOS_URL") |
|
R2_ENDPOINT = os.environ.get("R2_ENDPOINT") |
|
R2_ACCESS_KEY_ID = os.environ.get("R2_ACCESS_KEY_ID") |
|
R2_SECRET_ACCESS_KEY = os.environ.get("R2_SECRET_ACCESS_KEY") |
|
|
|
|
|
if not all([R2_ASL_VIDEOS_URL, R2_ENDPOINT, R2_ACCESS_KEY_ID, |
|
R2_SECRET_ACCESS_KEY]): |
|
raise ValueError( |
|
"Missing required R2 environment variables. " |
|
"Please check your .env file." |
|
) |
|
|
|
title = "AI-SL" |
|
description = "Convert text to ASL!" |
|
article = ("<p style='text-align: center'><a href='https://github.com/deenasun' " |
|
"target='_blank'>Deena Sun on Github</a></p>") |
|
inputs = gr.File(label="Upload Document (pdf, txt, docx, or epub)") |
|
outputs = [ |
|
gr.JSON(label="Processing Results"), |
|
gr.Video(label="ASL Video Output"), |
|
gr.HTML(label="Download Link") |
|
] |
|
|
|
parser = DocumentParser() |
|
asl_converter = DocumentToASLConverter() |
|
vectorizer = Vectorizer() |
|
session = boto3.session.Session() |
|
|
|
s3 = session.client( |
|
service_name='s3', |
|
region_name='auto', |
|
endpoint_url=R2_ENDPOINT, |
|
aws_access_key_id=R2_ACCESS_KEY_ID, |
|
aws_secret_access_key=R2_SECRET_ACCESS_KEY, |
|
config=Config(signature_version='s3v4') |
|
) |
|
|
|
def clean_gloss_token(token): |
|
"""Clean a single gloss token""" |
|
if not token: |
|
return None |
|
|
|
|
|
cleaned = re.sub(r'[^\w\s]', '', token).lower().strip() |
|
|
|
|
|
cleaned = re.sub(r'\s+', ' ', cleaned).strip() |
|
|
|
return cleaned if cleaned else None |
|
|
|
|
|
def verify_video_format(video_path): |
|
""" |
|
Verify that a video file is in a browser-compatible format (H.264 MP4) |
|
""" |
|
try: |
|
import cv2 |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return False, "Could not open video file" |
|
|
|
|
|
fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) |
|
codec = "".join([chr((fourcc >> 8 * i) & 0xFF) for i in range(4)]) |
|
|
|
cap.release() |
|
|
|
|
|
if codec in ['avc1', 'H264', 'h264']: |
|
return True, f"Video is H.264 encoded ({codec})" |
|
else: |
|
return False, f"Video codec {codec} may not be browser compatible" |
|
|
|
except Exception as e: |
|
return False, f"Error checking video format: {e}" |
|
|
|
|
|
def upload_video_to_r2(video_path, bucket_name="asl-videos"): |
|
""" |
|
Upload a video file to R2 and return a public URL |
|
""" |
|
try: |
|
|
|
is_compatible, message = verify_video_format(video_path) |
|
print(f"Video format check: {message}") |
|
|
|
|
|
file_extension = os.path.splitext(video_path)[1] |
|
unique_filename = f"{uuid.uuid4()}{file_extension}" |
|
|
|
|
|
with open(video_path, 'rb') as video_file: |
|
s3.upload_fileobj( |
|
video_file, |
|
bucket_name, |
|
unique_filename, |
|
ExtraArgs={ |
|
'ACL': 'public-read', |
|
'ContentType': 'video/mp4; codecs="avc1.42E01E"', |
|
'CacheControl': 'max-age=86400', |
|
'ContentDisposition': 'inline' |
|
}) |
|
|
|
|
|
if R2_ENDPOINT: |
|
public_domain = (R2_ENDPOINT.replace('https://', '') |
|
.split('.')[0]) |
|
video_url = (f"https://{public_domain}.r2.cloudflarestorage.com/" |
|
f"{bucket_name}/{unique_filename}") |
|
|
|
print(f"Video uploaded to R2: {video_url}") |
|
public_video_url = f"{R2_ASL_VIDEOS_URL}/{unique_filename}" |
|
print(f"Public video url: {public_video_url}") |
|
|
|
return public_video_url |
|
else: |
|
print("R2_ENDPOINT is not configured") |
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error uploading video to R2: {e}") |
|
return None |
|
|
|
def video_to_base64(video_path): |
|
""" |
|
Convert a video file to base64 string for direct download |
|
""" |
|
try: |
|
with open(video_path, 'rb') as video_file: |
|
video_data = video_file.read() |
|
base64_data = base64.b64encode(video_data).decode('utf-8') |
|
return f"data:video/mp4;base64,{base64_data}" |
|
except Exception as e: |
|
print(f"Error converting video to base64: {e}") |
|
return None |
|
|
|
def download_video_from_url(video_url): |
|
""" |
|
Download a video from a public R2 URL |
|
Returns the local file path where the video is saved |
|
""" |
|
try: |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
|
temp_path = temp_file.name |
|
temp_file.close() |
|
|
|
|
|
print(f"Downloading video from: {video_url}") |
|
response = requests.get(video_url, stream=True) |
|
response.raise_for_status() |
|
|
|
|
|
with open(temp_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
print(f"Video downloaded to: {temp_path}") |
|
return temp_path |
|
|
|
except Exception as e: |
|
print(f"Error downloading video: {e}") |
|
return None |
|
|
|
|
|
def cleanup_temp_video(file_path): |
|
""" |
|
Clean up temporary video file |
|
""" |
|
try: |
|
if file_path and os.path.exists(file_path): |
|
os.unlink(file_path) |
|
print(f"Cleaned up: {file_path}") |
|
except Exception as e: |
|
print(f"Error cleaning up file: {e}") |
|
|
|
|
|
def determine_input_type(input_data): |
|
""" |
|
Determine the type of input data and return a standardized format. |
|
Returns: (input_type, processed_data) where input_type is 'text', |
|
'file_path', or 'file_object' |
|
""" |
|
if isinstance(input_data, str): |
|
|
|
if any(ext in input_data.lower() for ext in ['.pdf', '.txt', '.docx', '.doc', '.epub']): |
|
return 'file_path', input_data |
|
|
|
elif input_data.startswith('{') and 'gradio.FileData' in input_data: |
|
try: |
|
import ast |
|
import json |
|
|
|
try: |
|
file_data = json.loads(input_data) |
|
except json.JSONDecodeError: |
|
|
|
file_data = ast.literal_eval(input_data) |
|
|
|
if isinstance(file_data, dict) and 'path' in file_data: |
|
print(f"Parsed FileData: {file_data}") |
|
return 'file_path', file_data['path'] |
|
except (ValueError, SyntaxError, json.JSONDecodeError) as e: |
|
print(f"Error parsing FileData string: {e}") |
|
print(f"Input data: {input_data}") |
|
pass |
|
else: |
|
return 'text', input_data.strip() |
|
elif isinstance(input_data, dict) and 'path' in input_data: |
|
|
|
return 'file_path', input_data['path'] |
|
elif hasattr(input_data, 'name'): |
|
|
|
return 'file_path', input_data.name |
|
else: |
|
return 'unknown', None |
|
|
|
|
|
def process_input(input_data): |
|
""" |
|
Extract text content from various input types. |
|
Returns the text content ready for ASL conversion. |
|
""" |
|
input_type, processed_data = determine_input_type(input_data) |
|
|
|
if input_type == 'text': |
|
return processed_data |
|
elif input_type == 'file_path': |
|
try: |
|
print(f"Processing file: {processed_data}") |
|
|
|
gloss = asl_converter.convert_document(processed_data) |
|
print(f"Converted gloss: {gloss[:100]}...") |
|
return gloss |
|
except Exception as e: |
|
print(f"Error processing file: {e}") |
|
return None |
|
else: |
|
print(f"Unsupported input type: {type(input_data)}") |
|
return None |
|
|
|
|
|
|
|
async def parse_vectorize_and_search_unified(input_data): |
|
""" |
|
Unified function that handles both text and file inputs |
|
""" |
|
|
|
gloss = process_input(input_data) |
|
if not gloss: |
|
return { |
|
"status": "error", |
|
"message": "Failed to process input" |
|
}, None |
|
|
|
print("ASL", gloss) |
|
|
|
|
|
gloss_tokens = gloss.split() |
|
cleaned_tokens = [] |
|
|
|
for token in gloss_tokens: |
|
cleaned = clean_gloss_token(token) |
|
if cleaned: |
|
cleaned_tokens.append(cleaned) |
|
|
|
print("Cleaned tokens:", cleaned_tokens) |
|
|
|
videos = [] |
|
video_files = [] |
|
|
|
for g in cleaned_tokens: |
|
print(f"Processing {g}") |
|
try: |
|
result = await vectorizer.vector_query_from_supabase(query=g) |
|
print("result", result) |
|
if result.get("match", False): |
|
video_url = result["video_url"] |
|
videos.append(video_url) |
|
|
|
|
|
local_path = download_video_from_url(video_url) |
|
if local_path: |
|
video_files.append(local_path) |
|
|
|
except Exception as e: |
|
print(f"Error processing {g}: {e}") |
|
continue |
|
|
|
|
|
stitched_video_path = None |
|
if len(video_files) > 1: |
|
try: |
|
print(f"Creating stitched video from {len(video_files)} videos...") |
|
stitched_video_path = tempfile.NamedTemporaryFile( |
|
delete=False, suffix='.mp4' |
|
).name |
|
create_multi_stitched_video(video_files, stitched_video_path) |
|
print(f"Stitched video created: {stitched_video_path}") |
|
except Exception as e: |
|
print(f"Error creating stitched video: {e}") |
|
stitched_video_path = None |
|
elif len(video_files) == 1: |
|
|
|
stitched_video_path = video_files[0] |
|
|
|
|
|
video_download_url = None |
|
if stitched_video_path: |
|
video_download_url = upload_video_to_r2(stitched_video_path) |
|
|
|
|
|
|
|
for video_file in video_files: |
|
if video_file != stitched_video_path: |
|
cleanup_temp_video(video_file) |
|
|
|
video64 = video_to_base64(stitched_video_path) |
|
|
|
return { |
|
"status": "success", |
|
"videos": videos, |
|
"video_count": len(videos), |
|
"gloss": gloss, |
|
"cleaned_tokens": cleaned_tokens, |
|
"video_download_url": video_download_url, |
|
"video_as_base_64": video64 |
|
}, stitched_video_path |
|
|
|
|
|
def parse_vectorize_and_search_unified_sync(input_data): |
|
return asyncio.run(parse_vectorize_and_search_unified(input_data)) |
|
|
|
|
|
def predict_unified(input_data): |
|
""" |
|
Unified prediction function that handles both text and file inputs |
|
""" |
|
try: |
|
if input_data is None: |
|
return { |
|
"status": "error", |
|
"message": "Please provide text or upload a document" |
|
}, None |
|
|
|
|
|
result = parse_vectorize_and_search_unified_sync(input_data) |
|
|
|
|
|
json_data, local_video_path = result |
|
|
|
|
|
if local_video_path and json_data.get("status") == "success": |
|
|
|
|
|
import threading |
|
import time |
|
|
|
def delayed_cleanup(video_path): |
|
time.sleep(30) |
|
cleanup_temp_video(video_path) |
|
|
|
|
|
cleanup_thread = threading.Thread( |
|
target=delayed_cleanup, |
|
args=(local_video_path,) |
|
) |
|
cleanup_thread.daemon = True |
|
cleanup_thread.start() |
|
|
|
return json_data, local_video_path |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error in predict_unified function: {e}") |
|
return { |
|
"status": "error", |
|
"message": f"An error occurred: {str(e)}" |
|
}, None |
|
|
|
|
|
|
|
def create_interface(): |
|
"""Create and configure the Gradio interface""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Enter text to convert to ASL", |
|
placeholder="Type or paste your text here...", |
|
lines=5 |
|
), |
|
gr.File( |
|
label="Upload Document (pdf, txt, docx, or epub)", |
|
file_types=[".pdf", ".txt", ".docx", ".epub"] |
|
) |
|
], |
|
outputs=[ |
|
gr.JSON(label="Results"), |
|
gr.Video(label="ASL Video") |
|
], |
|
title=title, |
|
description=description, |
|
article=article |
|
) |
|
|
|
return interface |
|
|
|
|
|
|
|
def predict(text, file): |
|
""" |
|
Predict function for Hugging Face API access. |
|
This function will be available as the /predict endpoint. |
|
""" |
|
|
|
if text and text.strip(): |
|
|
|
input_data = text.strip() |
|
elif file is not None: |
|
|
|
input_data = file |
|
else: |
|
|
|
return { |
|
"status": "error", |
|
"message": "Please provide either text or upload a file" |
|
}, None |
|
|
|
print("Input to the prediction function", input_data) |
|
print("Input type:", type(input)) |
|
|
|
return predict_unified(input_data) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |
|
|