Spaces:
Runtime error
Runtime error
# Construction Site Safety Analyzer - FIXED VERSION | |
# Using Local LLaVA + Llama 3 70B via Groq API | |
# Google Colab Implementation with JSON Error Handling | |
# ============================================================================ | |
# SETUP AND INSTALLATION | |
# ============================================================================ | |
# Cell 1: Install required packages | |
#!pip install transformers torch torchvision Pillow requests opencv-python | |
#!pip install groq accelerate bitsandbytes | |
#!pip install gradio ipywidgets | |
# Cell 2: Import libraries | |
import torch | |
import requests | |
import json | |
import base64 | |
import re | |
from PIL import Image | |
import io | |
import cv2 | |
import numpy as np | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
from groq import Groq | |
import gradio as gr | |
from google.colab import files | |
import matplotlib.pyplot as plt | |
from typing import Dict, List, Optional, Tuple | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Cell 3: Configuration and API Setup | |
class Config: | |
def __init__(self): | |
self.groq_api_key = "" # Set your Groq API key here | |
self.llava_model_name = "llava-hf/llava-v1.6-mistral-7b-hf" | |
self.max_qa_rounds = 5 # Reduced to prevent timeout issues | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def set_groq_key(self, api_key: str): | |
self.groq_api_key = api_key | |
config = Config() | |
# Prompt user for API key | |
from getpass import getpass | |
groq_key = getpass("Enter your Groq API key: ") | |
config.set_groq_key(groq_key) | |
print(f"Using device: {config.device}") | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
# ============================================================================ | |
# LLAVA MODEL SETUP (LOCAL) | |
# ============================================================================ | |
# Cell 4: Load LLaVA Model | |
class LocalLLaVA: | |
def __init__(self, model_name: str, device: str): | |
print("Loading LLaVA model locally...") | |
self.device = device | |
self.processor = LlavaNextProcessor.from_pretrained(model_name) | |
# Load model with appropriate settings for Colab | |
if device == "cuda": | |
self.model = LlavaNextForConditionalGeneration.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
load_in_4bit=True, # Use 4-bit quantization to save memory | |
device_map="auto" | |
) | |
else: | |
self.model = LlavaNextForConditionalGeneration.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
self.model.to(device) | |
print("LLaVA model loaded successfully!") | |
def analyze_image(self, image: Image.Image, question: str = None) -> str: | |
"""Analyze construction site image with optional specific question""" | |
if question is None: | |
# Initial comprehensive analysis prompt | |
prompt = """[INST] <image> | |
You are a construction safety expert analyzing this construction site image. | |
Please provide a detailed analysis covering: | |
1. Overall scene description and type of construction work | |
2. Workers present and their activities | |
3. Heavy machinery and equipment visible | |
4. Safety equipment and PPE compliance | |
5. Visible hazards and safety concerns | |
6. Site organization and conditions | |
Be specific and detailed in your observations. Focus on safety-critical elements. | |
[/INST]""" | |
else: | |
# Specific question prompt | |
prompt = f"[INST] <image>\nAs a construction safety expert, please answer this specific question about the construction site image:\n\n{question}\n\nProvide a detailed and specific answer based on what you can observe in the image.[/INST]" | |
try: | |
# Process inputs | |
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) | |
# Generate response | |
with torch.no_grad(): | |
output = self.model.generate( | |
**inputs, | |
max_new_tokens=500, | |
do_sample=True, | |
temperature=0.1, | |
pad_token_id=self.processor.tokenizer.eos_token_id | |
) | |
# Decode response | |
response = self.processor.decode(output[0], skip_special_tokens=True) | |
# Extract only the generated response (after [/INST]) | |
if "[/INST]" in response: | |
response = response.split("[/INST]")[-1].strip() | |
return response | |
except Exception as e: | |
print(f"Error in LLaVA analysis: {e}") | |
return f"Error analyzing image: {str(e)}" | |
# Initialize LLaVA | |
llava_model = LocalLLaVA(config.llava_model_name, config.device) | |
# ============================================================================ | |
# GROQ LLAMA 3 70B INTEGRATION - FIXED JSON HANDLING | |
# ============================================================================ | |
# Cell 5: Groq Llama Integration with Error Handling | |
class GroqLlamaAnalyzer: | |
def __init__(self, api_key: str): | |
self.client = Groq(api_key=api_key) | |
self.model_name = "llama3-70b-8192" | |
def extract_json_from_text(self, text: str) -> Optional[Dict]: | |
"""Extract JSON from text response, handling various formats""" | |
try: | |
# First, try to parse the entire text as JSON | |
return json.loads(text) | |
except: | |
pass | |
# Look for JSON-like patterns in the text | |
json_patterns = [ | |
r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', # Simple nested JSON | |
r'\{.*?\}', # Basic JSON pattern | |
] | |
for pattern in json_patterns: | |
matches = re.findall(pattern, text, re.DOTALL) | |
for match in matches: | |
try: | |
return json.loads(match) | |
except: | |
continue | |
return None | |
def generate_question(self, context: str, round_num: int) -> Dict: | |
"""Generate dynamic questions based on context analysis""" | |
system_prompt = """You are an expert construction safety analyst. Generate specific questions to gather detailed safety information about construction sites. Always respond in valid JSON format.""" | |
user_prompt = f"""Based on the construction site analysis so far (Round {round_num + 1}): | |
{context[:2000]} # Truncate to prevent token limits | |
Generate ONE specific question to identify safety risks, or respond "ANALYSIS_COMPLETE" if sufficient. | |
Respond ONLY in this exact JSON format: | |
{{"action": "QUESTION", "question": "your specific safety question", "reasoning": "why this question matters for safety"}} | |
OR | |
{{"action": "ANALYSIS_COMPLETE", "reasoning": "sufficient information gathered"}}""" | |
try: | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
], | |
temperature=0.3, | |
max_tokens=300 | |
) | |
response_text = response.choices[0].message.content.strip() | |
print(f"Raw Groq response: {response_text}") | |
# Try to extract JSON | |
result = self.extract_json_from_text(response_text) | |
if result is None: | |
# Fallback: create a question based on round number | |
safety_questions = [ | |
"What personal protective equipment (PPE) are workers wearing or missing?", | |
"Are there any fall protection measures in place for workers at height?", | |
"What heavy machinery is present and are proper safety protocols being followed?", | |
"Are there any visible electrical hazards or unsafe conditions?", | |
"Is the work area properly organized and free of debris or obstacles?" | |
] | |
if round_num < len(safety_questions): | |
result = { | |
"action": "QUESTION", | |
"question": safety_questions[round_num], | |
"reasoning": "Systematic safety assessment" | |
} | |
else: | |
result = { | |
"action": "ANALYSIS_COMPLETE", | |
"reasoning": "Completed systematic safety review" | |
} | |
# Validate result structure | |
if "action" not in result: | |
result["action"] = "ANALYSIS_COMPLETE" | |
if result["action"] == "QUESTION" and "question" not in result: | |
result["action"] = "ANALYSIS_COMPLETE" | |
return result | |
except Exception as e: | |
print(f"Error generating question: {e}") | |
return { | |
"action": "ANALYSIS_COMPLETE", | |
"reasoning": f"Error occurred: {str(e)}" | |
} | |
def final_analysis(self, context: str) -> Dict: | |
"""Generate comprehensive safety analysis with improved error handling""" | |
system_prompt = """You are a senior construction safety expert. Analyze the provided information and create a comprehensive safety assessment. You must respond ONLY in valid JSON format.""" | |
user_prompt = f"""Based on all construction site information: | |
{context[:3000]} # Truncate to prevent token limits | |
Create a comprehensive safety analysis in this EXACT JSON format: | |
{{ | |
"risk_level": "LOW/MODERATE/HIGH/CRITICAL", | |
"confidence_score": "85%", | |
"executive_summary": "Brief overview of main safety findings", | |
"identified_risks": [ | |
"Risk 1 with severity level", | |
"Risk 2 with severity level" | |
], | |
"immediate_actions": [ | |
"Urgent action 1", | |
"Urgent action 2" | |
], | |
"prevention_methods": [ | |
"Prevention method 1", | |
"Prevention method 2" | |
], | |
"regulatory_compliance": [ | |
"Compliance issue 1", | |
"Compliance issue 2" | |
] | |
}} | |
Respond ONLY with valid JSON, no additional text.""" | |
try: | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
], | |
temperature=0.2, | |
max_tokens=1500 | |
) | |
response_text = response.choices[0].message.content.strip() | |
print(f"Raw final analysis response: {response_text}") | |
# Try to extract JSON | |
result = self.extract_json_from_text(response_text) | |
if result is None: | |
# Create a fallback analysis structure | |
result = { | |
"risk_level": "MODERATE", | |
"confidence_score": "75%", | |
"executive_summary": "Analysis completed with limited data processing capabilities.", | |
"identified_risks": ["Unable to fully parse detailed risk assessment"], | |
"immediate_actions": ["Conduct manual safety review"], | |
"prevention_methods": ["Implement standard safety protocols"], | |
"regulatory_compliance": ["Review OSHA compliance standards"] | |
} | |
# Ensure all required fields exist | |
required_fields = ["risk_level", "confidence_score", "executive_summary", | |
"identified_risks", "immediate_actions", "prevention_methods", | |
"regulatory_compliance"] | |
for field in required_fields: | |
if field not in result: | |
result[field] = ["Information not available"] if field.endswith(('_risks', '_actions', '_methods', '_compliance')) else "Not available" | |
return result | |
except Exception as e: | |
print(f"Error in final analysis: {e}") | |
return { | |
"error": str(e), | |
"risk_level": "UNKNOWN", | |
"confidence_score": "0%", | |
"executive_summary": f"Analysis failed due to: {str(e)}", | |
"identified_risks": [f"System error: {str(e)}"], | |
"immediate_actions": ["Manual review required"], | |
"prevention_methods": ["System troubleshooting needed"], | |
"regulatory_compliance": ["Unable to assess due to system error"] | |
} | |
# Initialize Groq analyzer | |
groq_analyzer = GroqLlamaAnalyzer(config.groq_api_key) | |
# ============================================================================ | |
# MAIN ANALYSIS SYSTEM - IMPROVED ERROR HANDLING | |
# ============================================================================ | |
# Cell 6: Complete Analysis System with Better Error Handling | |
class ConstructionSafetyAnalyzer: | |
def __init__(self, llava_model: LocalLLaVA, groq_analyzer: GroqLlamaAnalyzer): | |
self.llava = llava_model | |
self.groq = groq_analyzer | |
self.qa_history = [] | |
self.analysis_context = "" | |
def analyze_construction_site(self, image_path: str) -> Dict: | |
"""Complete construction site safety analysis with improved error handling""" | |
try: | |
# Load and display image | |
image = Image.open(image_path) | |
plt.figure(figsize=(10, 8)) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.title("Construction Site Image for Analysis") | |
plt.show() | |
print("π Starting Construction Site Safety Analysis...") | |
print("=" * 60) | |
# Step 1: Initial LLaVA analysis | |
print("π Step 1: Initial Image Analysis with LLaVA...") | |
initial_analysis = self.llava.analyze_image(image) | |
print("Initial Analysis:") | |
print("-" * 30) | |
print(initial_analysis) | |
print("\n") | |
# Initialize context | |
self.analysis_context = f"Initial Visual Analysis:\n{initial_analysis}\n\n" | |
self.qa_history = [] | |
# Step 2: Interactive Q&A rounds with error handling | |
print("π€ Step 2: Dynamic Question Generation and Analysis...") | |
print("=" * 60) | |
round_num = 0 | |
max_rounds = config.max_qa_rounds | |
consecutive_errors = 0 | |
while round_num < max_rounds and consecutive_errors < 3: | |
print(f"\nπ Round {round_num + 1}:") | |
print("-" * 20) | |
try: | |
# Generate question with Llama | |
print("π§ Llama 3 70B analyzing and generating question...") | |
question_result = self.groq.generate_question(self.analysis_context, round_num) | |
if question_result["action"] == "ANALYSIS_COMPLETE": | |
print("β Analysis determined complete.") | |
print(f"Reasoning: {question_result.get('reasoning', 'Analysis complete')}") | |
break | |
question = question_result.get("question", "") | |
reasoning = question_result.get("reasoning", "") | |
if not question: | |
print("β οΈ No question generated, moving to final analysis.") | |
break | |
print(f"Generated Question: {question}") | |
print(f"Reasoning: {reasoning}") | |
# Get answer from LLaVA | |
print("ποΈ LLaVA analyzing specific aspect...") | |
answer = self.llava.analyze_image(image, question) | |
print(f"LLaVA Response: {answer}") | |
# Store Q&A | |
qa_round = { | |
"round": round_num + 1, | |
"question": question, | |
"answer": answer, | |
"reasoning": reasoning | |
} | |
self.qa_history.append(qa_round) | |
# Update context | |
self.analysis_context += f"Q{round_num + 1}: {question}\nA{round_num + 1}: {answer}\nReasoning: {reasoning}\n\n" | |
consecutive_errors = 0 # Reset error counter on success | |
except Exception as e: | |
print(f"β οΈ Error in round {round_num + 1}: {e}") | |
consecutive_errors += 1 | |
if consecutive_errors >= 3: | |
print("π Too many consecutive errors, proceeding to final analysis.") | |
break | |
round_num += 1 | |
# Step 3: Final comprehensive analysis | |
print("\nπ Step 3: Generating Comprehensive Safety Report...") | |
print("=" * 60) | |
final_analysis = self.groq.final_analysis(self.analysis_context) | |
return { | |
"initial_analysis": initial_analysis, | |
"qa_rounds": self.qa_history, | |
"final_analysis": final_analysis, | |
"total_rounds": len(self.qa_history), | |
"status": "completed" | |
} | |
except Exception as e: | |
print(f"π¨ Critical error in analysis: {e}") | |
return { | |
"error": str(e), | |
"status": "failed", | |
"initial_analysis": "Failed to analyze image", | |
"qa_rounds": [], | |
"final_analysis": { | |
"risk_level": "UNKNOWN", | |
"confidence_score": "0%", | |
"executive_summary": f"Analysis failed: {str(e)}", | |
"identified_risks": [f"System error: {str(e)}"], | |
"immediate_actions": ["Manual analysis required"], | |
"prevention_methods": ["System troubleshooting needed"], | |
"regulatory_compliance": ["Unable to assess"] | |
}, | |
"total_rounds": 0 | |
} | |
def display_results(self, results: Dict): | |
"""Display formatted analysis results with error handling""" | |
print("\n" + "=" * 80) | |
print("ποΈ CONSTRUCTION SITE SAFETY ANALYSIS REPORT") | |
print("=" * 80) | |
if results.get("status") == "failed": | |
print(f"\nβ ANALYSIS FAILED") | |
print("-" * 40) | |
print(f"Error: {results.get('error', 'Unknown error')}") | |
return | |
# Executive Summary | |
final = results.get("final_analysis", {}) | |
print(f"\nπ― EXECUTIVE SUMMARY") | |
print("-" * 40) | |
print(f"Risk Level: {final.get('risk_level', 'Unknown')}") | |
print(f"Confidence: {final.get('confidence_score', 'Unknown')}") | |
print(f"Summary: {final.get('executive_summary', 'No summary available')}") | |
# Q&A Summary | |
print(f"\nπ ANALYSIS PROCESS") | |
print("-" * 40) | |
print(f"Total Investigation Rounds: {results.get('total_rounds', 0)}") | |
for qa in results.get("qa_rounds", []): | |
print(f"\nRound {qa['round']}: {qa['question']}") | |
answer_preview = qa['answer'][:100] + "..." if len(qa['answer']) > 100 else qa['answer'] | |
print(f"Answer: {answer_preview}") | |
# Risk Assessment | |
risks = final.get("identified_risks", []) | |
if risks and risks != ["Information not available"]: | |
print(f"\nβ οΈ IDENTIFIED RISKS") | |
print("-" * 40) | |
for i, risk in enumerate(risks, 1): | |
print(f"{i}. {risk}") | |
# Immediate Actions | |
actions = final.get("immediate_actions", []) | |
if actions and actions != ["Information not available"]: | |
print(f"\nπ¨ IMMEDIATE ACTIONS REQUIRED") | |
print("-" * 40) | |
for i, action in enumerate(actions, 1): | |
print(f"{i}. {action}") | |
# Prevention Methods | |
methods = final.get("prevention_methods", []) | |
if methods and methods != ["Information not available"]: | |
print(f"\nπ‘οΈ PREVENTION METHODS") | |
print("-" * 40) | |
for i, method in enumerate(methods, 1): | |
print(f"{i}. {method}") | |
# Regulatory Compliance | |
compliance = final.get("regulatory_compliance", []) | |
if compliance and compliance != ["Information not available"]: | |
print(f"\nπ REGULATORY COMPLIANCE ISSUES") | |
print("-" * 40) | |
for i, issue in enumerate(compliance, 1): | |
print(f"{i}. {issue}") | |
# Initialize the complete system | |
analyzer = ConstructionSafetyAnalyzer(llava_model, groq_analyzer) | |
# ============================================================================ | |
# IMPROVED GRADIO INTERFACE | |
# ============================================================================ | |
# Cell 7: Create Improved Gradio Interface | |
def create_gradio_interface(): | |
def analyze_uploaded_image(image): | |
if image is None: | |
return "Please upload an image first." | |
# Save temporary image | |
temp_path = "/tmp/construction_site.jpg" | |
image.save(temp_path) | |
try: | |
# Run analysis | |
results = analyzer.analyze_construction_site(temp_path) | |
if results.get("status") == "failed": | |
return f"# β Analysis Failed\n\nError: {results.get('error', 'Unknown error')}\n\nPlease try again or check your API configuration." | |
# Format results for display | |
final = results.get("final_analysis", {}) | |
report = f""" | |
# ποΈ Construction Site Safety Analysis Report | |
## π― Executive Summary | |
- **Risk Level**: {final.get('risk_level', 'Unknown')} | |
- **Confidence**: {final.get('confidence_score', 'Unknown')} | |
- **Summary**: {final.get('executive_summary', 'No summary available')} | |
## π Analysis Process | |
- **Total Investigation Rounds**: {results.get('total_rounds', 0)} | |
- **Status**: {results.get('status', 'Unknown')} | |
### Question & Answer Rounds: | |
""" | |
for qa in results.get("qa_rounds", []): | |
report += f"\n**Round {qa['round']}**: {qa['question']}\n" | |
report += f"*Answer*: {qa['answer'][:200]}{'...' if len(qa['answer']) > 200 else ''}\n" | |
risks = final.get("identified_risks", []) | |
if risks and risks != ["Information not available"]: | |
report += "\n## β οΈ Identified Risks\n" | |
for i, risk in enumerate(risks, 1): | |
report += f"{i}. {risk}\n" | |
actions = final.get("immediate_actions", []) | |
if actions and actions != ["Information not available"]: | |
report += "\n## π¨ Immediate Actions Required\n" | |
for i, action in enumerate(actions, 1): | |
report += f"{i}. {action}\n" | |
methods = final.get("prevention_methods", []) | |
if methods and methods != ["Information not available"]: | |
report += "\n## π‘οΈ Prevention Methods\n" | |
for i, method in enumerate(methods, 1): | |
report += f"{i}. {method}\n" | |
return report | |
except Exception as e: | |
return f"# β Error During Analysis\n\n```\n{str(e)}\n```\n\nPlease check your configuration and try again." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=analyze_uploaded_image, | |
inputs=gr.Image(type="pil", label="Upload Construction Site Image"), | |
outputs=gr.Markdown(label="Safety Analysis Report"), | |
title="ποΈ Construction Site Safety Analyzer (Fixed Version)", | |
description="Upload a construction site image for comprehensive safety analysis using LLaVA + Llama 3 70B. This version includes improved error handling and JSON parsing.", | |
examples=None | |
) | |
return iface | |
# ============================================================================ | |
# EXAMPLE USAGE AND TESTING | |
# ============================================================================ | |
# Cell 8: Test the Fixed System | |
def test_system(): | |
"""Test the fixed system with better error handling""" | |
print("π§ͺ Testing Fixed Construction Safety Analyzer System...") | |
# Test 1: Check model loading | |
print("β Test 1: Models loaded successfully") | |
print(f" - LLaVA model: {llava_model.model.__class__.__name__}") | |
print(f" - Groq client: {groq_analyzer.client.__class__.__name__}") | |
# Test 2: Check API connectivity with better error handling | |
try: | |
test_response = groq_analyzer.client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=[{"role": "user", "content": "Hello, this is a test."}], | |
max_tokens=10 | |
) | |
print("β Test 2: Groq API connection successful") | |
except Exception as e: | |
print(f"β Test 2: Groq API connection failed: {e}") | |
print(" Please check your API key and internet connection.") | |
# Test 3: JSON parsing function | |
test_json = '{"action": "QUESTION", "question": "Test question"}' | |
result = groq_analyzer.extract_json_from_text(test_json) | |
if result and "action" in result: | |
print("β Test 3: JSON parsing function working") | |
else: | |
print("β Test 3: JSON parsing function failed") | |
print("π System test completed!") | |
# Run system test | |
test_system() | |
# Launch Gradio interface | |
print("π Creating Fixed Gradio Interface...") | |
interface = create_gradio_interface() | |
interface.launch(share=True, debug=True) | |
print(""" | |
ποΈ FIXED CONSTRUCTION SITE SAFETY ANALYZER - READY TO USE! | |
π§ IMPROVEMENTS MADE: | |
- β Fixed JSON parsing errors with robust extraction | |
- β Added comprehensive error handling | |
- β Reduced max Q&A rounds to prevent timeouts | |
- β Added fallback questions for systematic analysis | |
- β Improved response validation | |
- β Better error messages and debugging | |
π INSTRUCTIONS: | |
1. Ensure your Groq API key is set correctly | |
2. Upload a construction site image | |
3. The system will now handle JSON errors gracefully | |
4. View comprehensive safety analysis with improved reliability | |
π READY TO ANALYZE CONSTRUCTION SITE SAFETY WITH IMPROVED RELIABILITY! | |
""") |