|
import os |
|
import base64 |
|
import gradio as gr |
|
from PIL import Image, ImageOps |
|
import io |
|
import json |
|
from groq import Groq |
|
import logging |
|
import cv2 |
|
import numpy as np |
|
import traceback |
|
from datetime import datetime |
|
import tempfile |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") |
|
if not GROQ_API_KEY: |
|
logger.error("GROQ_API_KEY is not set in environment variables") |
|
raise ValueError("GROQ_API_KEY is not set") |
|
|
|
|
|
client = Groq(api_key=GROQ_API_KEY) |
|
|
|
def encode_image(image): |
|
try: |
|
if isinstance(image, str): |
|
with open(image, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
elif isinstance(image, Image.Image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
elif isinstance(image, np.ndarray): |
|
is_success, buffer = cv2.imencode(".png", image) |
|
if is_success: |
|
return base64.b64encode(buffer).decode('utf-8') |
|
else: |
|
raise ValueError(f"Unsupported image type: {type(image)}") |
|
except Exception as e: |
|
logger.error(f"Error encoding image: {str(e)}") |
|
raise |
|
|
|
def resize_image(image, max_size=(800, 800)): |
|
"""Resize image to avoid exceeding the API size limits.""" |
|
try: |
|
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
|
return image |
|
except Exception as e: |
|
logger.error(f"Error resizing image: {str(e)}") |
|
raise |
|
|
|
def extract_frames_from_video(video, frame_points=[0, 0.5, 1], max_size=(800, 800)): |
|
"""Extract key frames from the video at specific time points.""" |
|
cap = cv2.VideoCapture(video) |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
duration = frame_count / fps |
|
|
|
frames = [] |
|
for time_point in frame_points: |
|
cap.set(cv2.CAP_PROP_POS_MSEC, time_point * duration * 1000) |
|
ret, frame = cap.read() |
|
if ret: |
|
resized_frame = cv2.resize(frame, max_size) |
|
frames.append(resized_frame) |
|
cap.release() |
|
return frames |
|
|
|
def analyze_file(file): |
|
"""Analyze a single file (image or video)""" |
|
try: |
|
file_type = file.name.split('.')[-1].lower() |
|
if file_type in ['jpg', 'jpeg', 'png', 'bmp']: |
|
return analyze_image(file) |
|
elif file_type in ['mp4', 'avi', 'mov', 'webm']: |
|
return analyze_video(file) |
|
else: |
|
return "Unsupported file type. Please upload an image or video file." |
|
except Exception as e: |
|
logger.error(f"Error analyzing file: {str(e)}") |
|
return f"Error analyzing file: {str(e)}" |
|
|
|
def analyze_image(image_file): |
|
image = Image.open(image_file.name) |
|
resized_image = resize_image(image) |
|
image_data_url = f"data:image/png;base64,{encode_image(resized_image)}" |
|
|
|
instruction = ("You are an AI assistant specialized in analyzing images for safety issues. " |
|
"Your task is first to explain what you see in the image and determine if the image shows a construction site. " |
|
"If it does, identify any safety issues or hazards, categorize them, and provide a detailed description, " |
|
"and suggest steps to resolve them. If it's not a construction site, simply state that") |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": f"{instruction}\n\nAnalyze this image. First, determine if it's a construction site. If it is, explain the image in detail, focusing on safety aspects. If it's not, briefly describe what you see." |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": image_data_url |
|
} |
|
} |
|
] |
|
} |
|
] |
|
|
|
completion = client.chat.completions.create( |
|
model="llama-3.2-90b-vision-preview", |
|
messages=messages, |
|
temperature=0.7, |
|
max_tokens=1000, |
|
top_p=1, |
|
stream=False, |
|
stop=None |
|
) |
|
|
|
return completion.choices[0].message.content |
|
|
|
def analyze_video(video_file): |
|
frames = extract_frames_from_video(video_file.name) |
|
results = [] |
|
|
|
instruction = ("You are an AI assistant specialized in analyzing images for safety issues. " |
|
"Your task is first to explain what you see in the image and determine if the image shows a construction site. " |
|
"If it does, identify any safety issues or hazards, categorize them, and provide a detailed description, " |
|
"and suggest steps to resolve them. If it's not a construction site, simply state that") |
|
|
|
for i, frame in enumerate(frames): |
|
image_data_url = f"data:image/png;base64,{encode_image(frame)}" |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": f"{instruction}\n\nAnalyze this frame from a video (Frame {i+1}/{len(frames)}). First, explain the video and then determine if it's a construction site. If it is, explain what you observe, focusing on safety aspects. If it's not, briefly describe what you see." |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": image_data_url |
|
} |
|
} |
|
] |
|
} |
|
] |
|
completion = client.chat.completions.create( |
|
model="llama-3.2-90b-vision-preview", |
|
messages=messages, |
|
temperature=0.7, |
|
max_tokens=1000, |
|
top_p=1, |
|
stream=False, |
|
stop=None |
|
) |
|
results.append(f"Frame {i+1} analysis:\n{completion.choices[0].message.content}\n\n") |
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
def chat_about_image(message, chat_history): |
|
try: |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are an AI assistant specialized in analyzing construction site images and answering questions about them. Use the information from the initial analysis to answer user queries."}, |
|
] |
|
|
|
|
|
for human, ai in chat_history: |
|
if human: |
|
messages.append({"role": "user", "content": human}) |
|
if ai: |
|
messages.append({"role": "assistant", "content": ai}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
completion = client.chat.completions.create( |
|
model="llama-3.2-90b-vision-preview", |
|
messages=messages, |
|
temperature=0.7, |
|
max_tokens=500, |
|
top_p=1, |
|
stream=False, |
|
stop=None |
|
) |
|
|
|
response = completion.choices[0].message.content |
|
chat_history.append((message, response)) |
|
|
|
return "", chat_history |
|
except Exception as e: |
|
logger.error(f"Error during chat: {str(e)}") |
|
return "", chat_history + [(message, f"Error: {str(e)}")] |
|
|
|
def generate_summary_report(chat_history): |
|
""" |
|
Generate a summary report from the chat history. |
|
""" |
|
report = "Construction Site Safety Analysis Report\n" |
|
report += "=" * 40 + "\n" |
|
report += f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
|
|
|
for i, (user, ai) in enumerate(chat_history, 1): |
|
if user: |
|
report += f"Query {i}:\n{user}\n\n" |
|
if ai: |
|
report += f"Analysis {i}:\n{ai}\n\n" |
|
report += "-" * 40 + "\n" |
|
|
|
return report |
|
|
|
def download_report(chat_history): |
|
""" |
|
Generate and provide a download link for the summary report. |
|
""" |
|
report = generate_summary_report(chat_history) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"safety_analysis_report_{timestamp}.txt" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as temp_file: |
|
temp_file.write(report) |
|
temp_file_path = temp_file.name |
|
|
|
return temp_file_path |
|
|
|
|
|
|
|
custom_css = """ |
|
.container { max-width: 1200px; margin: auto; padding-top: 1.5rem; } |
|
.header { text-align: center; margin-bottom: 1rem; } |
|
.header h1 { color: #2c3e50; font-size: 2.5rem; } |
|
.subheader { |
|
color: #34495e; |
|
font-size: 1rem; |
|
line-height: 1.2; |
|
margin-bottom: 1.5rem; |
|
text-align: center; |
|
padding: 0 15px; |
|
white-space: nowrap; |
|
overflow: hidden; |
|
text-overflow: ellipsis; |
|
} |
|
.image-container { border: 2px dashed #3498db; border-radius: 10px; padding: 1rem; text-align: center; margin-bottom: 1rem; } |
|
.analyze-button { background-color: #2ecc71 !important; color: white !important; width: 100%; } |
|
.clear-button { background-color: #e74c3c !important; color: white !important; width: 100px !important; } |
|
.chatbot { border: 1px solid #bdc3c7; border-radius: 10px; padding: 1rem; height: 500px; overflow-y: auto; } |
|
.chat-input { border: 1px solid #bdc3c7; border-radius: 5px; padding: 0.5rem; width: 100%; } |
|
.groq-badge { position: fixed; bottom: 10px; right: 10px; background-color: #f39c12; color: white; padding: 5px 10px; border-radius: 5px; font-weight: bold; } |
|
.chat-container { display: flex; flex-direction: column; height: 100%; } |
|
.input-row { display: flex; align-items: center; margin-top: 10px; justify-content: space-between; } |
|
.input-row > div:first-child { flex-grow: 1; margin-right: 10px; } |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface: |
|
gr.HTML( |
|
""" |
|
<div class="container"> |
|
<div class="header"> |
|
<h1>ποΈ Construction Site Safety Analyzer</h1> |
|
</div> |
|
<p class="subheader">Enhance workplace safety and compliance with AI-powered image and video analysis using Llama 3.2 90B Vision and expert chat assistance.</p> |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
file_input = gr.File( |
|
label="Upload Construction Site Images or Videos", |
|
file_count="multiple", |
|
type="filepath", |
|
elem_classes="file-container" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
analyze_button = gr.Button("π Analyze Safety Hazards", elem_classes="analyze-button") |
|
|
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot( |
|
label="Safety Analysis Results and Expert Chat", |
|
elem_classes="chatbot", |
|
show_share_button=False, |
|
show_copy_button=False |
|
) |
|
|
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
label="Ask about safety measures or regulations", |
|
placeholder="E.g., 'What OSHA guidelines apply to this hazard?'", |
|
show_label=False, |
|
elem_classes="chat-input" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
clear = gr.Button("ποΈ Clear Chat", elem_classes="clear-button") |
|
download_button = gr.Button("π₯ Download Report", elem_classes="download-button") |
|
|
|
|
|
report_file = gr.File(label="Download Safety Analysis Report") |
|
|
|
def process_files(files): |
|
results = [] |
|
for file in files: |
|
result = analyze_file(file) |
|
results.append((file.name, result)) |
|
return results |
|
|
|
def update_chat(history, new_messages): |
|
history = history or [] |
|
for title, content in new_messages: |
|
history.append((None, f"{title}\n\n{content}")) |
|
return history |
|
|
|
analyze_button.click( |
|
process_files, |
|
inputs=[file_input], |
|
outputs=[chatbot], |
|
postprocess=lambda x: update_chat(chatbot.value, x) |
|
) |
|
|
|
msg.submit(chat_about_image, [msg, chatbot], [msg, chatbot]) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
download_button.click( |
|
download_report, |
|
inputs=[chatbot], |
|
outputs=[report_file] |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<div class="groq-badge">Powered by Groq</div> |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
iface.launch(debug=True) |
|
except Exception as e: |
|
logger.error(f"Error when trying to launch the interface: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
print("Failed to launch the Gradio interface. Please check the logs for more information.") |