Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import pkg_resources | |
import shutil | |
import glob | |
import psutil | |
import torch | |
# Clear caches and temporary files on startup | |
def clear_caches(): | |
print("Clearing caches...") | |
# Clear pip cache | |
try: | |
os.system("pip cache purge") | |
print("Pip cache cleared") | |
except Exception as e: | |
print(f"Error clearing pip cache: {e}") | |
# Clear ./tmp directory | |
tmp_dir = "./tmp" | |
if os.path.exists(tmp_dir): | |
try: | |
shutil.rmtree(tmp_dir) | |
os.makedirs(tmp_dir, exist_ok=True) | |
print("Temporary directory ./tmp cleared") | |
except Exception as e: | |
print(f"Error clearing {tmp_dir}: {e}") | |
# Print memory usage for debugging | |
def print_memory_usage(): | |
process = psutil.Process() | |
mem_info = process.memory_info() | |
print(f"Memory usage: {mem_info.rss / 1024**2:.2f} MB") | |
if torch.cuda.is_available(): | |
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
# Run cache clearing and memory check | |
clear_caches() | |
print_memory_usage() | |
import spaces | |
import os | |
os.system("pip install gradio==4.44.1 --no-cache-dir") | |
import gradio as gr | |
print(f"Gradio version: {gr.__version__}") | |
from datetime import datetime | |
import tempfile | |
import os | |
import json | |
import torch | |
import gc | |
from azure.storage.blob import BlobServiceClient | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from yolo_detection import ( | |
detect_people_and_machinery, | |
annotate_video_with_bboxes, | |
is_image, | |
is_video | |
) | |
from image_captioning import ( | |
analyze_image_activities, | |
analyze_video_activities, | |
process_video_chunk, | |
load_model_and_tokenizer, | |
MAX_NUM_FRAMES | |
) | |
# Azure Blob Storage configuration | |
CONTAINER_NAME = "timelapsevideo" | |
connection_string =os.environ['AzureKey'] | |
print(connection_string) | |
if not connection_string: | |
print("Warning: AZURE_STORAGE_CONNECTION_STRING not found. Azure Blob functionality will be disabled.") | |
# Global storage for activities | |
global_activities = [] | |
# Create tmp directory for storing frames | |
tmp_dir = os.path.join('.', 'tmp') | |
os.makedirs(tmp_dir, exist_ok=True) | |
def get_azure_videos(): | |
"""List videos from Azure Blob Storage""" | |
try: | |
blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
container_client = blob_service_client.get_container_client(CONTAINER_NAME) | |
blob_list = container_client.list_blobs() | |
video_files = [blob.name for blob in blob_list if blob.name.lower().endswith(('.mp4', '.avi', '.mov'))] | |
return video_files | |
except Exception as e: | |
print(f"Error listing Azure blobs: {str(e)}") | |
return [] | |
def download_azure_video(blob_name): | |
"""Download a video from Azure Blob Storage to a temporary file""" | |
global global_media_path | |
try: | |
blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME, blob=blob_name) | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: | |
blob_data = blob_client.download_blob() | |
temp_file.write(blob_data.readall()) | |
temp_path = temp_file.name | |
global_media_path = temp_path | |
return temp_path | |
except Exception as e: | |
print(f"Error downloading Azure blob {blob_name}: {str(e)}") | |
return None | |
def prepare_media(azure_video, media, current_prepared_path): | |
"""Prepare media by downloading from Azure or copying uploaded file""" | |
if current_prepared_path and os.path.exists(current_prepared_path): | |
try: | |
os.remove(current_prepared_path) | |
except Exception as e: | |
print(f"Error deleting previous media: {str(e)}") | |
if azure_video and azure_video != "None": | |
temp_path = download_azure_video(azure_video) | |
if temp_path: | |
return temp_path, f"Video '{azure_video}' downloaded successfully." | |
else: | |
return None, "Error downloading Azure video." | |
elif media is not None: | |
try: | |
file_ext = get_file_extension(media.name) | |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file: | |
temp_path = temp_file.name | |
with open(media.name, 'rb') as f: | |
temp_file.write(f.read()) | |
return temp_path, "Uploaded media prepared successfully." | |
except Exception as e: | |
return None, f"Error preparing uploaded media: {str(e)}" | |
else: | |
return None, "No media selected or uploaded." | |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, prepared_media_path): | |
"""Process the site diary entry using pre-prepared media""" | |
global global_activities | |
if prepared_media_path is None: | |
return [day, date, "Please prepare media first", "Please prepare media first", "Please prepare media first", None, None, [], None, []] | |
media_path = prepared_media_path | |
try: | |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(media_path) | |
print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}") | |
annotated_video_path = None | |
detected_activities = analyze_image_activities(media_path) if is_image(media_path) else analyze_video_activities(media_path) | |
print(f"Detected activities: {detected_activities}") | |
global_activities = detected_activities | |
if is_video(media_path): | |
annotated_video_path = annotate_video_with_bboxes(media_path) | |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
chat_history = [] | |
activity_rows = [] | |
for activity in detected_activities: | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
activity_rows.append([time, summary]) | |
return [day, date, str(detected_people), str(detected_machinery), | |
detected_types_str, gr.update(visible=True), annotated_video_path, | |
detected_activities, chat_history, activity_rows] | |
except Exception as e: | |
print(f"Error processing media: {str(e)}") | |
return [day, date, "Error processing media", "Error processing media", | |
"Error processing media", None, None, [], None, []] | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[1].lower() | |
def on_card_click(activity_indices, history, evt: gr.SelectData): | |
"""Handle clicking on an activity card in the gallery""" | |
global global_activities, global_media_path | |
selected_idx = evt.index | |
if selected_idx < 0 or selected_idx >= len(activity_indices): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
card_idx = activity_indices[selected_idx] | |
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}") | |
if card_idx < 0 or card_idx >= len(global_activities): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
selected_activity = global_activities[card_idx] | |
chunk_video_path = None | |
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']): | |
chunk_video_path = selected_activity['chunk_path'] | |
print(f"Using pre-saved chunk video: {chunk_video_path}") | |
else: | |
# Fallback to full video if chunk not available | |
chunk_video_path = global_media_path | |
print(f"Chunk video not available, using full video: {chunk_video_path}") | |
history = [] | |
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}")) | |
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
thumbnail_path = selected_activity['thumbnail'] | |
history.append((None, f"📷 Video frame at {selected_activity['time']}")) | |
history.append((None, thumbnail_path)) | |
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}" | |
if selected_activity['objects']: | |
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}" | |
history.append(("Tell me about this video segment", activity_info)) | |
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path] | |
def chat_with_video(message, history): | |
"""Chat with the mPLUG model about the selected video segment""" | |
global global_activities, global_media_path | |
try: | |
selected_chunk_idx = None | |
selected_time = None | |
selected_activity = None | |
for entry in history: | |
if entry[0] is None and "Selected video at timestamp" in entry[1]: | |
time_str = entry[1].split("Selected video at timestamp ")[1] | |
selected_time = time_str.strip() | |
break | |
if selected_time: | |
for i, activity in enumerate(global_activities): | |
if activity.get('time') == selected_time: | |
selected_chunk_idx = activity.get('chunk_id') | |
selected_activity = activity | |
break | |
# If we found the chunk, use the model to analyze it | |
if selected_chunk_idx is not None and global_media_path and selected_activity: | |
# Load model | |
model, tokenizer, processor = load_model_and_tokenizer() | |
context = f"This video shows construction site activities at timestamp {selected_time}." | |
if selected_activity.get('objects'): | |
context += f" The scene contains {', '.join(selected_activity.get('objects'))}." | |
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}" | |
# This would ideally use the specific chunk, but for simplicity we'll use the global path | |
# In a production system, you'd extract just that chunk of the video | |
vr = VideoReader(global_media_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
chunk_size = MAX_NUM_FRAMES | |
start_idx = selected_chunk_idx * chunk_size | |
end_idx = min(start_idx + chunk_size, len(frame_idx)) | |
chunk_frames = frame_idx[start_idx:end_idx] | |
if chunk_frames: | |
frames = vr.get_batch(chunk_frames).asnumpy() | |
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames] | |
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt) | |
del model, tokenizer, processor | |
torch.cuda.empty_cache() | |
gc.collect() | |
return history + [(message, response)] | |
else: | |
return history + [(message, "Could not extract frames for this segment.")] | |
else: | |
thumbnail = None | |
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site." | |
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
thumbnail = selected_activity['thumbnail'] | |
new_history = history + [(message, response_text)] | |
new_history.append((None, f"📷 Video frame at {selected_time}")) | |
new_history.append((None, thumbnail)) | |
return new_history | |
return history + [(message, response_text)] | |
except Exception as e: | |
print(f"Error in chat_with_video: {str(e)}") | |
return history + [(message, f"I encountered an error: {str(e)}")] | |
def create_activity_cards_ui(activities): | |
"""Create activity cards using native Gradio components""" | |
if not activities: | |
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), [] | |
thumbnails = [] | |
captions = [] | |
activity_indices = [] | |
for i, activity in enumerate(activities): | |
thumbnail = activity.get('thumbnail', '') | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
objects_list = activity.get('objects', []) | |
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else "" | |
if len(summary) > 150: | |
summary = summary[:147] + "..." | |
thumbnails.append(thumbnail) | |
captions.append(f"Timestamp: {time} | {summary}") | |
activity_indices.append(i) | |
gallery = gr.Gallery( | |
value=[(path, caption) for path, caption in zip(thumbnails, captions)], | |
columns=5, | |
rows=None, | |
height="auto", | |
object_fit="contain", | |
label="Activity Timeline" | |
) | |
return gallery, activity_indices | |
with gr.Blocks(title="Digital Site Diary", css="") as demo: | |
gr.Markdown("# 📝 Digital Site Diary") | |
activity_data = gr.State([]) | |
activity_indices = gr.State([]) | |
prepared_media_path = gr.State(None) | |
with gr.Tabs() as tabs: | |
with gr.Tab("Site Diary"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### User Input") | |
day = gr.Textbox(label="Day", value='9') | |
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) | |
total_people = gr.Number(label="Total Number of People", precision=0, value=10) | |
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) | |
machinery_types = gr.Textbox( | |
label="Number of Machinery Per Type", | |
placeholder="e.g., Excavator: 2, Roller: 1", | |
value="Excavator: 2, Roller: 1" | |
) | |
activities = gr.Textbox( | |
label="Activity", | |
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", | |
value="9 AM: Excavation, 10 AM: Concreting", | |
lines=3 | |
) | |
media = gr.File(label="Upload Image/Video", file_types=["image", "video"]) | |
azure_video = gr.Dropdown( | |
label="Select Video from Azure", | |
choices=["None"] + get_azure_videos(), | |
value="None" | |
) | |
prepare_btn = gr.Button("Prepare Media") | |
submit_btn = gr.Button("Submit", variant="primary") | |
status_message = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(): | |
gr.Markdown("### Model Detection") | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Total Number of People") | |
model_machinery = gr.Textbox(label="Total Number of Machinery") | |
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") | |
with gr.Row(): | |
gr.Markdown("#### Activities with Timestamps") | |
model_activities = gr.Dataframe( | |
headers=["Time", "Activity Description"], | |
datatype=["str", "str"], | |
label="Detected Activities", | |
interactive=False, | |
wrap=True | |
) | |
with gr.Row(): | |
with gr.Column(visible=True) as timeline_view: | |
activity_gallery = gr.Gallery(label="Activity Timeline") | |
model_annotated_video = gr.Video(label="Full Video") | |
with gr.Column(visible=False) as chat_view: | |
chunk_video = gr.Video(label="Chunk video") | |
chatbot = gr.Chatbot(height=400) | |
chat_input = gr.Textbox( | |
placeholder="Ask about this video segment...", | |
show_label=False | |
) | |
back_btn = gr.Button("← Back to Timeline") | |
prepare_btn.click( | |
fn=prepare_media, | |
inputs=[azure_video, media, prepared_media_path], | |
outputs=[prepared_media_path, status_message] | |
) | |
submit_btn.click( | |
fn=process_diary, | |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, prepared_media_path], | |
outputs=[ | |
model_day, | |
model_date, | |
model_people, | |
model_machinery, | |
model_machinery_types, | |
timeline_view, | |
model_annotated_video, | |
activity_data, | |
chatbot, | |
model_activities | |
] | |
) | |
activity_data.change( | |
fn=create_activity_cards_ui, | |
inputs=[activity_data], | |
outputs=[activity_gallery, activity_indices] | |
) | |
activity_gallery.select( | |
fn=on_card_click, | |
inputs=[activity_indices, chatbot], | |
outputs=[timeline_view, chat_view, chatbot, chunk_video] | |
) | |
chat_input.submit( | |
fn=chat_with_video, | |
inputs=[chat_input, chatbot], | |
outputs=[chatbot] | |
) | |
back_btn.click( | |
fn=lambda: [gr.update(visible=True), gr.update(visible=False)], | |
inputs=None, | |
outputs=[timeline_view, chat_view] | |
) | |
gr.HTML(""" | |
<style> | |
.gradio-container .gallery-item { | |
border: 1px solid #444444 !important; | |
border-radius: 8px !important; | |
padding: 8px !important; | |
margin: 10px !important; | |
cursor: pointer !important; | |
transition: all 0.3s !important; | |
background: #18181b !important; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important; | |
} | |
.gradio-container .gallery-item:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important; | |
border-color: #007bff !important; | |
background: #202025 !important; | |
} | |
.gradio-container .gallery-item.selected { | |
border: 2px solid #007bff !important; | |
background: #202030 !important; | |
} | |
.gradio-container .gallery-item img { | |
height: 180px !important; | |
object-fit: cover !important; | |
border-radius: 4px !important; | |
border: 1px solid #444444 !important; | |
margin-bottom: 8px !important; | |
} | |
.gradio-container .caption { | |
color: #e0e0e0 !important; | |
font-size: 0.9em !important; | |
margin-top: 8px !important; | |
line-height: 1.4 !important; | |
padding: 0 4px !important; | |
} | |
.gradio-container [id*='gallery'] > div:first-child { | |
background-color: #27272a !important; | |
padding: 15px !important; | |
border-radius: 10px !important; | |
} | |
.gradio-container .chatbot { | |
background-color: #27272a !important; | |
border-radius: 10px !important; | |
border: 1px solid #444444 !important; | |
} | |
.gradio-container .chatbot .message.user { | |
background-color: #18181b !important; | |
border-radius: 8px !important; | |
} | |
.gradio-container .chatbot .message.bot { | |
background-color: #202030 !important; | |
border-radius: 8px !important; | |
} | |
.gradio-container button.secondary { | |
background-color: #3d4452 !important; | |
color: white !important; | |
} | |
</style> | |
""") | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["./tmp"]) |