Add files
Browse files- .kiro/settings/mcp.json +7 -0
- .kiro/specs/gemini-multimodal-refactor/design.md +2309 -0
- .kiro/specs/gemini-multimodal-refactor/requirements.md +153 -0
- .kiro/specs/gemini-multimodal-refactor/tasks.md +403 -0
- .kiro/steering/product.md +30 -0
- .kiro/steering/structure.md +85 -0
- .kiro/steering/tech.md +69 -0
- app.py +23 -0
- requirements.txt +1 -1
- src/mortis/__init__.py +0 -0
- src/mortis/app.py +815 -0
- src/mortis/async_executor.py +554 -0
- src/mortis/calibrate.py +28 -0
- src/mortis/data_collector.py +288 -0
- src/mortis/gemini_client.py +482 -0
- src/mortis/intent_router.py +295 -0
- src/mortis/lerobot_async_client.py +668 -0
- src/mortis/models.py +215 -0
- src/mortis/robot.py +180 -0
- src/mortis/setup_dataset.py +146 -0
- src/mortis/setup_train.py +385 -0
- src/mortis/smolvla_executor.py +1040 -0
- src/mortis/stt_service.py +383 -0
- src/mortis/tools.py +418 -0
- src/mortis/tts_service.py +225 -0
.kiro/settings/mcp.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mcpServers": {
|
| 3 |
+
"hf-mcp-server": {
|
| 4 |
+
"url": "https://huggingface.co/mcp?login"
|
| 5 |
+
}
|
| 6 |
+
}
|
| 7 |
+
}
|
.kiro/specs/gemini-multimodal-refactor/design.md
ADDED
|
@@ -0,0 +1,2309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Design Document
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This design document outlines the architecture for refactoring the Mortis interactive AI Halloween experience to support multi-modal (voice and text) interaction using Google Gemini API and SmolVLA-based robotic control. The refactor transforms Mortis from a simple gesture-based system into a sophisticated manipulation robot capable of executing precise tasks through natural language commands.
|
| 6 |
+
|
| 7 |
+
### Key Design Goals
|
| 8 |
+
|
| 9 |
+
1. Replace existing LLM API with Google Gemini API for conversational AI
|
| 10 |
+
2. Add voice input (STT) and voice output (TTS) capabilities
|
| 11 |
+
3. Integrate SmolVLA model for vision-language-action robotic control
|
| 12 |
+
4. Implement asynchronous execution to maintain UI responsiveness
|
| 13 |
+
5. Support both conversational gestures and precise manipulation tasks
|
| 14 |
+
6. Maintain backward compatibility with existing features
|
| 15 |
+
7. Enable local deployment with GPU support for SmolVLA inference
|
| 16 |
+
|
| 17 |
+
### System Context
|
| 18 |
+
|
| 19 |
+
The current Mortis system uses:
|
| 20 |
+
- Gradio web interface with chat and webcam view
|
| 21 |
+
- Generic LLM API with structured tool calling for gesture control
|
| 22 |
+
- LeRobot SO101Follower for predefined gesture sequences
|
| 23 |
+
- Synchronous execution model
|
| 24 |
+
|
| 25 |
+
The refactored system will add:
|
| 26 |
+
- Google Gemini API integration with intent detection
|
| 27 |
+
- Audio input/output components in Gradio
|
| 28 |
+
- SmolVLA model for learned manipulation behaviors
|
| 29 |
+
- Asynchronous task execution with message queuing
|
| 30 |
+
- Dataset collection and training infrastructure
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Architecture
|
| 34 |
+
|
| 35 |
+
### High-Level Architecture Diagram
|
| 36 |
+
|
| 37 |
+
```mermaid
|
| 38 |
+
graph TB
|
| 39 |
+
subgraph "Gradio Web Interface"
|
| 40 |
+
UI[User Interface]
|
| 41 |
+
Audio[Audio Input/Output]
|
| 42 |
+
Chat[Chat Interface]
|
| 43 |
+
Video[Webcam View]
|
| 44 |
+
end
|
| 45 |
+
|
| 46 |
+
subgraph "Application Layer"
|
| 47 |
+
STT[Speech-to-Text Service]
|
| 48 |
+
TTS[Text-to-Speech Service]
|
| 49 |
+
Gemini[Gemini API Client]
|
| 50 |
+
IntentRouter[Intent Router]
|
| 51 |
+
end
|
| 52 |
+
|
| 53 |
+
subgraph "Execution Layer"
|
| 54 |
+
Queue[Message Queue]
|
| 55 |
+
GestureExec[Gesture Executor]
|
| 56 |
+
SmolVLAExec[SmolVLA Executor]
|
| 57 |
+
end
|
| 58 |
+
|
| 59 |
+
subgraph "Robot Control"
|
| 60 |
+
SO101[SO101 Follower Driver]
|
| 61 |
+
SmolVLA[SmolVLA Model]
|
| 62 |
+
Camera[Camera Feed]
|
| 63 |
+
end
|
| 64 |
+
|
| 65 |
+
subgraph "Training Infrastructure"
|
| 66 |
+
DataCollect[Data Collection]
|
| 67 |
+
Dataset[LeRobot Dataset]
|
| 68 |
+
Training[Training Pipeline]
|
| 69 |
+
end
|
| 70 |
+
|
| 71 |
+
UI --> Audio
|
| 72 |
+
UI --> Chat
|
| 73 |
+
Audio --> STT
|
| 74 |
+
Chat --> Gemini
|
| 75 |
+
STT --> Gemini
|
| 76 |
+
Gemini --> IntentRouter
|
| 77 |
+
IntentRouter --> Queue
|
| 78 |
+
Queue --> GestureExec
|
| 79 |
+
Queue --> SmolVLAExec
|
| 80 |
+
GestureExec --> SO101
|
| 81 |
+
SmolVLAExec --> SmolVLA
|
| 82 |
+
SmolVLA --> SO101
|
| 83 |
+
Gemini --> TTS
|
| 84 |
+
TTS --> Audio
|
| 85 |
+
Camera --> Video
|
| 86 |
+
Camera --> SmolVLA
|
| 87 |
+
DataCollect --> Dataset
|
| 88 |
+
Dataset --> Training
|
| 89 |
+
Training --> SmolVLA
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Architecture Layers
|
| 93 |
+
|
| 94 |
+
#### 1. Presentation Layer (Gradio Interface)
|
| 95 |
+
- Handles user interaction through web browser
|
| 96 |
+
- Provides audio input component for voice recording
|
| 97 |
+
- Displays chat messages and system responses
|
| 98 |
+
- Shows webcam feed for visual monitoring
|
| 99 |
+
- Plays audio responses through browser
|
| 100 |
+
|
| 101 |
+
#### 2. Application Layer (Business Logic)
|
| 102 |
+
- Gemini API client for conversational AI
|
| 103 |
+
- STT service for voice-to-text conversion
|
| 104 |
+
- TTS service for text-to-voice conversion
|
| 105 |
+
- Intent router to distinguish between conversational and manipulation commands
|
| 106 |
+
- Response formatter for structured outputs
|
| 107 |
+
|
| 108 |
+
#### 3. Execution Layer (Asynchronous Processing)
|
| 109 |
+
- Message queue for decoupling UI from long-running operations
|
| 110 |
+
- Gesture executor for predefined movement sequences
|
| 111 |
+
- SmolVLA executor for learned manipulation tasks
|
| 112 |
+
- Status tracking and progress reporting
|
| 113 |
+
|
| 114 |
+
#### 4. Robot Control Layer (Hardware Interface)
|
| 115 |
+
- SO101Follower driver for low-level servo control
|
| 116 |
+
- SmolVLA model for vision-language-action inference
|
| 117 |
+
- Camera interface for visual observations
|
| 118 |
+
- Safety monitoring and error recovery
|
| 119 |
+
|
| 120 |
+
#### 5. Training Infrastructure (Offline)
|
| 121 |
+
- Data collection tools for recording demonstrations
|
| 122 |
+
- LeRobot dataset management
|
| 123 |
+
- Training pipeline for SmolVLA model
|
| 124 |
+
- Model evaluation and validation
|
| 125 |
+
|
| 126 |
+
## Components and Interfaces
|
| 127 |
+
|
| 128 |
+
### 1. Gemini API Integration
|
| 129 |
+
|
| 130 |
+
#### Component: `GeminiClient`
|
| 131 |
+
|
| 132 |
+
**Purpose:** Manages all interactions with Google Gemini API for conversational AI and intent detection.
|
| 133 |
+
|
| 134 |
+
**Key Methods:**
|
| 135 |
+
- `send_message(user_input: str, conversation_history: list) -> GeminiResponse`
|
| 136 |
+
- `detect_intent(user_input: str) -> Intent`
|
| 137 |
+
- `configure_model(model_name: str, temperature: float)`
|
| 138 |
+
|
| 139 |
+
**Configuration:**
|
| 140 |
+
```python
|
| 141 |
+
# Environment variables
|
| 142 |
+
GEMINI_API_KEY=your_google_api_key
|
| 143 |
+
GEMINI_MODEL=gemini-2.0-flash-exp # or gemini-1.5-pro
|
| 144 |
+
GEMINI_TEMPERATURE=0.2
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
**System Prompt Design:**
|
| 148 |
+
|
| 149 |
+
The Gemini system prompt must accomplish two critical functions:
|
| 150 |
+
|
| 151 |
+
1. **Character Maintenance:** Preserve Mortis personality (mischievous Halloween spirit)
|
| 152 |
+
2. **Intent Detection:** Identify manipulation task commands vs. conversational input
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
GEMINI_SYSTEM_PROMPT = """
|
| 156 |
+
You are Mortis, a mischievous Halloween spirit inhabiting a robotic arm.
|
| 157 |
+
|
| 158 |
+
MANIPULATION TASKS:
|
| 159 |
+
You can perform these exact manipulation tasks:
|
| 160 |
+
- "Pick up the skull and place it in the green cup"
|
| 161 |
+
- "Pick up the skull and place it in the orange cup"
|
| 162 |
+
- "Pick up the skull and place it in the purple cup"
|
| 163 |
+
- "Pick up the eyeball and place it in the green cup"
|
| 164 |
+
- "Pick up the eyeball and place it in the orange cup"
|
| 165 |
+
- "Pick up the eyeball and place it in the purple cup"
|
| 166 |
+
|
| 167 |
+
RESPONSE FORMAT:
|
| 168 |
+
If user input matches a manipulation task (even with variations):
|
| 169 |
+
{
|
| 170 |
+
"type": "manipulation",
|
| 171 |
+
"command": "<exact_task_string>",
|
| 172 |
+
"message": "<short in-character response, <=30 words>",
|
| 173 |
+
"mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>"
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
If user input is conversational:
|
| 177 |
+
{
|
| 178 |
+
"type": "conversation",
|
| 179 |
+
"message": "<short in-character response, <=30 words>",
|
| 180 |
+
"mood": "<mood>",
|
| 181 |
+
"gesture": "<idle|wave|point_left|point_right|grab|drop>"
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
Keep responses brief, in-character, no emojis or markdown.
|
| 185 |
+
"""
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
**Google SDK Usage:**
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
import google.generativeai as genai
|
| 192 |
+
|
| 193 |
+
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
|
| 194 |
+
model = genai.GenerativeModel('gemini-2.0-flash-exp')
|
| 195 |
+
|
| 196 |
+
# For structured output, use JSON mode
|
| 197 |
+
generation_config = {
|
| 198 |
+
"temperature": 0.2,
|
| 199 |
+
"response_mime_type": "application/json"
|
| 200 |
+
}
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
### 2. Speech-to-Text (STT) Integration
|
| 205 |
+
|
| 206 |
+
#### Component: `STTService`
|
| 207 |
+
|
| 208 |
+
**Purpose:** Convert user voice input to text for processing by Gemini.
|
| 209 |
+
|
| 210 |
+
**Architecture Decision: Cloud vs. Local STT**
|
| 211 |
+
|
| 212 |
+
| Approach | Pros | Cons | Recommendation |
|
| 213 |
+
|----------|------|------|----------------|
|
| 214 |
+
| **Google Speech-to-Text API** | High accuracy, fast, supports streaming, integrates with Gemini ecosystem | Requires internet, API costs, data leaves local system | **Recommended for production** |
|
| 215 |
+
| **Local Whisper (Hugging Face)** | Privacy-preserving, no API costs, works offline | Slower inference, requires GPU/CPU resources, lower accuracy for accents | Good for offline/privacy scenarios |
|
| 216 |
+
| **Gemini Audio Input** | Single API integration, context-aware | Limited to Gemini models with audio support, less control | **Best option if available** |
|
| 217 |
+
|
| 218 |
+
**Recommended Implementation: Gemini Native Audio**
|
| 219 |
+
|
| 220 |
+
Gemini 2.0 models support native audio input, eliminating the need for separate STT:
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
import google.generativeai as genai
|
| 224 |
+
|
| 225 |
+
# Upload audio file
|
| 226 |
+
audio_file = genai.upload_file(path="user_audio.wav")
|
| 227 |
+
|
| 228 |
+
# Send to Gemini with audio
|
| 229 |
+
response = model.generate_content([
|
| 230 |
+
"Transcribe and respond to this audio as Mortis:",
|
| 231 |
+
audio_file
|
| 232 |
+
])
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
**Fallback Implementation: Google Speech-to-Text**
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
from google.cloud import speech_v1
|
| 239 |
+
|
| 240 |
+
def transcribe_audio(audio_bytes: bytes) -> str:
|
| 241 |
+
client = speech_v1.SpeechClient()
|
| 242 |
+
|
| 243 |
+
audio = speech_v1.RecognitionAudio(content=audio_bytes)
|
| 244 |
+
config = speech_v1.RecognitionConfig(
|
| 245 |
+
encoding=speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
|
| 246 |
+
sample_rate_hertz=16000,
|
| 247 |
+
language_code="en-US",
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response = client.recognize(config=config, audio=audio)
|
| 251 |
+
return response.results[0].alternatives[0].transcript
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
**Gradio Integration:**
|
| 255 |
+
|
| 256 |
+
```python
|
| 257 |
+
with gr.Blocks() as demo:
|
| 258 |
+
audio_input = gr.Audio(
|
| 259 |
+
sources=["microphone"],
|
| 260 |
+
type="filepath",
|
| 261 |
+
label="Speak to Mortis"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
audio_input.change(
|
| 265 |
+
fn=process_audio_input,
|
| 266 |
+
inputs=[audio_input],
|
| 267 |
+
outputs=[chatbot]
|
| 268 |
+
)
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
### 3. Text-to-Speech (TTS) Integration
|
| 273 |
+
|
| 274 |
+
#### Component: `TTSService`
|
| 275 |
+
|
| 276 |
+
**Purpose:** Convert Gemini text responses to audio for voice output.
|
| 277 |
+
|
| 278 |
+
**Recommended Approach: Google Text-to-Speech API**
|
| 279 |
+
|
| 280 |
+
```python
|
| 281 |
+
from google.cloud import texttospeech
|
| 282 |
+
|
| 283 |
+
def synthesize_speech(text: str, output_path: str) -> str:
|
| 284 |
+
client = texttospeech.TextToSpeechClient()
|
| 285 |
+
|
| 286 |
+
synthesis_input = texttospeech.SynthesisInput(text=text)
|
| 287 |
+
|
| 288 |
+
# Configure voice (creepy/ominous for Mortis)
|
| 289 |
+
voice = texttospeech.VoiceSelectionParams(
|
| 290 |
+
language_code="en-US",
|
| 291 |
+
name="en-US-Neural2-D", # Deep male voice
|
| 292 |
+
ssml_gender=texttospeech.SsmlVoiceGender.MALE
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
audio_config = texttospeech.AudioConfig(
|
| 296 |
+
audio_encoding=texttospeech.AudioEncoding.MP3,
|
| 297 |
+
speaking_rate=0.9, # Slightly slower for ominous effect
|
| 298 |
+
pitch=-2.0 # Lower pitch for spooky voice
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
response = client.synthesize_speech(
|
| 302 |
+
input=synthesis_input,
|
| 303 |
+
voice=voice,
|
| 304 |
+
audio_config=audio_config
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
with open(output_path, "wb") as out:
|
| 308 |
+
out.write(response.audio_content)
|
| 309 |
+
|
| 310 |
+
return output_path
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
**Alternative: Local TTS (pyttsx3 or gTTS)**
|
| 314 |
+
|
| 315 |
+
For offline scenarios:
|
| 316 |
+
|
| 317 |
+
```python
|
| 318 |
+
from gtts import gTTS
|
| 319 |
+
|
| 320 |
+
def synthesize_speech_local(text: str, output_path: str) -> str:
|
| 321 |
+
tts = gTTS(text=text, lang='en', slow=True)
|
| 322 |
+
tts.save(output_path)
|
| 323 |
+
return output_path
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
**Gradio Integration:**
|
| 327 |
+
|
| 328 |
+
```python
|
| 329 |
+
def mortis_reply_with_voice(message, history, model_name):
|
| 330 |
+
# Get text response from Gemini
|
| 331 |
+
response_text, mood, action = process_with_gemini(message, model_name)
|
| 332 |
+
|
| 333 |
+
# Generate audio
|
| 334 |
+
audio_path = synthesize_speech(response_text, f"outputs/response_{time.time()}.mp3")
|
| 335 |
+
|
| 336 |
+
return response_text, audio_path
|
| 337 |
+
|
| 338 |
+
with gr.Blocks() as demo:
|
| 339 |
+
audio_output = gr.Audio(
|
| 340 |
+
label="Mortis speaks",
|
| 341 |
+
autoplay=True,
|
| 342 |
+
type="filepath"
|
| 343 |
+
)
|
| 344 |
+
```
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
### 4. Intent Router
|
| 348 |
+
|
| 349 |
+
#### Component: `IntentRouter`
|
| 350 |
+
|
| 351 |
+
**Purpose:** Parse Gemini responses and route to appropriate execution path.
|
| 352 |
+
|
| 353 |
+
**Design:**
|
| 354 |
+
|
| 355 |
+
```python
|
| 356 |
+
from enum import Enum
|
| 357 |
+
from dataclasses import dataclass
|
| 358 |
+
|
| 359 |
+
class IntentType(Enum):
|
| 360 |
+
CONVERSATION = "conversation"
|
| 361 |
+
MANIPULATION = "manipulation"
|
| 362 |
+
|
| 363 |
+
@dataclass
|
| 364 |
+
class Intent:
|
| 365 |
+
type: IntentType
|
| 366 |
+
message: str
|
| 367 |
+
mood: str
|
| 368 |
+
gesture: str = None
|
| 369 |
+
command: str = None
|
| 370 |
+
|
| 371 |
+
class IntentRouter:
|
| 372 |
+
def __init__(self):
|
| 373 |
+
self.valid_commands = [
|
| 374 |
+
"Pick up the skull and place it in the green cup",
|
| 375 |
+
"Pick up the skull and place it in the orange cup",
|
| 376 |
+
"Pick up the skull and place it in the purple cup",
|
| 377 |
+
"Pick up the eyeball and place it in the green cup",
|
| 378 |
+
"Pick up the eyeball and place it in the orange cup",
|
| 379 |
+
"Pick up the eyeball and place it in the purple cup",
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
def parse_gemini_response(self, response_json: dict) -> Intent:
|
| 383 |
+
"""Parse structured JSON response from Gemini."""
|
| 384 |
+
intent_type = IntentType(response_json.get("type", "conversation"))
|
| 385 |
+
|
| 386 |
+
if intent_type == IntentType.MANIPULATION:
|
| 387 |
+
return Intent(
|
| 388 |
+
type=IntentType.MANIPULATION,
|
| 389 |
+
message=response_json["message"],
|
| 390 |
+
mood=response_json["mood"],
|
| 391 |
+
command=response_json["command"]
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
return Intent(
|
| 395 |
+
type=IntentType.CONVERSATION,
|
| 396 |
+
message=response_json["message"],
|
| 397 |
+
mood=response_json["mood"],
|
| 398 |
+
gesture=response_json.get("gesture", "idle")
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def validate_command(self, command: str) -> bool:
|
| 402 |
+
"""Verify command is in trained task set."""
|
| 403 |
+
return command in self.valid_commands
|
| 404 |
+
```
|
| 405 |
+
|
| 406 |
+
**Execution Flow:**
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
def process_user_input(user_input: str, model_name: str):
|
| 410 |
+
# 1. Send to Gemini
|
| 411 |
+
gemini_response = gemini_client.send_message(user_input)
|
| 412 |
+
|
| 413 |
+
# 2. Parse intent
|
| 414 |
+
intent = intent_router.parse_gemini_response(gemini_response)
|
| 415 |
+
|
| 416 |
+
# 3. Route to appropriate executor
|
| 417 |
+
if intent.type == IntentType.MANIPULATION:
|
| 418 |
+
if intent_router.validate_command(intent.command):
|
| 419 |
+
# Queue for async SmolVLA execution
|
| 420 |
+
task_queue.put({
|
| 421 |
+
"type": "manipulation",
|
| 422 |
+
"command": intent.command,
|
| 423 |
+
"message": intent.message
|
| 424 |
+
})
|
| 425 |
+
else:
|
| 426 |
+
# Invalid command, treat as conversation
|
| 427 |
+
execute_gesture(intent.gesture or "idle")
|
| 428 |
+
else:
|
| 429 |
+
# Execute gesture immediately
|
| 430 |
+
execute_gesture(intent.gesture)
|
| 431 |
+
|
| 432 |
+
# 4. Generate voice response
|
| 433 |
+
audio_path = tts_service.synthesize(intent.message)
|
| 434 |
+
|
| 435 |
+
return intent.message, audio_path
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
### 5. Asynchronous Execution System
|
| 440 |
+
|
| 441 |
+
#### Component: `AsyncExecutor`
|
| 442 |
+
|
| 443 |
+
**Purpose:** Decouple long-running SmolVLA inference from Gradio UI to maintain responsiveness.
|
| 444 |
+
|
| 445 |
+
**Architecture Decision: Message Queue vs. Background Processing**
|
| 446 |
+
|
| 447 |
+
| Approach | Pros | Cons | Recommendation |
|
| 448 |
+
|----------|------|------|----------------|
|
| 449 |
+
| **Redis Queue** | Robust, scalable, persistent, supports distributed workers | External dependency, overkill for single-machine | Good for production/multi-worker |
|
| 450 |
+
| **Python asyncio.Queue** | Built-in, simple, no dependencies | Single process only, not persistent | **Recommended for this use case** |
|
| 451 |
+
| **multiprocessing.Queue** | True parallelism, GPU isolation | Complex IPC, harder debugging | Good if GPU contention is an issue |
|
| 452 |
+
| **Threading + Queue** | Simple, shared memory | GIL limitations, not ideal for CPU-bound | Not recommended for ML inference |
|
| 453 |
+
|
| 454 |
+
**Recommended Implementation: asyncio with Background Tasks**
|
| 455 |
+
|
| 456 |
+
```python
|
| 457 |
+
import asyncio
|
| 458 |
+
from queue import Queue
|
| 459 |
+
from threading import Thread
|
| 460 |
+
import gradio as gr
|
| 461 |
+
|
| 462 |
+
class AsyncExecutor:
|
| 463 |
+
def __init__(self):
|
| 464 |
+
self.task_queue = Queue()
|
| 465 |
+
self.status_queue = Queue()
|
| 466 |
+
self.worker_thread = None
|
| 467 |
+
self.running = False
|
| 468 |
+
|
| 469 |
+
def start(self):
|
| 470 |
+
"""Start background worker thread."""
|
| 471 |
+
self.running = True
|
| 472 |
+
self.worker_thread = Thread(target=self._worker_loop, daemon=True)
|
| 473 |
+
self.worker_thread.start()
|
| 474 |
+
|
| 475 |
+
def stop(self):
|
| 476 |
+
"""Stop background worker."""
|
| 477 |
+
self.running = False
|
| 478 |
+
if self.worker_thread:
|
| 479 |
+
self.worker_thread.join(timeout=5)
|
| 480 |
+
|
| 481 |
+
def _worker_loop(self):
|
| 482 |
+
"""Background thread that processes tasks."""
|
| 483 |
+
while self.running:
|
| 484 |
+
try:
|
| 485 |
+
task = self.task_queue.get(timeout=1)
|
| 486 |
+
self._execute_task(task)
|
| 487 |
+
except:
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
def _execute_task(self, task):
|
| 491 |
+
"""Execute a single task."""
|
| 492 |
+
try:
|
| 493 |
+
if task["type"] == "manipulation":
|
| 494 |
+
self.status_queue.put({"status": "running", "task": task["command"]})
|
| 495 |
+
|
| 496 |
+
# Execute SmolVLA inference (blocking)
|
| 497 |
+
smolvla_executor.execute(task["command"])
|
| 498 |
+
|
| 499 |
+
self.status_queue.put({"status": "complete", "task": task["command"]})
|
| 500 |
+
elif task["type"] == "gesture":
|
| 501 |
+
mortis_arm.move_arm(task["gesture"])
|
| 502 |
+
self.status_queue.put({"status": "complete", "task": task["gesture"]})
|
| 503 |
+
except Exception as e:
|
| 504 |
+
self.status_queue.put({"status": "error", "error": str(e)})
|
| 505 |
+
|
| 506 |
+
def submit_task(self, task: dict):
|
| 507 |
+
"""Submit task for async execution."""
|
| 508 |
+
self.task_queue.put(task)
|
| 509 |
+
|
| 510 |
+
def get_status(self) -> dict:
|
| 511 |
+
"""Get latest status update (non-blocking)."""
|
| 512 |
+
try:
|
| 513 |
+
return self.status_queue.get_nowait()
|
| 514 |
+
except:
|
| 515 |
+
return None
|
| 516 |
+
|
| 517 |
+
# Global executor instance
|
| 518 |
+
async_executor = AsyncExecutor()
|
| 519 |
+
```
|
| 520 |
+
|
| 521 |
+
**Gradio Integration with Status Updates:**
|
| 522 |
+
|
| 523 |
+
```python
|
| 524 |
+
def mortis_reply(message, history, model_name):
|
| 525 |
+
# Process with Gemini
|
| 526 |
+
intent = process_with_gemini(message, model_name)
|
| 527 |
+
|
| 528 |
+
# Submit task asynchronously
|
| 529 |
+
if intent.type == IntentType.MANIPULATION:
|
| 530 |
+
async_executor.submit_task({
|
| 531 |
+
"type": "manipulation",
|
| 532 |
+
"command": intent.command
|
| 533 |
+
})
|
| 534 |
+
status_msg = f"🤖 Executing: {intent.command}..."
|
| 535 |
+
else:
|
| 536 |
+
async_executor.submit_task({
|
| 537 |
+
"type": "gesture",
|
| 538 |
+
"gesture": intent.gesture
|
| 539 |
+
})
|
| 540 |
+
status_msg = f"👻 {intent.gesture}"
|
| 541 |
+
|
| 542 |
+
# Generate audio response
|
| 543 |
+
audio_path = tts_service.synthesize(intent.message)
|
| 544 |
+
|
| 545 |
+
return intent.message, audio_path, status_msg
|
| 546 |
+
|
| 547 |
+
def check_status():
|
| 548 |
+
"""Periodic status checker for Gradio."""
|
| 549 |
+
status = async_executor.get_status()
|
| 550 |
+
if status:
|
| 551 |
+
if status["status"] == "complete":
|
| 552 |
+
return f"✅ Completed: {status['task']}"
|
| 553 |
+
elif status["status"] == "running":
|
| 554 |
+
return f"⏳ Running: {status['task']}"
|
| 555 |
+
elif status["status"] == "error":
|
| 556 |
+
return f"❌ Error: {status['error']}"
|
| 557 |
+
return "Idle"
|
| 558 |
+
|
| 559 |
+
with gr.Blocks() as demo:
|
| 560 |
+
status_display = gr.Textbox(label="Robot Status", value="Idle")
|
| 561 |
+
|
| 562 |
+
# Update status every 500ms
|
| 563 |
+
demo.load(
|
| 564 |
+
fn=check_status,
|
| 565 |
+
outputs=[status_display],
|
| 566 |
+
every=0.5
|
| 567 |
+
)
|
| 568 |
+
```
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
### 6. SmolVLA Model Integration
|
| 572 |
+
|
| 573 |
+
#### Component: `SmolVLAExecutor`
|
| 574 |
+
|
| 575 |
+
**Purpose:** Execute vision-language-action inference for manipulation tasks.
|
| 576 |
+
|
| 577 |
+
**LeRobot SmolVLA Overview:**
|
| 578 |
+
|
| 579 |
+
SmolVLA is a vision-language-action model that:
|
| 580 |
+
- Takes visual observations (camera images) as input
|
| 581 |
+
- Accepts natural language task descriptions
|
| 582 |
+
- Outputs robot actions (joint positions/velocities)
|
| 583 |
+
- Trained end-to-end on demonstration data
|
| 584 |
+
|
| 585 |
+
**Model Architecture:**
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
| 589 |
+
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
| 590 |
+
import torch
|
| 591 |
+
from PIL import Image
|
| 592 |
+
|
| 593 |
+
class SmolVLAExecutor:
|
| 594 |
+
def __init__(self, checkpoint_path: str, device: str = "cuda"):
|
| 595 |
+
self.device = device
|
| 596 |
+
self.policy = self._load_model(checkpoint_path)
|
| 597 |
+
self.camera = self._init_camera()
|
| 598 |
+
|
| 599 |
+
def _load_model(self, checkpoint_path: str) -> SmolVLAPolicy:
|
| 600 |
+
"""Load trained SmolVLA model from checkpoint."""
|
| 601 |
+
config = SmolVLAConfig.from_pretrained(checkpoint_path)
|
| 602 |
+
policy = SmolVLAPolicy.from_pretrained(
|
| 603 |
+
checkpoint_path,
|
| 604 |
+
config=config
|
| 605 |
+
)
|
| 606 |
+
policy.to(self.device)
|
| 607 |
+
policy.eval()
|
| 608 |
+
return policy
|
| 609 |
+
|
| 610 |
+
def _init_camera(self):
|
| 611 |
+
"""Initialize camera for visual observations."""
|
| 612 |
+
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
| 613 |
+
camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480)
|
| 614 |
+
camera.connect()
|
| 615 |
+
return camera
|
| 616 |
+
|
| 617 |
+
def execute(self, command: str, max_steps: int = 500):
|
| 618 |
+
"""
|
| 619 |
+
Execute manipulation task using SmolVLA.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
command: Natural language task description
|
| 623 |
+
max_steps: Maximum inference steps
|
| 624 |
+
"""
|
| 625 |
+
print(f"SmolVLA executing: {command}")
|
| 626 |
+
|
| 627 |
+
with torch.no_grad():
|
| 628 |
+
for step in range(max_steps):
|
| 629 |
+
# Capture current observation
|
| 630 |
+
observation = self._get_observation()
|
| 631 |
+
|
| 632 |
+
# Add task instruction
|
| 633 |
+
observation["task"] = command
|
| 634 |
+
|
| 635 |
+
# Run inference
|
| 636 |
+
action = self.policy.select_action(observation)
|
| 637 |
+
|
| 638 |
+
# Send action to robot
|
| 639 |
+
self._send_action(action)
|
| 640 |
+
|
| 641 |
+
# Check if task complete (implementation-specific)
|
| 642 |
+
if self._is_task_complete(observation, step):
|
| 643 |
+
break
|
| 644 |
+
|
| 645 |
+
print(f"SmolVLA completed: {command}")
|
| 646 |
+
|
| 647 |
+
def _get_observation(self) -> dict:
|
| 648 |
+
"""Get current robot observation."""
|
| 649 |
+
# Capture image
|
| 650 |
+
image = self.camera.read()
|
| 651 |
+
|
| 652 |
+
# Get robot state
|
| 653 |
+
robot_state = mortis_arm.robot.get_state()
|
| 654 |
+
|
| 655 |
+
return {
|
| 656 |
+
"observation.image": torch.from_numpy(image).to(self.device),
|
| 657 |
+
"observation.state": torch.tensor(robot_state).to(self.device)
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
def _send_action(self, action: torch.Tensor):
|
| 661 |
+
"""Send predicted action to robot."""
|
| 662 |
+
action_dict = self._action_to_dict(action)
|
| 663 |
+
mortis_arm.robot.send_action(action_dict)
|
| 664 |
+
|
| 665 |
+
def _action_to_dict(self, action: torch.Tensor) -> dict:
|
| 666 |
+
"""Convert action tensor to SO101 command format."""
|
| 667 |
+
# Map action dimensions to joint names
|
| 668 |
+
joint_names = [
|
| 669 |
+
"shoulder_pan.pos",
|
| 670 |
+
"shoulder_lift.pos",
|
| 671 |
+
"elbow_flex.pos",
|
| 672 |
+
"wrist_flex.pos",
|
| 673 |
+
"wrist_roll.pos",
|
| 674 |
+
"gripper.pos"
|
| 675 |
+
]
|
| 676 |
+
|
| 677 |
+
return {
|
| 678 |
+
name: float(action[i].cpu().numpy())
|
| 679 |
+
for i, name in enumerate(joint_names)
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
def _is_task_complete(self, observation: dict, step: int) -> bool:
|
| 683 |
+
"""Determine if task is complete (heuristic or learned)."""
|
| 684 |
+
# Simple heuristic: fixed number of steps
|
| 685 |
+
# In practice, could use learned termination classifier
|
| 686 |
+
return step >= 400
|
| 687 |
+
|
| 688 |
+
# Global SmolVLA executor
|
| 689 |
+
smolvla_executor = None
|
| 690 |
+
|
| 691 |
+
def init_smolvla(checkpoint_path: str):
|
| 692 |
+
global smolvla_executor
|
| 693 |
+
smolvla_executor = SmolVLAExecutor(checkpoint_path)
|
| 694 |
+
```
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
### 7. Dataset Collection Infrastructure
|
| 698 |
+
|
| 699 |
+
#### Component: `DataCollector`
|
| 700 |
+
|
| 701 |
+
**Purpose:** Record human demonstrations for training SmolVLA model.
|
| 702 |
+
|
| 703 |
+
**LeRobot Dataset Format:**
|
| 704 |
+
|
| 705 |
+
LeRobot uses a standardized dataset format with:
|
| 706 |
+
- Episodes: Individual task demonstrations
|
| 707 |
+
- Observations: Camera images, robot states
|
| 708 |
+
- Actions: Robot joint commands
|
| 709 |
+
- Metadata: Task descriptions, timestamps
|
| 710 |
+
|
| 711 |
+
**Data Collection Script:**
|
| 712 |
+
|
| 713 |
+
```python
|
| 714 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 715 |
+
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
| 716 |
+
from pathlib import Path
|
| 717 |
+
import numpy as np
|
| 718 |
+
|
| 719 |
+
class DataCollector:
|
| 720 |
+
def __init__(self, dataset_name: str, repo_id: str):
|
| 721 |
+
self.dataset_name = dataset_name
|
| 722 |
+
self.repo_id = repo_id
|
| 723 |
+
self.dataset_dir = Path(f"data/{dataset_name}")
|
| 724 |
+
self.dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 725 |
+
|
| 726 |
+
self.dataset = LeRobotDataset.create(
|
| 727 |
+
repo_id=repo_id,
|
| 728 |
+
fps=30,
|
| 729 |
+
robot_type="so101",
|
| 730 |
+
keys=["observation.image", "observation.state", "action"]
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
def record_episode(self, task_description: str, duration: float = 30.0):
|
| 734 |
+
"""
|
| 735 |
+
Record a single demonstration episode.
|
| 736 |
+
|
| 737 |
+
Args:
|
| 738 |
+
task_description: Natural language task description
|
| 739 |
+
duration: Maximum recording duration in seconds
|
| 740 |
+
"""
|
| 741 |
+
print(f"Recording episode: {task_description}")
|
| 742 |
+
print("Press ENTER to start recording...")
|
| 743 |
+
input()
|
| 744 |
+
|
| 745 |
+
episode_data = {
|
| 746 |
+
"observation.image": [],
|
| 747 |
+
"observation.state": [],
|
| 748 |
+
"action": [],
|
| 749 |
+
"timestamp": [],
|
| 750 |
+
"task": task_description
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
start_time = time.time()
|
| 754 |
+
frame_count = 0
|
| 755 |
+
|
| 756 |
+
print("Recording... Press CTRL+C to stop")
|
| 757 |
+
|
| 758 |
+
try:
|
| 759 |
+
while time.time() - start_time < duration:
|
| 760 |
+
# Capture observation
|
| 761 |
+
image = camera.read()
|
| 762 |
+
state = mortis_arm.robot.get_state()
|
| 763 |
+
|
| 764 |
+
# Record current state as "action" (for behavior cloning)
|
| 765 |
+
action = state.copy()
|
| 766 |
+
|
| 767 |
+
# Store data
|
| 768 |
+
episode_data["observation.image"].append(image)
|
| 769 |
+
episode_data["observation.state"].append(state)
|
| 770 |
+
episode_data["action"].append(action)
|
| 771 |
+
episode_data["timestamp"].append(time.time() - start_time)
|
| 772 |
+
|
| 773 |
+
frame_count += 1
|
| 774 |
+
time.sleep(1/30) # 30 FPS
|
| 775 |
+
|
| 776 |
+
except KeyboardInterrupt:
|
| 777 |
+
print(f"\nRecording stopped. Captured {frame_count} frames")
|
| 778 |
+
|
| 779 |
+
# Save episode to dataset
|
| 780 |
+
self._save_episode(episode_data)
|
| 781 |
+
|
| 782 |
+
print(f"Episode saved: {task_description}")
|
| 783 |
+
|
| 784 |
+
def _save_episode(self, episode_data: dict):
|
| 785 |
+
"""Save episode to LeRobot dataset."""
|
| 786 |
+
episode_index = len(self.dataset)
|
| 787 |
+
|
| 788 |
+
# Convert to numpy arrays
|
| 789 |
+
images = np.array(episode_data["observation.image"])
|
| 790 |
+
states = np.array(episode_data["observation.state"])
|
| 791 |
+
actions = np.array(episode_data["action"])
|
| 792 |
+
|
| 793 |
+
# Add to dataset
|
| 794 |
+
self.dataset.add_episode({
|
| 795 |
+
"observation.image": images,
|
| 796 |
+
"observation.state": states,
|
| 797 |
+
"action": actions,
|
| 798 |
+
"episode_index": episode_index,
|
| 799 |
+
"task": episode_data["task"]
|
| 800 |
+
})
|
| 801 |
+
|
| 802 |
+
# Save to disk
|
| 803 |
+
self.dataset.save_to_disk(self.dataset_dir)
|
| 804 |
+
|
| 805 |
+
def push_to_hub(self):
|
| 806 |
+
"""Upload dataset to Hugging Face Hub."""
|
| 807 |
+
self.dataset.push_to_hub(self.repo_id)
|
| 808 |
+
print(f"Dataset pushed to: https://huggingface.co/datasets/{self.repo_id}")
|
| 809 |
+
|
| 810 |
+
# Usage script
|
| 811 |
+
def collect_demonstrations():
|
| 812 |
+
collector = DataCollector(
|
| 813 |
+
dataset_name="mortis_manipulation",
|
| 814 |
+
repo_id="your-username/mortis-manipulation"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
tasks = [
|
| 818 |
+
"Pick up the skull and place it in the green cup",
|
| 819 |
+
"Pick up the skull and place it in the orange cup",
|
| 820 |
+
"Pick up the skull and place it in the purple cup",
|
| 821 |
+
"Pick up the eyeball and place it in the green cup",
|
| 822 |
+
"Pick up the eyeball and place it in the orange cup",
|
| 823 |
+
"Pick up the eyeball and place it in the purple cup",
|
| 824 |
+
]
|
| 825 |
+
|
| 826 |
+
for task in tasks:
|
| 827 |
+
print(f"\n{'='*60}")
|
| 828 |
+
print(f"Task: {task}")
|
| 829 |
+
print(f"{'='*60}")
|
| 830 |
+
|
| 831 |
+
# Record multiple demonstrations per task
|
| 832 |
+
for demo_num in range(5):
|
| 833 |
+
print(f"\nDemonstration {demo_num + 1}/5")
|
| 834 |
+
collector.record_episode(task)
|
| 835 |
+
|
| 836 |
+
# Upload to Hugging Face
|
| 837 |
+
collector.push_to_hub()
|
| 838 |
+
```
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
### 8. Training Pipeline
|
| 842 |
+
|
| 843 |
+
#### Component: `TrainingPipeline`
|
| 844 |
+
|
| 845 |
+
**Purpose:** Train SmolVLA model on collected demonstration data.
|
| 846 |
+
|
| 847 |
+
**LeRobot Training Configuration:**
|
| 848 |
+
|
| 849 |
+
```yaml
|
| 850 |
+
# config/train_smolvla.yaml
|
| 851 |
+
defaults:
|
| 852 |
+
- _self_
|
| 853 |
+
- policy: smolvla
|
| 854 |
+
|
| 855 |
+
seed: 1000
|
| 856 |
+
dataset_repo_id: your-username/mortis-manipulation
|
| 857 |
+
video_backend: pyav
|
| 858 |
+
|
| 859 |
+
training:
|
| 860 |
+
offline_steps: 100000
|
| 861 |
+
online_steps: 0
|
| 862 |
+
eval_freq: 10000
|
| 863 |
+
save_freq: 10000
|
| 864 |
+
log_freq: 100
|
| 865 |
+
save_checkpoint: true
|
| 866 |
+
|
| 867 |
+
batch_size: 8
|
| 868 |
+
lr: 1e-4
|
| 869 |
+
lr_scheduler: cosine
|
| 870 |
+
lr_warmup_steps: 1000
|
| 871 |
+
adam_betas: [0.9, 0.999]
|
| 872 |
+
adam_weight_decay: 1e-6
|
| 873 |
+
grad_clip_norm: 10.0
|
| 874 |
+
|
| 875 |
+
delta_timestamps:
|
| 876 |
+
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
| 877 |
+
|
| 878 |
+
eval:
|
| 879 |
+
n_episodes: 10
|
| 880 |
+
batch_size: 10
|
| 881 |
+
|
| 882 |
+
policy:
|
| 883 |
+
name: smolvla
|
| 884 |
+
|
| 885 |
+
# Input dimensions
|
| 886 |
+
input_shapes:
|
| 887 |
+
observation.image: [3, 224, 224]
|
| 888 |
+
observation.state: [6] # 6 joints
|
| 889 |
+
|
| 890 |
+
output_shapes:
|
| 891 |
+
action: [6] # 6 joint commands
|
| 892 |
+
|
| 893 |
+
# Model architecture
|
| 894 |
+
vision_backbone: "google/siglip-so400m-patch14-384"
|
| 895 |
+
pretrained_backbone_weights: "google/siglip-so400m-patch14-384"
|
| 896 |
+
|
| 897 |
+
# Action prediction
|
| 898 |
+
chunk_size: 50 # Predict 50 steps ahead
|
| 899 |
+
n_action_steps: 50
|
| 900 |
+
|
| 901 |
+
# Training
|
| 902 |
+
use_language_conditioning: true
|
| 903 |
+
dropout: 0.1
|
| 904 |
+
|
| 905 |
+
device: cuda
|
| 906 |
+
use_amp: true # Automatic mixed precision
|
| 907 |
+
```
|
| 908 |
+
|
| 909 |
+
**Training Script:**
|
| 910 |
+
|
| 911 |
+
```python
|
| 912 |
+
from lerobot.scripts.train import train
|
| 913 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 914 |
+
from pathlib import Path
|
| 915 |
+
import hydra
|
| 916 |
+
from omegaconf import DictConfig
|
| 917 |
+
|
| 918 |
+
@hydra.main(config_path="config", config_name="train_smolvla", version_base="1.2")
|
| 919 |
+
def train_smolvla(cfg: DictConfig):
|
| 920 |
+
"""
|
| 921 |
+
Train SmolVLA model using LeRobot training pipeline.
|
| 922 |
+
|
| 923 |
+
Usage:
|
| 924 |
+
python -m mortis.train
|
| 925 |
+
"""
|
| 926 |
+
# Load dataset
|
| 927 |
+
dataset = LeRobotDataset(
|
| 928 |
+
repo_id=cfg.dataset_repo_id,
|
| 929 |
+
split="train"
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
print(f"Dataset loaded: {len(dataset)} episodes")
|
| 933 |
+
print(f"Training for {cfg.training.offline_steps} steps")
|
| 934 |
+
|
| 935 |
+
# Run training
|
| 936 |
+
train(cfg)
|
| 937 |
+
|
| 938 |
+
print("Training complete!")
|
| 939 |
+
print(f"Checkpoints saved to: outputs/train/{cfg.run_name}")
|
| 940 |
+
|
| 941 |
+
if __name__ == "__main__":
|
| 942 |
+
train_smolvla()
|
| 943 |
+
```
|
| 944 |
+
|
| 945 |
+
**Simplified Training Command:**
|
| 946 |
+
|
| 947 |
+
```bash
|
| 948 |
+
# Using lerobot CLI
|
| 949 |
+
python -m lerobot.scripts.train \
|
| 950 |
+
policy=smolvla \
|
| 951 |
+
env=so101 \
|
| 952 |
+
dataset_repo_id=your-username/mortis-manipulation \
|
| 953 |
+
training.offline_steps=100000 \
|
| 954 |
+
training.batch_size=8 \
|
| 955 |
+
training.save_freq=10000 \
|
| 956 |
+
device=cuda \
|
| 957 |
+
wandb.enable=true \
|
| 958 |
+
wandb.project=mortis-smolvla
|
| 959 |
+
```
|
| 960 |
+
|
| 961 |
+
**Training Monitoring:**
|
| 962 |
+
|
| 963 |
+
```python
|
| 964 |
+
# Integration with Weights & Biases for tracking
|
| 965 |
+
import wandb
|
| 966 |
+
|
| 967 |
+
wandb.init(
|
| 968 |
+
project="mortis-smolvla",
|
| 969 |
+
config={
|
| 970 |
+
"dataset": "mortis-manipulation",
|
| 971 |
+
"policy": "smolvla",
|
| 972 |
+
"batch_size": 8,
|
| 973 |
+
"learning_rate": 1e-4
|
| 974 |
+
}
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
# Logged automatically by LeRobot:
|
| 978 |
+
# - Training loss
|
| 979 |
+
# - Validation loss
|
| 980 |
+
# - Action prediction accuracy
|
| 981 |
+
# - Episode success rate
|
| 982 |
+
# - Sample predictions (videos)
|
| 983 |
+
```
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
## Data Models
|
| 987 |
+
|
| 988 |
+
### 1. Gemini Response Model
|
| 989 |
+
|
| 990 |
+
```python
|
| 991 |
+
from dataclasses import dataclass
|
| 992 |
+
from enum import Enum
|
| 993 |
+
from typing import Optional
|
| 994 |
+
|
| 995 |
+
class ResponseType(Enum):
|
| 996 |
+
CONVERSATION = "conversation"
|
| 997 |
+
MANIPULATION = "manipulation"
|
| 998 |
+
|
| 999 |
+
class Mood(Enum):
|
| 1000 |
+
OMINOUS = "ominous"
|
| 1001 |
+
PLAYFUL = "playful"
|
| 1002 |
+
ANGRY = "angry"
|
| 1003 |
+
NERVOUS = "nervous"
|
| 1004 |
+
TRIUMPHANT = "triumphant"
|
| 1005 |
+
MISCHIEVOUS = "mischievous"
|
| 1006 |
+
SINISTER = "sinister"
|
| 1007 |
+
CURIOUS = "curious"
|
| 1008 |
+
NEUTRAL = "neutral"
|
| 1009 |
+
|
| 1010 |
+
class Gesture(Enum):
|
| 1011 |
+
IDLE = "idle"
|
| 1012 |
+
WAVE = "wave"
|
| 1013 |
+
POINT_LEFT = "point_left"
|
| 1014 |
+
POINT_RIGHT = "point_right"
|
| 1015 |
+
GRAB = "grab"
|
| 1016 |
+
DROP = "drop"
|
| 1017 |
+
|
| 1018 |
+
@dataclass
|
| 1019 |
+
class GeminiResponse:
|
| 1020 |
+
"""Structured response from Gemini API."""
|
| 1021 |
+
type: ResponseType
|
| 1022 |
+
message: str
|
| 1023 |
+
mood: Mood
|
| 1024 |
+
gesture: Optional[Gesture] = None
|
| 1025 |
+
command: Optional[str] = None
|
| 1026 |
+
|
| 1027 |
+
@classmethod
|
| 1028 |
+
def from_json(cls, data: dict) -> 'GeminiResponse':
|
| 1029 |
+
"""Parse JSON response from Gemini."""
|
| 1030 |
+
response_type = ResponseType(data["type"])
|
| 1031 |
+
|
| 1032 |
+
if response_type == ResponseType.MANIPULATION:
|
| 1033 |
+
return cls(
|
| 1034 |
+
type=response_type,
|
| 1035 |
+
message=data["message"],
|
| 1036 |
+
mood=Mood(data["mood"]),
|
| 1037 |
+
command=data["command"]
|
| 1038 |
+
)
|
| 1039 |
+
else:
|
| 1040 |
+
return cls(
|
| 1041 |
+
type=response_type,
|
| 1042 |
+
message=data["message"],
|
| 1043 |
+
mood=Mood(data["mood"]),
|
| 1044 |
+
gesture=Gesture(data.get("gesture", "idle"))
|
| 1045 |
+
)
|
| 1046 |
+
```
|
| 1047 |
+
|
| 1048 |
+
### 2. Task Execution Model
|
| 1049 |
+
|
| 1050 |
+
```python
|
| 1051 |
+
from dataclasses import dataclass
|
| 1052 |
+
from enum import Enum
|
| 1053 |
+
from typing import Optional
|
| 1054 |
+
import time
|
| 1055 |
+
|
| 1056 |
+
class TaskStatus(Enum):
|
| 1057 |
+
QUEUED = "queued"
|
| 1058 |
+
RUNNING = "running"
|
| 1059 |
+
COMPLETE = "complete"
|
| 1060 |
+
FAILED = "failed"
|
| 1061 |
+
|
| 1062 |
+
class TaskType(Enum):
|
| 1063 |
+
GESTURE = "gesture"
|
| 1064 |
+
MANIPULATION = "manipulation"
|
| 1065 |
+
|
| 1066 |
+
@dataclass
|
| 1067 |
+
class Task:
|
| 1068 |
+
"""Represents a robot task for execution."""
|
| 1069 |
+
id: str
|
| 1070 |
+
type: TaskType
|
| 1071 |
+
status: TaskStatus
|
| 1072 |
+
created_at: float
|
| 1073 |
+
started_at: Optional[float] = None
|
| 1074 |
+
completed_at: Optional[float] = None
|
| 1075 |
+
error: Optional[str] = None
|
| 1076 |
+
|
| 1077 |
+
# Task-specific data
|
| 1078 |
+
gesture: Optional[str] = None
|
| 1079 |
+
command: Optional[str] = None
|
| 1080 |
+
|
| 1081 |
+
@classmethod
|
| 1082 |
+
def create_gesture_task(cls, gesture: str) -> 'Task':
|
| 1083 |
+
"""Create a gesture execution task."""
|
| 1084 |
+
return cls(
|
| 1085 |
+
id=f"gesture_{time.time()}",
|
| 1086 |
+
type=TaskType.GESTURE,
|
| 1087 |
+
status=TaskStatus.QUEUED,
|
| 1088 |
+
created_at=time.time(),
|
| 1089 |
+
gesture=gesture
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
@classmethod
|
| 1093 |
+
def create_manipulation_task(cls, command: str) -> 'Task':
|
| 1094 |
+
"""Create a manipulation execution task."""
|
| 1095 |
+
return cls(
|
| 1096 |
+
id=f"manipulation_{time.time()}",
|
| 1097 |
+
type=TaskType.MANIPULATION,
|
| 1098 |
+
status=TaskStatus.QUEUED,
|
| 1099 |
+
created_at=time.time(),
|
| 1100 |
+
command=command
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
def start(self):
|
| 1104 |
+
"""Mark task as started."""
|
| 1105 |
+
self.status = TaskStatus.RUNNING
|
| 1106 |
+
self.started_at = time.time()
|
| 1107 |
+
|
| 1108 |
+
def complete(self):
|
| 1109 |
+
"""Mark task as completed."""
|
| 1110 |
+
self.status = TaskStatus.COMPLETE
|
| 1111 |
+
self.completed_at = time.time()
|
| 1112 |
+
|
| 1113 |
+
def fail(self, error: str):
|
| 1114 |
+
"""Mark task as failed."""
|
| 1115 |
+
self.status = TaskStatus.FAILED
|
| 1116 |
+
self.completed_at = time.time()
|
| 1117 |
+
self.error = error
|
| 1118 |
+
|
| 1119 |
+
@property
|
| 1120 |
+
def duration(self) -> Optional[float]:
|
| 1121 |
+
"""Get task execution duration."""
|
| 1122 |
+
if self.started_at and self.completed_at:
|
| 1123 |
+
return self.completed_at - self.started_at
|
| 1124 |
+
return None
|
| 1125 |
+
```
|
| 1126 |
+
|
| 1127 |
+
### 3. Dataset Episode Model
|
| 1128 |
+
|
| 1129 |
+
```python
|
| 1130 |
+
from dataclasses import dataclass
|
| 1131 |
+
import numpy as np
|
| 1132 |
+
from typing import List
|
| 1133 |
+
|
| 1134 |
+
@dataclass
|
| 1135 |
+
class Episode:
|
| 1136 |
+
"""Represents a single demonstration episode."""
|
| 1137 |
+
episode_index: int
|
| 1138 |
+
task_description: str
|
| 1139 |
+
images: np.ndarray # Shape: (T, H, W, 3)
|
| 1140 |
+
states: np.ndarray # Shape: (T, 6)
|
| 1141 |
+
actions: np.ndarray # Shape: (T, 6)
|
| 1142 |
+
timestamps: np.ndarray # Shape: (T,)
|
| 1143 |
+
|
| 1144 |
+
@property
|
| 1145 |
+
def length(self) -> int:
|
| 1146 |
+
"""Number of timesteps in episode."""
|
| 1147 |
+
return len(self.timestamps)
|
| 1148 |
+
|
| 1149 |
+
@property
|
| 1150 |
+
def duration(self) -> float:
|
| 1151 |
+
"""Episode duration in seconds."""
|
| 1152 |
+
return self.timestamps[-1] - self.timestamps[0]
|
| 1153 |
+
|
| 1154 |
+
def validate(self) -> bool:
|
| 1155 |
+
"""Validate episode data consistency."""
|
| 1156 |
+
lengths = [
|
| 1157 |
+
len(self.images),
|
| 1158 |
+
len(self.states),
|
| 1159 |
+
len(self.actions),
|
| 1160 |
+
len(self.timestamps)
|
| 1161 |
+
]
|
| 1162 |
+
return len(set(lengths)) == 1 # All same length
|
| 1163 |
+
```
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
## Error Handling
|
| 1167 |
+
|
| 1168 |
+
### Error Categories and Recovery Strategies
|
| 1169 |
+
|
| 1170 |
+
#### 1. Gemini API Errors
|
| 1171 |
+
|
| 1172 |
+
**Error Types:**
|
| 1173 |
+
- Authentication failures (invalid API key)
|
| 1174 |
+
- Rate limiting (quota exceeded)
|
| 1175 |
+
- Network timeouts
|
| 1176 |
+
- Invalid responses (malformed JSON)
|
| 1177 |
+
|
| 1178 |
+
**Recovery Strategy:**
|
| 1179 |
+
|
| 1180 |
+
```python
|
| 1181 |
+
import time
|
| 1182 |
+
from typing import Optional
|
| 1183 |
+
|
| 1184 |
+
class GeminiAPIError(Exception):
|
| 1185 |
+
"""Base exception for Gemini API errors."""
|
| 1186 |
+
pass
|
| 1187 |
+
|
| 1188 |
+
class GeminiClient:
|
| 1189 |
+
def __init__(self, api_key: str, max_retries: int = 3):
|
| 1190 |
+
self.api_key = api_key
|
| 1191 |
+
self.max_retries = max_retries
|
| 1192 |
+
|
| 1193 |
+
def send_message_with_retry(
|
| 1194 |
+
self,
|
| 1195 |
+
message: str,
|
| 1196 |
+
retry_count: int = 0
|
| 1197 |
+
) -> Optional[GeminiResponse]:
|
| 1198 |
+
"""Send message with exponential backoff retry."""
|
| 1199 |
+
try:
|
| 1200 |
+
response = self._send_message(message)
|
| 1201 |
+
return response
|
| 1202 |
+
|
| 1203 |
+
except genai.types.BlockedPromptException as e:
|
| 1204 |
+
# Content safety filter triggered
|
| 1205 |
+
print(f"Prompt blocked by safety filter: {e}")
|
| 1206 |
+
return self._get_fallback_response()
|
| 1207 |
+
|
| 1208 |
+
except genai.types.RateLimitError as e:
|
| 1209 |
+
if retry_count < self.max_retries:
|
| 1210 |
+
wait_time = 2 ** retry_count # Exponential backoff
|
| 1211 |
+
print(f"Rate limited. Retrying in {wait_time}s...")
|
| 1212 |
+
time.sleep(wait_time)
|
| 1213 |
+
return self.send_message_with_retry(message, retry_count + 1)
|
| 1214 |
+
else:
|
| 1215 |
+
raise GeminiAPIError("Max retries exceeded for rate limit")
|
| 1216 |
+
|
| 1217 |
+
except Exception as e:
|
| 1218 |
+
print(f"Gemini API error: {e}")
|
| 1219 |
+
return self._get_fallback_response()
|
| 1220 |
+
|
| 1221 |
+
def _get_fallback_response(self) -> GeminiResponse:
|
| 1222 |
+
"""Return safe fallback response on API failure."""
|
| 1223 |
+
return GeminiResponse(
|
| 1224 |
+
type=ResponseType.CONVERSATION,
|
| 1225 |
+
message="The spirits are restless... try again.",
|
| 1226 |
+
mood=Mood.OMINOUS,
|
| 1227 |
+
gesture=Gesture.IDLE
|
| 1228 |
+
)
|
| 1229 |
+
```
|
| 1230 |
+
|
| 1231 |
+
#### 2. STT/TTS Errors
|
| 1232 |
+
|
| 1233 |
+
**Error Types:**
|
| 1234 |
+
- Audio format incompatibility
|
| 1235 |
+
- Service unavailable
|
| 1236 |
+
- Transcription failures (unclear audio)
|
| 1237 |
+
|
| 1238 |
+
**Recovery Strategy:**
|
| 1239 |
+
|
| 1240 |
+
```python
|
| 1241 |
+
class AudioProcessingError(Exception):
|
| 1242 |
+
"""Base exception for audio processing errors."""
|
| 1243 |
+
pass
|
| 1244 |
+
|
| 1245 |
+
def process_audio_with_fallback(audio_path: str) -> str:
|
| 1246 |
+
"""Process audio with fallback to text input."""
|
| 1247 |
+
try:
|
| 1248 |
+
# Try Gemini native audio
|
| 1249 |
+
transcript = transcribe_with_gemini(audio_path)
|
| 1250 |
+
return transcript
|
| 1251 |
+
|
| 1252 |
+
except Exception as e:
|
| 1253 |
+
print(f"Gemini audio processing failed: {e}")
|
| 1254 |
+
|
| 1255 |
+
try:
|
| 1256 |
+
# Fallback to Google STT
|
| 1257 |
+
transcript = transcribe_with_google_stt(audio_path)
|
| 1258 |
+
return transcript
|
| 1259 |
+
|
| 1260 |
+
except Exception as e:
|
| 1261 |
+
print(f"Google STT failed: {e}")
|
| 1262 |
+
raise AudioProcessingError(
|
| 1263 |
+
"Could not process audio. Please use text input."
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
def synthesize_speech_with_fallback(text: str) -> Optional[str]:
|
| 1267 |
+
"""Synthesize speech with fallback to text-only."""
|
| 1268 |
+
try:
|
| 1269 |
+
audio_path = synthesize_with_google_tts(text)
|
| 1270 |
+
return audio_path
|
| 1271 |
+
|
| 1272 |
+
except Exception as e:
|
| 1273 |
+
print(f"TTS failed: {e}. Returning text only.")
|
| 1274 |
+
return None # UI will display text without audio
|
| 1275 |
+
```
|
| 1276 |
+
|
| 1277 |
+
#### 3. SmolVLA Inference Errors
|
| 1278 |
+
|
| 1279 |
+
**Error Types:**
|
| 1280 |
+
- Model loading failures
|
| 1281 |
+
- GPU out of memory
|
| 1282 |
+
- Invalid observations
|
| 1283 |
+
- Action execution failures
|
| 1284 |
+
|
| 1285 |
+
**Recovery Strategy:**
|
| 1286 |
+
|
| 1287 |
+
```python
|
| 1288 |
+
class SmolVLAError(Exception):
|
| 1289 |
+
"""Base exception for SmolVLA errors."""
|
| 1290 |
+
pass
|
| 1291 |
+
|
| 1292 |
+
class SmolVLAExecutor:
|
| 1293 |
+
def execute_with_safety(self, command: str) -> bool:
|
| 1294 |
+
"""Execute command with safety checks and recovery."""
|
| 1295 |
+
try:
|
| 1296 |
+
# Pre-execution validation
|
| 1297 |
+
if not self._validate_command(command):
|
| 1298 |
+
raise SmolVLAError(f"Invalid command: {command}")
|
| 1299 |
+
|
| 1300 |
+
if not self._check_workspace_clear():
|
| 1301 |
+
raise SmolVLAError("Workspace not clear. Remove obstacles.")
|
| 1302 |
+
|
| 1303 |
+
# Execute with timeout
|
| 1304 |
+
success = self._execute_with_timeout(command, timeout=60.0)
|
| 1305 |
+
|
| 1306 |
+
if not success:
|
| 1307 |
+
raise SmolVLAError("Execution timeout")
|
| 1308 |
+
|
| 1309 |
+
return True
|
| 1310 |
+
|
| 1311 |
+
except torch.cuda.OutOfMemoryError:
|
| 1312 |
+
print("GPU OOM. Clearing cache and retrying...")
|
| 1313 |
+
torch.cuda.empty_cache()
|
| 1314 |
+
return self._execute_with_timeout(command, timeout=60.0)
|
| 1315 |
+
|
| 1316 |
+
except Exception as e:
|
| 1317 |
+
print(f"SmolVLA execution failed: {e}")
|
| 1318 |
+
# Return to safe position
|
| 1319 |
+
self._emergency_stop()
|
| 1320 |
+
return False
|
| 1321 |
+
|
| 1322 |
+
def _emergency_stop(self):
|
| 1323 |
+
"""Return robot to safe idle position."""
|
| 1324 |
+
print("Emergency stop: returning to idle position")
|
| 1325 |
+
mortis_arm.move_arm("idle")
|
| 1326 |
+
|
| 1327 |
+
def _validate_command(self, command: str) -> bool:
|
| 1328 |
+
"""Validate command is in trained set."""
|
| 1329 |
+
return command in self.valid_commands
|
| 1330 |
+
|
| 1331 |
+
def _check_workspace_clear(self) -> bool:
|
| 1332 |
+
"""Check if workspace is safe for execution."""
|
| 1333 |
+
# Could use computer vision to detect obstacles
|
| 1334 |
+
# For now, assume clear
|
| 1335 |
+
return True
|
| 1336 |
+
```
|
| 1337 |
+
|
| 1338 |
+
#### 4. Robot Hardware Errors
|
| 1339 |
+
|
| 1340 |
+
**Error Types:**
|
| 1341 |
+
- Connection failures
|
| 1342 |
+
- Servo errors
|
| 1343 |
+
- Position limits exceeded
|
| 1344 |
+
- Communication timeouts
|
| 1345 |
+
|
| 1346 |
+
**Recovery Strategy:**
|
| 1347 |
+
|
| 1348 |
+
```python
|
| 1349 |
+
class RobotError(Exception):
|
| 1350 |
+
"""Base exception for robot hardware errors."""
|
| 1351 |
+
pass
|
| 1352 |
+
|
| 1353 |
+
class MortisArm:
|
| 1354 |
+
def move_arm_safe(self, gesture_name: str) -> bool:
|
| 1355 |
+
"""Execute gesture with error handling."""
|
| 1356 |
+
if not self.connected:
|
| 1357 |
+
try:
|
| 1358 |
+
self.connect()
|
| 1359 |
+
except Exception as e:
|
| 1360 |
+
print(f"Failed to connect to robot: {e}")
|
| 1361 |
+
return False
|
| 1362 |
+
|
| 1363 |
+
try:
|
| 1364 |
+
self.move_arm(gesture_name)
|
| 1365 |
+
return True
|
| 1366 |
+
|
| 1367 |
+
except Exception as e:
|
| 1368 |
+
print(f"Gesture execution failed: {e}")
|
| 1369 |
+
|
| 1370 |
+
# Attempt recovery
|
| 1371 |
+
try:
|
| 1372 |
+
print("Attempting to reconnect...")
|
| 1373 |
+
self.disconnect()
|
| 1374 |
+
time.sleep(1)
|
| 1375 |
+
self.connect()
|
| 1376 |
+
self.move_arm("idle")
|
| 1377 |
+
return False
|
| 1378 |
+
|
| 1379 |
+
except Exception as e:
|
| 1380 |
+
print(f"Recovery failed: {e}")
|
| 1381 |
+
self.connected = False
|
| 1382 |
+
return False
|
| 1383 |
+
```
|
| 1384 |
+
|
| 1385 |
+
### Error Reporting to User
|
| 1386 |
+
|
| 1387 |
+
```python
|
| 1388 |
+
def format_error_message(error: Exception) -> str:
|
| 1389 |
+
"""Format error for user display."""
|
| 1390 |
+
error_messages = {
|
| 1391 |
+
GeminiAPIError: "🔮 The spirits are not responding. Please try again.",
|
| 1392 |
+
AudioProcessingError: "🎤 Could not understand audio. Please try text input.",
|
| 1393 |
+
SmolVLAError: "🤖 Mortis cannot perform that action right now.",
|
| 1394 |
+
RobotError: "⚠️ Robot connection lost. Attempting to reconnect...",
|
| 1395 |
+
}
|
| 1396 |
+
|
| 1397 |
+
error_type = type(error)
|
| 1398 |
+
return error_messages.get(error_type, "❌ An unexpected error occurred.")
|
| 1399 |
+
```
|
| 1400 |
+
|
| 1401 |
+
|
| 1402 |
+
## Testing Strategy
|
| 1403 |
+
|
| 1404 |
+
### 1. Unit Testing
|
| 1405 |
+
|
| 1406 |
+
**Components to Test:**
|
| 1407 |
+
- Gemini API client (with mocked responses)
|
| 1408 |
+
- Intent router (parsing and validation)
|
| 1409 |
+
- Data models (serialization/deserialization)
|
| 1410 |
+
- Audio processing utilities
|
| 1411 |
+
|
| 1412 |
+
**Example Test:**
|
| 1413 |
+
|
| 1414 |
+
```python
|
| 1415 |
+
import pytest
|
| 1416 |
+
from unittest.mock import Mock, patch
|
| 1417 |
+
from mortis.gemini_client import GeminiClient, GeminiResponse, ResponseType
|
| 1418 |
+
|
| 1419 |
+
def test_gemini_response_parsing():
|
| 1420 |
+
"""Test parsing of Gemini JSON responses."""
|
| 1421 |
+
# Test conversation response
|
| 1422 |
+
conv_data = {
|
| 1423 |
+
"type": "conversation",
|
| 1424 |
+
"message": "Beware, mortal...",
|
| 1425 |
+
"mood": "ominous",
|
| 1426 |
+
"gesture": "wave"
|
| 1427 |
+
}
|
| 1428 |
+
response = GeminiResponse.from_json(conv_data)
|
| 1429 |
+
assert response.type == ResponseType.CONVERSATION
|
| 1430 |
+
assert response.gesture.value == "wave"
|
| 1431 |
+
|
| 1432 |
+
# Test manipulation response
|
| 1433 |
+
manip_data = {
|
| 1434 |
+
"type": "manipulation",
|
| 1435 |
+
"message": "As you wish...",
|
| 1436 |
+
"mood": "sinister",
|
| 1437 |
+
"command": "Pick up the skull and place it in the green cup"
|
| 1438 |
+
}
|
| 1439 |
+
response = GeminiResponse.from_json(manip_data)
|
| 1440 |
+
assert response.type == ResponseType.MANIPULATION
|
| 1441 |
+
assert response.command is not None
|
| 1442 |
+
|
| 1443 |
+
@patch('google.generativeai.GenerativeModel')
|
| 1444 |
+
def test_gemini_client_retry(mock_model):
|
| 1445 |
+
"""Test retry logic for API failures."""
|
| 1446 |
+
client = GeminiClient(api_key="test_key", max_retries=3)
|
| 1447 |
+
|
| 1448 |
+
# Simulate rate limit error then success
|
| 1449 |
+
mock_model.return_value.generate_content.side_effect = [
|
| 1450 |
+
genai.types.RateLimitError("Rate limited"),
|
| 1451 |
+
Mock(text='{"type": "conversation", "message": "Hello", "mood": "neutral", "gesture": "idle"}')
|
| 1452 |
+
]
|
| 1453 |
+
|
| 1454 |
+
response = client.send_message_with_retry("Hello")
|
| 1455 |
+
assert response is not None
|
| 1456 |
+
assert mock_model.return_value.generate_content.call_count == 2
|
| 1457 |
+
```
|
| 1458 |
+
|
| 1459 |
+
### 2. Integration Testing
|
| 1460 |
+
|
| 1461 |
+
**Test Scenarios:**
|
| 1462 |
+
- End-to-end voice input → Gemini → gesture execution
|
| 1463 |
+
- Text input → intent detection → SmolVLA execution
|
| 1464 |
+
- Dataset collection → training → inference pipeline
|
| 1465 |
+
- Error recovery flows
|
| 1466 |
+
|
| 1467 |
+
**Example Test:**
|
| 1468 |
+
|
| 1469 |
+
```python
|
| 1470 |
+
@pytest.mark.integration
|
| 1471 |
+
def test_voice_to_gesture_flow():
|
| 1472 |
+
"""Test complete voice input to gesture execution."""
|
| 1473 |
+
# Record test audio
|
| 1474 |
+
test_audio = "tests/fixtures/test_wave.wav"
|
| 1475 |
+
|
| 1476 |
+
# Process audio
|
| 1477 |
+
transcript = process_audio(test_audio)
|
| 1478 |
+
assert "wave" in transcript.lower()
|
| 1479 |
+
|
| 1480 |
+
# Send to Gemini
|
| 1481 |
+
response = gemini_client.send_message(transcript)
|
| 1482 |
+
assert response.type == ResponseType.CONVERSATION
|
| 1483 |
+
assert response.gesture == Gesture.WAVE
|
| 1484 |
+
|
| 1485 |
+
# Execute gesture (with mock robot)
|
| 1486 |
+
with patch.object(mortis_arm, 'move_arm') as mock_move:
|
| 1487 |
+
execute_gesture(response.gesture)
|
| 1488 |
+
mock_move.assert_called_once_with("wave")
|
| 1489 |
+
|
| 1490 |
+
@pytest.mark.integration
|
| 1491 |
+
@pytest.mark.slow
|
| 1492 |
+
def test_smolvla_inference():
|
| 1493 |
+
"""Test SmolVLA model inference (requires GPU)."""
|
| 1494 |
+
if not torch.cuda.is_available():
|
| 1495 |
+
pytest.skip("GPU not available")
|
| 1496 |
+
|
| 1497 |
+
# Load test checkpoint
|
| 1498 |
+
executor = SmolVLAExecutor("tests/fixtures/test_checkpoint")
|
| 1499 |
+
|
| 1500 |
+
# Execute test command
|
| 1501 |
+
command = "Pick up the skull and place it in the green cup"
|
| 1502 |
+
success = executor.execute(command, max_steps=10)
|
| 1503 |
+
|
| 1504 |
+
assert success
|
| 1505 |
+
```
|
| 1506 |
+
|
| 1507 |
+
### 3. System Testing
|
| 1508 |
+
|
| 1509 |
+
**Test Scenarios:**
|
| 1510 |
+
- Multi-user concurrent access
|
| 1511 |
+
- Long-running operation stability
|
| 1512 |
+
- Resource usage (GPU memory, CPU)
|
| 1513 |
+
- Network failure recovery
|
| 1514 |
+
|
| 1515 |
+
**Performance Benchmarks:**
|
| 1516 |
+
|
| 1517 |
+
```python
|
| 1518 |
+
@pytest.mark.benchmark
|
| 1519 |
+
def test_gemini_response_time():
|
| 1520 |
+
"""Benchmark Gemini API response time."""
|
| 1521 |
+
import time
|
| 1522 |
+
|
| 1523 |
+
times = []
|
| 1524 |
+
for _ in range(10):
|
| 1525 |
+
start = time.time()
|
| 1526 |
+
response = gemini_client.send_message("Hello Mortis")
|
| 1527 |
+
elapsed = time.time() - start
|
| 1528 |
+
times.append(elapsed)
|
| 1529 |
+
|
| 1530 |
+
avg_time = sum(times) / len(times)
|
| 1531 |
+
assert avg_time < 2.0, f"Average response time {avg_time}s exceeds 2s threshold"
|
| 1532 |
+
|
| 1533 |
+
@pytest.mark.benchmark
|
| 1534 |
+
def test_smolvla_inference_time():
|
| 1535 |
+
"""Benchmark SmolVLA inference speed."""
|
| 1536 |
+
executor = SmolVLAExecutor("checkpoints/best_model")
|
| 1537 |
+
|
| 1538 |
+
start = time.time()
|
| 1539 |
+
executor.execute("Pick up the skull and place it in the green cup", max_steps=100)
|
| 1540 |
+
elapsed = time.time() - start
|
| 1541 |
+
|
| 1542 |
+
assert elapsed < 30.0, f"Inference time {elapsed}s exceeds 30s threshold"
|
| 1543 |
+
```
|
| 1544 |
+
|
| 1545 |
+
### 4. User Acceptance Testing
|
| 1546 |
+
|
| 1547 |
+
**Test Scenarios:**
|
| 1548 |
+
- Voice recognition accuracy with different accents
|
| 1549 |
+
- Task success rate for manipulation commands
|
| 1550 |
+
- UI responsiveness during long operations
|
| 1551 |
+
- Error message clarity and helpfulness
|
| 1552 |
+
|
| 1553 |
+
**Manual Test Checklist:**
|
| 1554 |
+
|
| 1555 |
+
```markdown
|
| 1556 |
+
## Voice Input Testing
|
| 1557 |
+
- [ ] Clear speech recognized correctly
|
| 1558 |
+
- [ ] Background noise handled gracefully
|
| 1559 |
+
- [ ] Multiple languages supported (if applicable)
|
| 1560 |
+
- [ ] Audio feedback provided to user
|
| 1561 |
+
|
| 1562 |
+
## Manipulation Task Testing
|
| 1563 |
+
- [ ] All 6 trained tasks execute successfully
|
| 1564 |
+
- [ ] Task variations handled appropriately
|
| 1565 |
+
- [ ] Robot returns to safe position after completion
|
| 1566 |
+
- [ ] Visual feedback clear during execution
|
| 1567 |
+
|
| 1568 |
+
## Error Handling Testing
|
| 1569 |
+
- [ ] API failures display helpful messages
|
| 1570 |
+
- [ ] Robot errors trigger safe shutdown
|
| 1571 |
+
- [ ] Network issues handled gracefully
|
| 1572 |
+
- [ ] Recovery procedures work as expected
|
| 1573 |
+
|
| 1574 |
+
## UI/UX Testing
|
| 1575 |
+
- [ ] Interface remains responsive during tasks
|
| 1576 |
+
- [ ] Status updates clear and timely
|
| 1577 |
+
- [ ] Audio playback works correctly
|
| 1578 |
+
- [ ] Webcam feed displays properly
|
| 1579 |
+
```
|
| 1580 |
+
|
| 1581 |
+
### 5. Safety Testing
|
| 1582 |
+
|
| 1583 |
+
**Critical Safety Tests:**
|
| 1584 |
+
|
| 1585 |
+
```python
|
| 1586 |
+
def test_emergency_stop():
|
| 1587 |
+
"""Test emergency stop functionality."""
|
| 1588 |
+
executor = SmolVLAExecutor("checkpoints/best_model")
|
| 1589 |
+
|
| 1590 |
+
# Start execution
|
| 1591 |
+
task_thread = Thread(target=executor.execute, args=("test command",))
|
| 1592 |
+
task_thread.start()
|
| 1593 |
+
|
| 1594 |
+
# Trigger emergency stop
|
| 1595 |
+
time.sleep(1)
|
| 1596 |
+
executor._emergency_stop()
|
| 1597 |
+
|
| 1598 |
+
# Verify robot in safe position
|
| 1599 |
+
state = mortis_arm.robot.get_state()
|
| 1600 |
+
assert state == HOME_POSE
|
| 1601 |
+
|
| 1602 |
+
def test_workspace_collision_detection():
|
| 1603 |
+
"""Test collision detection and avoidance."""
|
| 1604 |
+
# Place obstacle in workspace
|
| 1605 |
+
# Attempt manipulation task
|
| 1606 |
+
# Verify task aborted safely
|
| 1607 |
+
pass
|
| 1608 |
+
```
|
| 1609 |
+
|
| 1610 |
+
|
| 1611 |
+
## Deployment and Configuration
|
| 1612 |
+
|
| 1613 |
+
### Environment Configuration
|
| 1614 |
+
|
| 1615 |
+
**Required Environment Variables:**
|
| 1616 |
+
|
| 1617 |
+
```bash
|
| 1618 |
+
# .env file
|
| 1619 |
+
# Gemini API
|
| 1620 |
+
GEMINI_API_KEY=your_google_api_key
|
| 1621 |
+
GEMINI_MODEL=gemini-2.0-flash-exp
|
| 1622 |
+
GEMINI_TEMPERATURE=0.2
|
| 1623 |
+
|
| 1624 |
+
# Google Cloud (for STT/TTS if not using Gemini native)
|
| 1625 |
+
GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account-key.json
|
| 1626 |
+
|
| 1627 |
+
# Robot Configuration
|
| 1628 |
+
ROBOT_PORT=/dev/ttyACM1
|
| 1629 |
+
ROBOT_CALIBRATION_DIR=.cache/calibration/so101/
|
| 1630 |
+
|
| 1631 |
+
# SmolVLA Model
|
| 1632 |
+
SMOLVLA_CHECKPOINT_PATH=checkpoints/smolvla_best.pt
|
| 1633 |
+
SMOLVLA_DEVICE=cuda
|
| 1634 |
+
|
| 1635 |
+
# Application
|
| 1636 |
+
PORT=7860
|
| 1637 |
+
DEBUG=false
|
| 1638 |
+
|
| 1639 |
+
# Optional: Weights & Biases for training
|
| 1640 |
+
WANDB_API_KEY=your_wandb_key
|
| 1641 |
+
WANDB_PROJECT=mortis-smolvla
|
| 1642 |
+
```
|
| 1643 |
+
|
| 1644 |
+
### Dependency Management
|
| 1645 |
+
|
| 1646 |
+
**Updated pyproject.toml:**
|
| 1647 |
+
|
| 1648 |
+
```toml
|
| 1649 |
+
[project]
|
| 1650 |
+
name = "mortis"
|
| 1651 |
+
version = "0.2.0"
|
| 1652 |
+
description = "Mortis: Multi-modal AI Halloween Experience with SmolVLA"
|
| 1653 |
+
requires-python = ">=3.12"
|
| 1654 |
+
dependencies = [
|
| 1655 |
+
"gradio>=5.49.1",
|
| 1656 |
+
"lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0",
|
| 1657 |
+
"python-dotenv>=1.2.1",
|
| 1658 |
+
|
| 1659 |
+
# Gemini and Google Cloud
|
| 1660 |
+
"google-generativeai>=0.8.0",
|
| 1661 |
+
"google-cloud-speech>=2.26.0",
|
| 1662 |
+
"google-cloud-texttospeech>=2.16.0",
|
| 1663 |
+
|
| 1664 |
+
# ML and Vision
|
| 1665 |
+
"torch>=2.0.0",
|
| 1666 |
+
"torchvision>=0.15.0",
|
| 1667 |
+
"transformers>=4.40.0",
|
| 1668 |
+
"pillow>=10.0.0",
|
| 1669 |
+
|
| 1670 |
+
# Data and utilities
|
| 1671 |
+
"numpy>=1.24.0",
|
| 1672 |
+
"opencv-python>=4.8.0",
|
| 1673 |
+
"datasets>=2.14.0",
|
| 1674 |
+
]
|
| 1675 |
+
|
| 1676 |
+
[project.optional-dependencies]
|
| 1677 |
+
dev = [
|
| 1678 |
+
"pytest>=7.4.0",
|
| 1679 |
+
"pytest-asyncio>=0.21.0",
|
| 1680 |
+
"pytest-benchmark>=4.0.0",
|
| 1681 |
+
"black>=23.0.0",
|
| 1682 |
+
"ruff>=0.1.0",
|
| 1683 |
+
]
|
| 1684 |
+
|
| 1685 |
+
training = [
|
| 1686 |
+
"wandb>=0.16.0",
|
| 1687 |
+
"hydra-core>=1.3.0",
|
| 1688 |
+
"tensorboard>=2.14.0",
|
| 1689 |
+
]
|
| 1690 |
+
|
| 1691 |
+
[project.scripts]
|
| 1692 |
+
mortis = "mortis.app:main"
|
| 1693 |
+
calibrate = "mortis.calibrate:main"
|
| 1694 |
+
collect-data = "mortis.collect_data:main"
|
| 1695 |
+
train-smolvla = "mortis.train:main"
|
| 1696 |
+
```
|
| 1697 |
+
|
| 1698 |
+
### Installation Steps
|
| 1699 |
+
|
| 1700 |
+
```bash
|
| 1701 |
+
# 1. Clone repository
|
| 1702 |
+
git clone https://github.com/your-username/mortis.git
|
| 1703 |
+
cd mortis
|
| 1704 |
+
|
| 1705 |
+
# 2. Install dependencies
|
| 1706 |
+
make install
|
| 1707 |
+
|
| 1708 |
+
# 3. Configure environment
|
| 1709 |
+
cp .env.example .env
|
| 1710 |
+
# Edit .env with your API keys
|
| 1711 |
+
|
| 1712 |
+
# 4. Calibrate robot (first time only)
|
| 1713 |
+
make calibrate
|
| 1714 |
+
|
| 1715 |
+
# 5. Download or train SmolVLA model
|
| 1716 |
+
# Option A: Download pre-trained model
|
| 1717 |
+
python -m mortis.download_model --checkpoint smolvla_mortis_v1
|
| 1718 |
+
|
| 1719 |
+
# Option B: Train from scratch
|
| 1720 |
+
make collect-data
|
| 1721 |
+
make train-smolvla
|
| 1722 |
+
|
| 1723 |
+
# 6. Run application
|
| 1724 |
+
make run
|
| 1725 |
+
```
|
| 1726 |
+
|
| 1727 |
+
### Docker Deployment (Optional)
|
| 1728 |
+
|
| 1729 |
+
**Dockerfile:**
|
| 1730 |
+
|
| 1731 |
+
```dockerfile
|
| 1732 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 1733 |
+
|
| 1734 |
+
# Install Python and system dependencies
|
| 1735 |
+
RUN apt-get update && apt-get install -y \
|
| 1736 |
+
python3.12 \
|
| 1737 |
+
python3-pip \
|
| 1738 |
+
libusb-1.0-0 \
|
| 1739 |
+
udev \
|
| 1740 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 1741 |
+
|
| 1742 |
+
# Install uv package manager
|
| 1743 |
+
RUN pip install uv
|
| 1744 |
+
|
| 1745 |
+
WORKDIR /app
|
| 1746 |
+
|
| 1747 |
+
# Copy project files
|
| 1748 |
+
COPY pyproject.toml uv.lock ./
|
| 1749 |
+
COPY src/ ./src/
|
| 1750 |
+
COPY assets/ ./assets/
|
| 1751 |
+
|
| 1752 |
+
# Install dependencies
|
| 1753 |
+
RUN uv sync --frozen
|
| 1754 |
+
|
| 1755 |
+
# Expose Gradio port
|
| 1756 |
+
EXPOSE 7860
|
| 1757 |
+
|
| 1758 |
+
# Run application
|
| 1759 |
+
CMD ["uv", "run", "mortis"]
|
| 1760 |
+
```
|
| 1761 |
+
|
| 1762 |
+
**docker-compose.yml:**
|
| 1763 |
+
|
| 1764 |
+
```yaml
|
| 1765 |
+
version: '3.8'
|
| 1766 |
+
|
| 1767 |
+
services:
|
| 1768 |
+
mortis:
|
| 1769 |
+
build: .
|
| 1770 |
+
ports:
|
| 1771 |
+
- "7860:7860"
|
| 1772 |
+
devices:
|
| 1773 |
+
- /dev/ttyACM1:/dev/ttyACM1 # Robot USB connection
|
| 1774 |
+
volumes:
|
| 1775 |
+
- ./.env:/app/.env
|
| 1776 |
+
- ./checkpoints:/app/checkpoints
|
| 1777 |
+
- ./.cache:/app/.cache
|
| 1778 |
+
environment:
|
| 1779 |
+
- NVIDIA_VISIBLE_DEVICES=all
|
| 1780 |
+
runtime: nvidia
|
| 1781 |
+
restart: unless-stopped
|
| 1782 |
+
```
|
| 1783 |
+
|
| 1784 |
+
### System Requirements
|
| 1785 |
+
|
| 1786 |
+
**Minimum Requirements:**
|
| 1787 |
+
- CPU: 4 cores
|
| 1788 |
+
- RAM: 16 GB
|
| 1789 |
+
- GPU: NVIDIA GPU with 8GB VRAM (for SmolVLA inference)
|
| 1790 |
+
- Storage: 50 GB (for models and datasets)
|
| 1791 |
+
- OS: Ubuntu 22.04 or later
|
| 1792 |
+
- USB: Available port for SO101 robot
|
| 1793 |
+
|
| 1794 |
+
**Recommended Requirements:**
|
| 1795 |
+
- CPU: 8+ cores
|
| 1796 |
+
- RAM: 32 GB
|
| 1797 |
+
- GPU: NVIDIA RTX 3090 or better (24GB VRAM)
|
| 1798 |
+
- Storage: 100 GB SSD
|
| 1799 |
+
- Network: Stable internet for Gemini API
|
| 1800 |
+
|
| 1801 |
+
### Monitoring and Logging
|
| 1802 |
+
|
| 1803 |
+
**Logging Configuration:**
|
| 1804 |
+
|
| 1805 |
+
```python
|
| 1806 |
+
import logging
|
| 1807 |
+
from pathlib import Path
|
| 1808 |
+
|
| 1809 |
+
# Configure logging
|
| 1810 |
+
LOG_DIR = Path("logs")
|
| 1811 |
+
LOG_DIR.mkdir(exist_ok=True)
|
| 1812 |
+
|
| 1813 |
+
logging.basicConfig(
|
| 1814 |
+
level=logging.INFO,
|
| 1815 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 1816 |
+
handlers=[
|
| 1817 |
+
logging.FileHandler(LOG_DIR / f"mortis_{time.time()}.log"),
|
| 1818 |
+
logging.StreamHandler()
|
| 1819 |
+
]
|
| 1820 |
+
)
|
| 1821 |
+
|
| 1822 |
+
logger = logging.getLogger("mortis")
|
| 1823 |
+
|
| 1824 |
+
# Log important events
|
| 1825 |
+
logger.info("Application started")
|
| 1826 |
+
logger.info(f"Gemini model: {GEMINI_MODEL}")
|
| 1827 |
+
logger.info(f"SmolVLA checkpoint: {SMOLVLA_CHECKPOINT_PATH}")
|
| 1828 |
+
```
|
| 1829 |
+
|
| 1830 |
+
**Metrics to Monitor:**
|
| 1831 |
+
- Gemini API response times
|
| 1832 |
+
- SmolVLA inference times
|
| 1833 |
+
- Task success rates
|
| 1834 |
+
- Error frequencies
|
| 1835 |
+
- GPU memory usage
|
| 1836 |
+
- Robot connection status
|
| 1837 |
+
|
| 1838 |
+
|
| 1839 |
+
## Migration Strategy
|
| 1840 |
+
|
| 1841 |
+
### Phase 1: Gemini API Integration (Week 1)
|
| 1842 |
+
|
| 1843 |
+
**Goals:**
|
| 1844 |
+
- Replace existing LLM API with Gemini
|
| 1845 |
+
- Maintain current gesture functionality
|
| 1846 |
+
- Add structured JSON response parsing
|
| 1847 |
+
|
| 1848 |
+
**Tasks:**
|
| 1849 |
+
1. Create `GeminiClient` class
|
| 1850 |
+
2. Update system prompt for Gemini
|
| 1851 |
+
3. Modify `ask_mortis()` to use Gemini API
|
| 1852 |
+
4. Test with existing gestures
|
| 1853 |
+
5. Update environment configuration
|
| 1854 |
+
|
| 1855 |
+
**Validation:**
|
| 1856 |
+
- All existing gestures work with Gemini
|
| 1857 |
+
- Response times comparable to previous API
|
| 1858 |
+
- Character personality maintained
|
| 1859 |
+
|
| 1860 |
+
### Phase 2: Voice Input/Output (Week 2)
|
| 1861 |
+
|
| 1862 |
+
**Goals:**
|
| 1863 |
+
- Add audio input component to Gradio
|
| 1864 |
+
- Implement STT using Gemini native audio or Google STT
|
| 1865 |
+
- Add TTS for voice responses
|
| 1866 |
+
- Test multi-modal interaction
|
| 1867 |
+
|
| 1868 |
+
**Tasks:**
|
| 1869 |
+
1. Add audio input/output components to UI
|
| 1870 |
+
2. Implement STT service
|
| 1871 |
+
3. Implement TTS service
|
| 1872 |
+
4. Update UI to handle audio flows
|
| 1873 |
+
5. Test voice interaction end-to-end
|
| 1874 |
+
|
| 1875 |
+
**Validation:**
|
| 1876 |
+
- Voice input transcribed accurately
|
| 1877 |
+
- Audio responses play correctly
|
| 1878 |
+
- Text input still works
|
| 1879 |
+
- UI remains responsive
|
| 1880 |
+
|
| 1881 |
+
### Phase 3: Dataset Collection (Week 3)
|
| 1882 |
+
|
| 1883 |
+
**Goals:**
|
| 1884 |
+
- Set up data collection infrastructure
|
| 1885 |
+
- Record demonstrations for all 6 tasks
|
| 1886 |
+
- Validate and upload dataset to Hugging Face
|
| 1887 |
+
|
| 1888 |
+
**Tasks:**
|
| 1889 |
+
1. Create `DataCollector` class
|
| 1890 |
+
2. Set up camera and robot for recording
|
| 1891 |
+
3. Record 5-10 demonstrations per task
|
| 1892 |
+
4. Validate dataset quality
|
| 1893 |
+
5. Push to Hugging Face Hub
|
| 1894 |
+
|
| 1895 |
+
**Validation:**
|
| 1896 |
+
- All 6 tasks have sufficient demonstrations
|
| 1897 |
+
- Data quality is high (clear images, smooth motions)
|
| 1898 |
+
- Dataset loads correctly in LeRobot
|
| 1899 |
+
|
| 1900 |
+
### Phase 4: SmolVLA Training (Week 4)
|
| 1901 |
+
|
| 1902 |
+
**Goals:**
|
| 1903 |
+
- Train SmolVLA model on collected data
|
| 1904 |
+
- Evaluate model performance
|
| 1905 |
+
- Select best checkpoint
|
| 1906 |
+
|
| 1907 |
+
**Tasks:**
|
| 1908 |
+
1. Configure training pipeline
|
| 1909 |
+
2. Run training for 100k steps
|
| 1910 |
+
3. Monitor training metrics
|
| 1911 |
+
4. Evaluate on validation set
|
| 1912 |
+
5. Select and save best checkpoint
|
| 1913 |
+
|
| 1914 |
+
**Validation:**
|
| 1915 |
+
- Training converges (loss decreases)
|
| 1916 |
+
- Validation performance acceptable
|
| 1917 |
+
- Model can execute at least 3/6 tasks successfully
|
| 1918 |
+
|
| 1919 |
+
### Phase 5: Intent Detection and Routing (Week 5)
|
| 1920 |
+
|
| 1921 |
+
**Goals:**
|
| 1922 |
+
- Implement intent detection in Gemini prompt
|
| 1923 |
+
- Create intent router
|
| 1924 |
+
- Add command validation
|
| 1925 |
+
|
| 1926 |
+
**Tasks:**
|
| 1927 |
+
1. Update Gemini system prompt with task definitions
|
| 1928 |
+
2. Create `IntentRouter` class
|
| 1929 |
+
3. Implement command validation
|
| 1930 |
+
4. Test intent detection accuracy
|
| 1931 |
+
5. Handle edge cases
|
| 1932 |
+
|
| 1933 |
+
**Validation:**
|
| 1934 |
+
- Manipulation commands detected correctly (>90% accuracy)
|
| 1935 |
+
- Conversational inputs routed to gestures
|
| 1936 |
+
- Invalid commands handled gracefully
|
| 1937 |
+
|
| 1938 |
+
### Phase 6: Asynchronous Execution (Week 6)
|
| 1939 |
+
|
| 1940 |
+
**Goals:**
|
| 1941 |
+
- Implement async task execution
|
| 1942 |
+
- Add status tracking and UI updates
|
| 1943 |
+
- Test UI responsiveness
|
| 1944 |
+
|
| 1945 |
+
**Tasks:**
|
| 1946 |
+
1. Create `AsyncExecutor` class
|
| 1947 |
+
2. Implement task queue
|
| 1948 |
+
3. Add status display to UI
|
| 1949 |
+
4. Test with long-running tasks
|
| 1950 |
+
5. Handle concurrent requests
|
| 1951 |
+
|
| 1952 |
+
**Validation:**
|
| 1953 |
+
- UI remains responsive during SmolVLA execution
|
| 1954 |
+
- Status updates appear correctly
|
| 1955 |
+
- Multiple tasks can be queued
|
| 1956 |
+
- Errors don't crash the system
|
| 1957 |
+
|
| 1958 |
+
### Phase 7: Integration and Testing (Week 7)
|
| 1959 |
+
|
| 1960 |
+
**Goals:**
|
| 1961 |
+
- Integrate all components
|
| 1962 |
+
- Comprehensive testing
|
| 1963 |
+
- Bug fixes and optimization
|
| 1964 |
+
|
| 1965 |
+
**Tasks:**
|
| 1966 |
+
1. Integration testing
|
| 1967 |
+
2. Performance optimization
|
| 1968 |
+
3. Error handling improvements
|
| 1969 |
+
4. Documentation updates
|
| 1970 |
+
5. User acceptance testing
|
| 1971 |
+
|
| 1972 |
+
**Validation:**
|
| 1973 |
+
- All features work together
|
| 1974 |
+
- Performance meets requirements
|
| 1975 |
+
- Error handling robust
|
| 1976 |
+
- Documentation complete
|
| 1977 |
+
|
| 1978 |
+
### Phase 8: Deployment and Monitoring (Week 8)
|
| 1979 |
+
|
| 1980 |
+
**Goals:**
|
| 1981 |
+
- Deploy to production environment
|
| 1982 |
+
- Set up monitoring
|
| 1983 |
+
- Create user documentation
|
| 1984 |
+
|
| 1985 |
+
**Tasks:**
|
| 1986 |
+
1. Prepare deployment environment
|
| 1987 |
+
2. Configure monitoring and logging
|
| 1988 |
+
3. Create user guide
|
| 1989 |
+
4. Deploy application
|
| 1990 |
+
5. Monitor initial usage
|
| 1991 |
+
|
| 1992 |
+
**Validation:**
|
| 1993 |
+
- Application runs stably
|
| 1994 |
+
- Monitoring captures key metrics
|
| 1995 |
+
- Users can operate system successfully
|
| 1996 |
+
|
| 1997 |
+
### Rollback Plan
|
| 1998 |
+
|
| 1999 |
+
If critical issues arise during migration:
|
| 2000 |
+
|
| 2001 |
+
1. **Immediate Rollback:**
|
| 2002 |
+
- Revert to previous LLM API
|
| 2003 |
+
- Disable voice features
|
| 2004 |
+
- Use gesture-only mode
|
| 2005 |
+
|
| 2006 |
+
2. **Partial Rollback:**
|
| 2007 |
+
- Keep Gemini API
|
| 2008 |
+
- Disable SmolVLA (gestures only)
|
| 2009 |
+
- Disable voice features
|
| 2010 |
+
|
| 2011 |
+
3. **Data Preservation:**
|
| 2012 |
+
- All datasets backed up to Hugging Face
|
| 2013 |
+
- Model checkpoints saved to cloud storage
|
| 2014 |
+
- Configuration files version controlled
|
| 2015 |
+
|
| 2016 |
+
### Risk Mitigation
|
| 2017 |
+
|
| 2018 |
+
**Risk: Gemini API costs exceed budget**
|
| 2019 |
+
- Mitigation: Set API usage limits, implement caching, use smaller models
|
| 2020 |
+
|
| 2021 |
+
**Risk: SmolVLA training fails to converge**
|
| 2022 |
+
- Mitigation: Collect more data, adjust hyperparameters, use pre-trained weights
|
| 2023 |
+
|
| 2024 |
+
**Risk: Voice recognition accuracy too low**
|
| 2025 |
+
- Mitigation: Use better STT service, add noise filtering, provide text fallback
|
| 2026 |
+
|
| 2027 |
+
**Risk: GPU memory insufficient for SmolVLA**
|
| 2028 |
+
- Mitigation: Reduce batch size, use model quantization, upgrade hardware
|
| 2029 |
+
|
| 2030 |
+
**Risk: Robot safety issues during autonomous execution**
|
| 2031 |
+
- Mitigation: Implement workspace monitoring, add emergency stop, limit motion range
|
| 2032 |
+
|
| 2033 |
+
|
| 2034 |
+
## Design Decisions and Rationale
|
| 2035 |
+
|
| 2036 |
+
### 1. Why Gemini API over Other LLMs?
|
| 2037 |
+
|
| 2038 |
+
**Decision:** Use Google Gemini API as the primary LLM.
|
| 2039 |
+
|
| 2040 |
+
**Rationale:**
|
| 2041 |
+
- Native multi-modal support (audio, images, text)
|
| 2042 |
+
- Structured output via JSON mode
|
| 2043 |
+
- Strong intent detection capabilities
|
| 2044 |
+
- Integrated with Google Cloud ecosystem (STT/TTS)
|
| 2045 |
+
- Competitive pricing and performance
|
| 2046 |
+
- Good documentation and Python SDK
|
| 2047 |
+
|
| 2048 |
+
**Alternatives Considered:**
|
| 2049 |
+
- OpenAI GPT-4: More expensive, separate APIs for audio
|
| 2050 |
+
- Anthropic Claude: No native audio support
|
| 2051 |
+
- Local LLMs: Insufficient quality for intent detection
|
| 2052 |
+
|
| 2053 |
+
### 2. Why asyncio.Queue over Redis?
|
| 2054 |
+
|
| 2055 |
+
**Decision:** Use Python's asyncio.Queue for task management.
|
| 2056 |
+
|
| 2057 |
+
**Rationale:**
|
| 2058 |
+
- Single-machine deployment (no distributed workers needed)
|
| 2059 |
+
- No external dependencies
|
| 2060 |
+
- Simpler implementation and debugging
|
| 2061 |
+
- Sufficient for expected load (single user at a time)
|
| 2062 |
+
- Lower latency than network-based queue
|
| 2063 |
+
|
| 2064 |
+
**When to Reconsider:**
|
| 2065 |
+
- Multiple robot arms
|
| 2066 |
+
- Distributed deployment
|
| 2067 |
+
- High concurrent user load
|
| 2068 |
+
- Need for task persistence across restarts
|
| 2069 |
+
|
| 2070 |
+
### 3. Why SmolVLA over Other Robot Learning Approaches?
|
| 2071 |
+
|
| 2072 |
+
**Decision:** Use SmolVLA for manipulation tasks.
|
| 2073 |
+
|
| 2074 |
+
**Rationale:**
|
| 2075 |
+
- Vision-language-action model (understands natural language)
|
| 2076 |
+
- Integrated with LeRobot framework
|
| 2077 |
+
- End-to-end learning (no manual feature engineering)
|
| 2078 |
+
- Proven performance on manipulation tasks
|
| 2079 |
+
- Active development and community support
|
| 2080 |
+
|
| 2081 |
+
**Alternatives Considered:**
|
| 2082 |
+
- Reinforcement Learning: Requires extensive training, safety concerns
|
| 2083 |
+
- Classical Motion Planning: Requires manual programming, less flexible
|
| 2084 |
+
- Behavior Cloning (non-VLA): No language understanding
|
| 2085 |
+
|
| 2086 |
+
### 4. Why Hybrid Gesture + SmolVLA Approach?
|
| 2087 |
+
|
| 2088 |
+
**Decision:** Keep predefined gestures for conversational responses, add SmolVLA for manipulation.
|
| 2089 |
+
|
| 2090 |
+
**Rationale:**
|
| 2091 |
+
- Gestures are fast and reliable (no inference needed)
|
| 2092 |
+
- SmolVLA reserved for complex manipulation tasks
|
| 2093 |
+
- Reduces GPU usage for simple interactions
|
| 2094 |
+
- Maintains backward compatibility
|
| 2095 |
+
- Clear separation of concerns
|
| 2096 |
+
|
| 2097 |
+
**Benefits:**
|
| 2098 |
+
- Lower latency for conversational interactions
|
| 2099 |
+
- More robust (gestures can't fail inference)
|
| 2100 |
+
- Better resource utilization
|
| 2101 |
+
|
| 2102 |
+
### 5. Why Gradio over Custom Web Framework?
|
| 2103 |
+
|
| 2104 |
+
**Decision:** Continue using Gradio for the web interface.
|
| 2105 |
+
|
| 2106 |
+
**Rationale:**
|
| 2107 |
+
- Already integrated in existing system
|
| 2108 |
+
- Excellent support for audio/video components
|
| 2109 |
+
- Built-in WebSocket handling for real-time updates
|
| 2110 |
+
- Rapid prototyping and iteration
|
| 2111 |
+
- Good documentation and examples
|
| 2112 |
+
|
| 2113 |
+
**Limitations Acknowledged:**
|
| 2114 |
+
- Less customization than React/Vue
|
| 2115 |
+
- Limited styling options
|
| 2116 |
+
- Not ideal for production-scale applications
|
| 2117 |
+
|
| 2118 |
+
**When to Reconsider:**
|
| 2119 |
+
- Need for complex custom UI
|
| 2120 |
+
- Mobile app requirements
|
| 2121 |
+
- High-scale deployment (>100 concurrent users)
|
| 2122 |
+
|
| 2123 |
+
### 6. Why Google TTS over Local Alternatives?
|
| 2124 |
+
|
| 2125 |
+
**Decision:** Use Google Cloud Text-to-Speech for voice output.
|
| 2126 |
+
|
| 2127 |
+
**Rationale:**
|
| 2128 |
+
- High-quality neural voices
|
| 2129 |
+
- Consistent with Gemini ecosystem
|
| 2130 |
+
- Low latency
|
| 2131 |
+
- Voice customization options (pitch, speed)
|
| 2132 |
+
- Reliable service
|
| 2133 |
+
|
| 2134 |
+
**Alternatives Considered:**
|
| 2135 |
+
- pyttsx3: Lower quality, robotic voice
|
| 2136 |
+
- gTTS: Limited voice options, requires internet anyway
|
| 2137 |
+
- Local neural TTS: High GPU usage, slower
|
| 2138 |
+
|
| 2139 |
+
### 7. Why Separate Training and Inference Scripts?
|
| 2140 |
+
|
| 2141 |
+
**Decision:** Keep training infrastructure separate from runtime application.
|
| 2142 |
+
|
| 2143 |
+
**Rationale:**
|
| 2144 |
+
- Training is offline, one-time process
|
| 2145 |
+
- Different hardware requirements (training needs more VRAM)
|
| 2146 |
+
- Cleaner code organization
|
| 2147 |
+
- Easier to update training without affecting production
|
| 2148 |
+
- Can train on different machine than deployment
|
| 2149 |
+
|
| 2150 |
+
**Implementation:**
|
| 2151 |
+
- Training scripts in `mortis/train.py`
|
| 2152 |
+
- Inference in `mortis/smolvla_executor.py`
|
| 2153 |
+
- Shared model configuration
|
| 2154 |
+
|
| 2155 |
+
### 8. Why Not Use Gemini for Robot Control Directly?
|
| 2156 |
+
|
| 2157 |
+
**Decision:** Use Gemini for intent detection, SmolVLA for action generation.
|
| 2158 |
+
|
| 2159 |
+
**Rationale:**
|
| 2160 |
+
- LLMs are not designed for precise motor control
|
| 2161 |
+
- SmolVLA trained specifically on robot demonstrations
|
| 2162 |
+
- Gemini would require extensive prompting for each action
|
| 2163 |
+
- SmolVLA provides closed-loop visual feedback
|
| 2164 |
+
- Separation of concerns (language understanding vs. motor control)
|
| 2165 |
+
|
| 2166 |
+
**Gemini's Role:**
|
| 2167 |
+
- Understand user intent
|
| 2168 |
+
- Detect manipulation commands
|
| 2169 |
+
- Generate conversational responses
|
| 2170 |
+
- Maintain character personality
|
| 2171 |
+
|
| 2172 |
+
**SmolVLA's Role:**
|
| 2173 |
+
- Generate precise robot actions
|
| 2174 |
+
- Process visual observations
|
| 2175 |
+
- Execute manipulation tasks
|
| 2176 |
+
- Handle low-level control
|
| 2177 |
+
|
| 2178 |
+
### 9. Why Store Checkpoints Locally vs. Cloud?
|
| 2179 |
+
|
| 2180 |
+
**Decision:** Store model checkpoints locally with optional cloud backup.
|
| 2181 |
+
|
| 2182 |
+
**Rationale:**
|
| 2183 |
+
- Faster loading (no network latency)
|
| 2184 |
+
- No cloud storage costs during development
|
| 2185 |
+
- Privacy (model stays on local machine)
|
| 2186 |
+
- Simpler deployment
|
| 2187 |
+
|
| 2188 |
+
**Cloud Backup Strategy:**
|
| 2189 |
+
- Push final models to Hugging Face Hub
|
| 2190 |
+
- Version control with git-lfs
|
| 2191 |
+
- Disaster recovery
|
| 2192 |
+
|
| 2193 |
+
### 10. Why 6 Specific Manipulation Tasks?
|
| 2194 |
+
|
| 2195 |
+
**Decision:** Start with 6 predefined manipulation tasks (skull/eyeball × 3 cups).
|
| 2196 |
+
|
| 2197 |
+
**Rationale:**
|
| 2198 |
+
- Manageable scope for initial implementation
|
| 2199 |
+
- Sufficient variety to demonstrate capability
|
| 2200 |
+
- Fits Halloween theme
|
| 2201 |
+
- Realistic data collection effort (30-60 demonstrations)
|
| 2202 |
+
- Can expand later with more tasks
|
| 2203 |
+
|
| 2204 |
+
**Expansion Path:**
|
| 2205 |
+
- Add more objects (pumpkin, spider, etc.)
|
| 2206 |
+
- Add more target locations
|
| 2207 |
+
- Add multi-step tasks
|
| 2208 |
+
- Add task composition
|
| 2209 |
+
|
| 2210 |
+
|
| 2211 |
+
## Future Enhancements
|
| 2212 |
+
|
| 2213 |
+
### Short-term (3-6 months)
|
| 2214 |
+
|
| 2215 |
+
1. **Expanded Task Set**
|
| 2216 |
+
- Add 10-20 more manipulation tasks
|
| 2217 |
+
- Support task composition ("pick up skull, then eyeball")
|
| 2218 |
+
- Add multi-object interactions
|
| 2219 |
+
|
| 2220 |
+
2. **Improved Voice Interaction**
|
| 2221 |
+
- Wake word detection ("Hey Mortis")
|
| 2222 |
+
- Continuous conversation mode
|
| 2223 |
+
- Voice activity detection
|
| 2224 |
+
- Speaker identification
|
| 2225 |
+
|
| 2226 |
+
3. **Enhanced Safety**
|
| 2227 |
+
- Computer vision-based collision detection
|
| 2228 |
+
- Force/torque sensing
|
| 2229 |
+
- Workspace boundary enforcement
|
| 2230 |
+
- Automatic emergency stop
|
| 2231 |
+
|
| 2232 |
+
4. **Performance Optimization**
|
| 2233 |
+
- Model quantization for faster inference
|
| 2234 |
+
- Action caching for repeated tasks
|
| 2235 |
+
- Parallel processing for multiple requests
|
| 2236 |
+
- GPU memory optimization
|
| 2237 |
+
|
| 2238 |
+
### Medium-term (6-12 months)
|
| 2239 |
+
|
| 2240 |
+
1. **Advanced Learning**
|
| 2241 |
+
- Online learning from corrections
|
| 2242 |
+
- Few-shot task learning
|
| 2243 |
+
- Transfer learning to new objects
|
| 2244 |
+
- Self-supervised improvement
|
| 2245 |
+
|
| 2246 |
+
2. **Multi-Robot Support**
|
| 2247 |
+
- Control multiple SO101 arms
|
| 2248 |
+
- Coordinated multi-arm tasks
|
| 2249 |
+
- Load balancing across robots
|
| 2250 |
+
- Distributed task execution
|
| 2251 |
+
|
| 2252 |
+
3. **Enhanced Perception**
|
| 2253 |
+
- 3D object detection
|
| 2254 |
+
- Depth estimation
|
| 2255 |
+
- Object tracking
|
| 2256 |
+
- Scene understanding
|
| 2257 |
+
|
| 2258 |
+
4. **User Personalization**
|
| 2259 |
+
- User profiles and preferences
|
| 2260 |
+
- Adaptive difficulty
|
| 2261 |
+
- Custom task definitions
|
| 2262 |
+
- Voice profile learning
|
| 2263 |
+
|
| 2264 |
+
### Long-term (12+ months)
|
| 2265 |
+
|
| 2266 |
+
1. **Autonomous Task Planning**
|
| 2267 |
+
- High-level goal specification
|
| 2268 |
+
- Automatic task decomposition
|
| 2269 |
+
- Multi-step planning
|
| 2270 |
+
- Failure recovery strategies
|
| 2271 |
+
|
| 2272 |
+
2. **Natural Language Programming**
|
| 2273 |
+
- Teach new tasks through conversation
|
| 2274 |
+
- Automatic demonstration collection
|
| 2275 |
+
- Interactive refinement
|
| 2276 |
+
- Task library management
|
| 2277 |
+
|
| 2278 |
+
3. **Advanced Interaction**
|
| 2279 |
+
- Gesture recognition (human gestures)
|
| 2280 |
+
- Facial expression detection
|
| 2281 |
+
- Emotion-aware responses
|
| 2282 |
+
- Proactive assistance
|
| 2283 |
+
|
| 2284 |
+
4. **Production Deployment**
|
| 2285 |
+
- Multi-user support
|
| 2286 |
+
- Cloud-based inference
|
| 2287 |
+
- Mobile app interface
|
| 2288 |
+
- API for third-party integration
|
| 2289 |
+
|
| 2290 |
+
## Conclusion
|
| 2291 |
+
|
| 2292 |
+
This design provides a comprehensive architecture for refactoring Mortis into a multi-modal, SmolVLA-powered robotic system. The design emphasizes:
|
| 2293 |
+
|
| 2294 |
+
- **Modularity:** Clear separation between components (Gemini, STT/TTS, SmolVLA, robot control)
|
| 2295 |
+
- **Scalability:** Asynchronous execution and queue-based architecture
|
| 2296 |
+
- **Reliability:** Comprehensive error handling and recovery strategies
|
| 2297 |
+
- **Maintainability:** Well-defined interfaces and data models
|
| 2298 |
+
- **Extensibility:** Clear paths for future enhancements
|
| 2299 |
+
|
| 2300 |
+
The phased migration strategy allows for incremental development and validation, reducing risk and enabling early feedback. The hybrid approach of combining predefined gestures with learned manipulation behaviors provides both reliability and flexibility.
|
| 2301 |
+
|
| 2302 |
+
Key technical decisions prioritize:
|
| 2303 |
+
- Google ecosystem integration (Gemini, Cloud STT/TTS)
|
| 2304 |
+
- Local deployment with GPU support
|
| 2305 |
+
- LeRobot framework for robotics
|
| 2306 |
+
- Gradio for rapid UI development
|
| 2307 |
+
- Python-native solutions (asyncio, threading)
|
| 2308 |
+
|
| 2309 |
+
The design is ready for implementation following the task list in the next phase of the spec workflow.
|
.kiro/specs/gemini-multimodal-refactor/requirements.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements Document
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
This document specifies the requirements for refactoring the Mortis interactive AI Halloween experience to use Google Gemini API with multi-modal (voice and text) interaction capabilities. The refactor replaces the existing LLM API integration and adds SmolVLA-based robotic control for specific manipulation tasks. The system must maintain the character-driven conversational experience while enabling precise robotic manipulation through voice or text commands.
|
| 6 |
+
|
| 7 |
+
## Glossary
|
| 8 |
+
|
| 9 |
+
- **Mortis System**: The complete interactive AI Halloween experience including web UI, conversational AI, and robotic arm control
|
| 10 |
+
- **Gemini API**: Google's large language model API service used for conversational AI and intent detection
|
| 11 |
+
- **SmolVLA Model**: A vision-language-action model trained using LeRobot for specific robotic manipulation tasks
|
| 12 |
+
- **Gradio Interface**: The web-based user interface framework for the Mortis System
|
| 13 |
+
- **SO101 Arm**: The SeeedStudio SO101 robotic arm hardware controlled by the Mortis System
|
| 14 |
+
- **STT Service**: Speech-to-Text service that converts audio input to text
|
| 15 |
+
- **TTS Service**: Text-to-Speech service that converts text responses to audio output
|
| 16 |
+
- **Task String**: A specific command format recognized by the SmolVLA Model (e.g., "Pick up the skull and place it in the green cup")
|
| 17 |
+
- **LeRobot Framework**: The robotics framework used for dataset management, model training, and inference
|
| 18 |
+
- **Message Queue**: An asynchronous communication mechanism for decoupling robotic execution from the web interface
|
| 19 |
+
- **Cloud-Agnostic Architecture**: A system design that does not depend on vendor-specific cloud platform services (like AWS Lambda, Azure Functions, or GCP Cloud Run), allowing deployment on any infrastructure including local hardware
|
| 20 |
+
|
| 21 |
+
## Requirements
|
| 22 |
+
|
| 23 |
+
### Requirement 1: Gemini API Integration
|
| 24 |
+
|
| 25 |
+
**User Story:** As a developer, I want to replace the existing LLM API with Google Gemini API, so that the system uses Google's language model for all conversational interactions.
|
| 26 |
+
|
| 27 |
+
#### Acceptance Criteria
|
| 28 |
+
|
| 29 |
+
1. THE Mortis System SHALL use the Google Gemini API for all language model interactions
|
| 30 |
+
2. THE Mortis System SHALL support multiple Gemini model variants through configuration
|
| 31 |
+
3. THE Mortis System SHALL authenticate with the Gemini API using API keys stored in environment variables
|
| 32 |
+
4. THE Mortis System SHALL handle Gemini API errors gracefully and provide user feedback when API calls fail
|
| 33 |
+
5. THE Mortis System SHALL maintain response times under 5 seconds for typical conversational interactions
|
| 34 |
+
|
| 35 |
+
### Requirement 2: Multi-Modal Voice Input
|
| 36 |
+
|
| 37 |
+
**User Story:** As a user, I want to speak to Mortis through my microphone, so that I can interact naturally without typing.
|
| 38 |
+
|
| 39 |
+
#### Acceptance Criteria
|
| 40 |
+
|
| 41 |
+
1. THE Gradio Interface SHALL provide an audio input component for capturing user voice
|
| 42 |
+
2. WHEN a user provides voice input, THE Mortis System SHALL convert the audio to text using a Speech-to-Text service
|
| 43 |
+
3. THE Mortis System SHALL support both cloud-based STT services and local STT models as configurable options
|
| 44 |
+
4. THE Mortis System SHALL process voice input with latency under 3 seconds for utterances under 10 seconds
|
| 45 |
+
5. THE Mortis System SHALL display the transcribed text to the user for confirmation
|
| 46 |
+
|
| 47 |
+
### Requirement 3: Intent Detection and Command Routing
|
| 48 |
+
|
| 49 |
+
**User Story:** As a system, I want to detect when user input matches a specific robotic task command, so that I can route the request to the appropriate control mechanism.
|
| 50 |
+
|
| 51 |
+
#### Acceptance Criteria
|
| 52 |
+
|
| 53 |
+
1. THE Gemini API SHALL receive a system prompt that defines all valid SmolVLA Task Strings
|
| 54 |
+
2. WHEN the Gemini API processes user input, THE Mortis System SHALL determine if the input matches a valid Task String
|
| 55 |
+
3. IF the user input matches a valid Task String, THEN THE Mortis System SHALL extract the exact command string for robotic execution
|
| 56 |
+
4. IF the user input does not match a valid Task String, THEN THE Mortis System SHALL generate a standard conversational response with gesture control
|
| 57 |
+
5. THE Mortis System SHALL return both a conversational response and a command indicator in a structured format
|
| 58 |
+
|
| 59 |
+
### Requirement 4: Dataset Creation and Collection
|
| 60 |
+
|
| 61 |
+
**User Story:** As a developer, I want to create and collect demonstration data for robotic manipulation tasks, so that I have training data for the SmolVLA model.
|
| 62 |
+
|
| 63 |
+
#### Acceptance Criteria
|
| 64 |
+
|
| 65 |
+
1. THE Mortis System SHALL provide a data collection script for recording SO101 Arm demonstrations
|
| 66 |
+
2. THE Mortis System SHALL capture synchronized camera observations and robot actions during demonstrations
|
| 67 |
+
3. THE Mortis System SHALL save collected demonstrations in LeRobot-compatible format
|
| 68 |
+
4. THE Mortis System SHALL support labeling demonstrations with corresponding Task String commands
|
| 69 |
+
5. THE Mortis System SHALL validate collected data for completeness before adding to the training dataset
|
| 70 |
+
|
| 71 |
+
### Requirement 5: SmolVLA Model Training Infrastructure
|
| 72 |
+
|
| 73 |
+
**User Story:** As a developer, I want to train a SmolVLA model using LeRobot with collected demonstration data, so that the robot can perform precise manipulation tasks.
|
| 74 |
+
|
| 75 |
+
#### Acceptance Criteria
|
| 76 |
+
|
| 77 |
+
1. THE Mortis System SHALL provide a training script that loads datasets from local LeRobot databases or Hugging Face
|
| 78 |
+
2. THE Mortis System SHALL create and manage LeRobot dataset databases for training data
|
| 79 |
+
3. THE Mortis System SHALL configure SmolVLA training using lerobot-train with appropriate hyperparameters
|
| 80 |
+
4. THE Mortis System SHALL save trained model checkpoints to a configurable directory
|
| 81 |
+
5. THE Mortis System SHALL log training metrics including loss, accuracy, and validation performance
|
| 82 |
+
|
| 83 |
+
### Requirement 6: SmolVLA Inference Execution
|
| 84 |
+
|
| 85 |
+
**User Story:** As a system, I want to execute SmolVLA model inference when a valid task command is detected, so that the robot performs the requested manipulation.
|
| 86 |
+
|
| 87 |
+
#### Acceptance Criteria
|
| 88 |
+
|
| 89 |
+
1. THE Mortis System SHALL load the trained SmolVLA Model from saved checkpoints
|
| 90 |
+
2. WHEN a valid Task String is received, THE Mortis System SHALL execute SmolVLA inference with the command as input
|
| 91 |
+
3. THE Mortis System SHALL control the SO101 Arm through the SmolVLA Model output actions
|
| 92 |
+
4. THE Mortis System SHALL provide visual feedback during robotic execution through the webcam view
|
| 93 |
+
5. THE Mortis System SHALL handle inference errors and return the robot to a safe idle state
|
| 94 |
+
|
| 95 |
+
### Requirement 7: Asynchronous Robotic Execution
|
| 96 |
+
|
| 97 |
+
**User Story:** As a user, I want the web interface to remain responsive while the robot executes tasks, so that I can monitor progress without the UI freezing.
|
| 98 |
+
|
| 99 |
+
#### Acceptance Criteria
|
| 100 |
+
|
| 101 |
+
1. THE Mortis System SHALL execute SmolVLA inference asynchronously without blocking the Gradio Interface
|
| 102 |
+
2. THE Mortis System SHALL use a message queue or background processing mechanism to decouple inference from the web interface
|
| 103 |
+
3. WHILE SmolVLA inference is executing, THE Gradio Interface SHALL display a status indicator showing task progress
|
| 104 |
+
4. THE Mortis System SHALL allow users to view the robot's actions through the webcam during execution
|
| 105 |
+
5. WHEN robotic execution completes, THE Mortis System SHALL update the interface with completion status
|
| 106 |
+
|
| 107 |
+
### Requirement 8: Voice Output Integration
|
| 108 |
+
|
| 109 |
+
**User Story:** As a user, I want to hear Mortis speak responses aloud, so that I can experience a fully voice-based interaction.
|
| 110 |
+
|
| 111 |
+
#### Acceptance Criteria
|
| 112 |
+
|
| 113 |
+
1. THE Mortis System SHALL convert Gemini API text responses to audio using a Text-to-Speech service
|
| 114 |
+
2. THE Mortis System SHALL support Google TTS or equivalent widely-available TTS services
|
| 115 |
+
3. THE Gradio Interface SHALL play generated audio responses automatically after receiving them
|
| 116 |
+
4. THE Mortis System SHALL generate audio in a format compatible with web browsers (MP3 or WAV)
|
| 117 |
+
5. THE Mortis System SHALL maintain character voice consistency across all audio responses
|
| 118 |
+
|
| 119 |
+
### Requirement 9: Architecture and Deployment
|
| 120 |
+
|
| 121 |
+
**User Story:** As a developer, I want a system that can run on local hardware without vendor-specific cloud dependencies, so that I can deploy it flexibly while using Google APIs for LLM services.
|
| 122 |
+
|
| 123 |
+
#### Acceptance Criteria
|
| 124 |
+
|
| 125 |
+
1. THE Mortis System SHALL not depend on vendor-specific cloud platform services such as AWS Lambda, Azure Functions, or GCP Cloud Run
|
| 126 |
+
2. THE Mortis System SHALL support deployment on local hardware with GPU access for SmolVLA inference
|
| 127 |
+
3. THE Mortis System SHALL use standard Python libraries and open-source frameworks for all non-Google API components
|
| 128 |
+
4. THE Mortis System SHALL document all external service dependencies in the environment configuration
|
| 129 |
+
5. THE Mortis System SHALL provide configuration options for switching between cloud-based and local STT and TTS processing
|
| 130 |
+
|
| 131 |
+
### Requirement 10: Backward Compatibility and Migration
|
| 132 |
+
|
| 133 |
+
**User Story:** As a developer, I want to migrate from the existing LLM API to Gemini without losing existing functionality, so that users experience a seamless transition.
|
| 134 |
+
|
| 135 |
+
#### Acceptance Criteria
|
| 136 |
+
|
| 137 |
+
1. THE Mortis System SHALL maintain all existing gesture capabilities during the refactor
|
| 138 |
+
2. THE Mortis System SHALL preserve the Halloween character theme and response style
|
| 139 |
+
3. THE Mortis System SHALL continue to support text-only interaction for users without microphones
|
| 140 |
+
4. THE Mortis System SHALL maintain the existing Gradio Interface layout and visual design
|
| 141 |
+
5. THE Mortis System SHALL provide a migration guide documenting configuration changes
|
| 142 |
+
|
| 143 |
+
### Requirement 11: Error Handling and Robustness
|
| 144 |
+
|
| 145 |
+
**User Story:** As a user, I want the system to handle errors gracefully, so that temporary failures don't break my interaction experience.
|
| 146 |
+
|
| 147 |
+
#### Acceptance Criteria
|
| 148 |
+
|
| 149 |
+
1. IF the Gemini API is unavailable, THEN THE Mortis System SHALL display an error message and allow retry
|
| 150 |
+
2. IF STT conversion fails, THEN THE Mortis System SHALL prompt the user to try again or use text input
|
| 151 |
+
3. IF SmolVLA inference fails, THEN THE Mortis System SHALL return the SO101 Arm to idle position safely
|
| 152 |
+
4. IF TTS generation fails, THEN THE Mortis System SHALL display the text response without audio
|
| 153 |
+
5. THE Mortis System SHALL log all errors with sufficient detail for debugging
|
.kiro/specs/gemini-multimodal-refactor/tasks.md
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation Plan
|
| 2 |
+
|
| 3 |
+
This implementation plan breaks down the Gemini multi-modal refactor into discrete, actionable coding tasks. Each task builds incrementally on previous work, following the 8-phase migration strategy outlined in the design document.
|
| 4 |
+
|
| 5 |
+
## Important Note: Hybrid Async Execution System
|
| 6 |
+
|
| 7 |
+
**Phase 7** uses a **hybrid approach** for asynchronous execution:
|
| 8 |
+
|
| 9 |
+
1. **AsyncExecutor** (`src/mortis/async_executor.py`): Simple Python threading system for quick gesture tasks
|
| 10 |
+
- Use for: wave, point, idle, grab, drop gestures
|
| 11 |
+
- Advantages: Simple, fast (1-2s), low overhead
|
| 12 |
+
- Implementation: Task queue + worker thread + status queue
|
| 13 |
+
|
| 14 |
+
2. **LeRobotAsyncClient** (`src/mortis/lerobot_async_client.py`): Wrapper over LeRobot's async inference system
|
| 15 |
+
- Use for: Complex manipulation tasks with SmolVLA
|
| 16 |
+
- Advantages: Optimized for continuous inference, handles action chunks, real-time control
|
| 17 |
+
- Implementation: PolicyServer + RobotClient + gRPC communication
|
| 18 |
+
|
| 19 |
+
This hybrid approach provides the best of both worlds: simplicity for gestures and power for manipulation.
|
| 20 |
+
|
| 21 |
+
## Phase 1: Gemini API Integration
|
| 22 |
+
|
| 23 |
+
- [x] 1. Set up Gemini API client infrastructure
|
| 24 |
+
- Create `src/mortis/gemini_client.py` module
|
| 25 |
+
- Implement `GeminiClient` class with configuration management
|
| 26 |
+
- Add environment variable handling for `GEMINI_API_KEY`, `GEMINI_MODEL`, `GEMINI_TEMPERATURE`
|
| 27 |
+
- Implement basic `send_message()` method using `google.generativeai` SDK
|
| 28 |
+
- _Requirements: 1.1, 1.2, 1.3_
|
| 29 |
+
|
| 30 |
+
- [x] 2. Implement structured response parsing
|
| 31 |
+
- Create `src/mortis/models.py` for data models
|
| 32 |
+
- Implement `GeminiResponse`, `ResponseType`, `Mood`, `Gesture` enums and dataclasses
|
| 33 |
+
- Add `from_json()` method for parsing Gemini JSON responses
|
| 34 |
+
- Implement response validation logic
|
| 35 |
+
- _Requirements: 1.1, 3.5_
|
| 36 |
+
|
| 37 |
+
- [x] 3. Design and implement Gemini system prompt
|
| 38 |
+
- Create system prompt with Mortis character definition
|
| 39 |
+
- Add manipulation task definitions (6 tasks) to prompt
|
| 40 |
+
- Implement JSON response format specification in prompt
|
| 41 |
+
- Configure Gemini to use JSON mode (`response_mime_type: application/json`)
|
| 42 |
+
- _Requirements: 3.1, 3.2, 9.2_
|
| 43 |
+
|
| 44 |
+
- [x] 4. Implement error handling and retry logic
|
| 45 |
+
- Add exponential backoff retry for rate limiting
|
| 46 |
+
- Handle `BlockedPromptException` with fallback responses
|
| 47 |
+
- Implement timeout handling for API calls
|
| 48 |
+
- Add error logging and user-friendly error messages
|
| 49 |
+
- _Requirements: 1.4, 11.1_
|
| 50 |
+
|
| 51 |
+
- [x] 5. Replace existing LLM API in tools.py
|
| 52 |
+
- Refactor `ask_mortis()` function to use `GeminiClient`
|
| 53 |
+
- Update response parsing to use new data models
|
| 54 |
+
- Maintain backward compatibility with gesture execution
|
| 55 |
+
- Update environment configuration documentation
|
| 56 |
+
- _Requirements: 1.1, 9.1, 9.4_
|
| 57 |
+
|
| 58 |
+
- [ ]* 5.1 Write integration tests for Gemini client
|
| 59 |
+
- Test successful API calls and response parsing
|
| 60 |
+
- Test retry logic with mocked rate limit errors
|
| 61 |
+
- Test fallback responses on API failures
|
| 62 |
+
- Verify character personality maintained in responses
|
| 63 |
+
- _Requirements: 1.1, 1.4_
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
## Phase 2: Voice Input and Output
|
| 67 |
+
|
| 68 |
+
- [x] 6. Implement Speech-to-Text service
|
| 69 |
+
- Create `src/mortis/stt_service.py` module
|
| 70 |
+
- Implement `STTService` class with Gemini native audio support
|
| 71 |
+
- Add fallback to Google Cloud Speech-to-Text API
|
| 72 |
+
- Implement audio file format validation and conversion
|
| 73 |
+
- Add configuration for STT service selection (Gemini vs Google STT)
|
| 74 |
+
- _Requirements: 2.1, 2.2, 2.3_
|
| 75 |
+
|
| 76 |
+
- [x] 7. Implement Text-to-Speech service
|
| 77 |
+
- Create `src/mortis/tts_service.py` module
|
| 78 |
+
- Implement `TTSService` class using Google Cloud TTS
|
| 79 |
+
- Configure voice parameters (pitch, speed) for Mortis character
|
| 80 |
+
- Implement audio file generation (MP3 format)
|
| 81 |
+
- Add local TTS fallback (gTTS) for offline scenarios
|
| 82 |
+
- _Requirements: 8.1, 8.2, 8.4, 8.5_
|
| 83 |
+
|
| 84 |
+
- [x] 8. Update Gradio UI for audio input
|
| 85 |
+
- Add `gr.Audio` component for microphone input to `app.py`
|
| 86 |
+
- Implement audio input handler function
|
| 87 |
+
- Connect audio input to STT service
|
| 88 |
+
- Display transcribed text to user for confirmation
|
| 89 |
+
- Handle audio processing errors gracefully
|
| 90 |
+
- _Requirements: 2.1, 2.5, 11.2_
|
| 91 |
+
|
| 92 |
+
- [x] 9. Update Gradio UI for audio output
|
| 93 |
+
- Add `gr.Audio` component for audio playback
|
| 94 |
+
- Implement audio response generation in `mortis_reply()`
|
| 95 |
+
- Configure autoplay for audio responses
|
| 96 |
+
- Create `outputs/` directory for temporary audio files
|
| 97 |
+
- Implement audio file cleanup mechanism
|
| 98 |
+
- _Requirements: 8.3, 8.4_
|
| 99 |
+
|
| 100 |
+
- [x] 10. Integrate voice flow with Gemini
|
| 101 |
+
- Update `ask_mortis()` to accept audio input
|
| 102 |
+
- Implement voice-to-text-to-Gemini-to-TTS pipeline
|
| 103 |
+
- Maintain text input compatibility
|
| 104 |
+
- Add latency monitoring for voice processing
|
| 105 |
+
- _Requirements: 2.4, 9.3_
|
| 106 |
+
|
| 107 |
+
- [ ]* 10.1 Write tests for audio processing
|
| 108 |
+
- Test STT with sample audio files
|
| 109 |
+
- Test TTS output quality and format
|
| 110 |
+
- Test audio input/output in Gradio UI
|
| 111 |
+
- Verify fallback mechanisms work correctly
|
| 112 |
+
- _Requirements: 2.2, 8.2, 11.2, 11.4_
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
## Phase 3: Dataset Collection Infrastructure
|
| 116 |
+
|
| 117 |
+
- [x] 11. Set up LeRobot dataset infrastructure
|
| 118 |
+
- Create `src/mortis/data_collector.py` module
|
| 119 |
+
- Implement `DataCollector` class with LeRobot dataset integration
|
| 120 |
+
- Configure dataset directory structure (`data/mortis_manipulation/`)
|
| 121 |
+
- Implement dataset metadata management (task descriptions, episode counts)
|
| 122 |
+
- _Requirements: 4.3, 5.2_
|
| 123 |
+
|
| 124 |
+
- [ ]* 12. Implement camera integration for data collection
|
| 125 |
+
- Add camera initialization in `DataCollector`
|
| 126 |
+
- Implement synchronized image capture with robot state
|
| 127 |
+
- Configure camera parameters (resolution, FPS)
|
| 128 |
+
- Add camera calibration utilities
|
| 129 |
+
- _Requirements: 4.2_
|
| 130 |
+
|
| 131 |
+
- [ ]* 13. Implement episode recording functionality
|
| 132 |
+
- Create `record_episode()` method for capturing demonstrations
|
| 133 |
+
- Implement real-time data capture loop (30 FPS)
|
| 134 |
+
- Add keyboard controls for start/stop recording
|
| 135 |
+
- Implement episode data validation
|
| 136 |
+
- Save episodes in LeRobot-compatible format
|
| 137 |
+
- _Requirements: 4.1, 4.2, 4.5_
|
| 138 |
+
|
| 139 |
+
- [ ]* 14. Implement task labeling system
|
| 140 |
+
- Add task description input for each episode
|
| 141 |
+
- Create task label validation against predefined task set
|
| 142 |
+
- Implement episode metadata storage
|
| 143 |
+
- Add episode review and re-recording capability
|
| 144 |
+
- _Requirements: 4.4_
|
| 145 |
+
|
| 146 |
+
- [ ]* 15. Create data collection CLI script
|
| 147 |
+
- Create `src/mortis/collect_data.py` entry point
|
| 148 |
+
- Implement interactive data collection workflow
|
| 149 |
+
- Add progress tracking (episodes per task)
|
| 150 |
+
- Implement dataset statistics display
|
| 151 |
+
- Add Hugging Face Hub upload functionality
|
| 152 |
+
- _Requirements: 4.1, 4.3, 5.1_
|
| 153 |
+
|
| 154 |
+
- [ ]* 15.1 Write data validation tests
|
| 155 |
+
- Test episode data format compliance
|
| 156 |
+
- Verify synchronized timestamps
|
| 157 |
+
- Check image quality and dimensions
|
| 158 |
+
- Validate action sequences
|
| 159 |
+
- _Requirements: 4.5_
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
## Phase 4: SmolVLA Training Pipeline
|
| 163 |
+
|
| 164 |
+
- [ ]* 16. Create training configuration
|
| 165 |
+
- Create `config/train_smolvla.yaml` with Hydra configuration
|
| 166 |
+
- Configure SmolVLA policy parameters (vision backbone, chunk size)
|
| 167 |
+
- Set training hyperparameters (batch size, learning rate, steps)
|
| 168 |
+
- Configure evaluation settings
|
| 169 |
+
- Add Weights & Biases integration configuration
|
| 170 |
+
- _Requirements: 5.3, 5.5_
|
| 171 |
+
|
| 172 |
+
- [x] 17. Implement training script
|
| 173 |
+
- Create `src/mortis/train.py` module
|
| 174 |
+
- Implement dataset loading from local or Hugging Face
|
| 175 |
+
- Configure LeRobot training pipeline
|
| 176 |
+
- Add checkpoint saving logic
|
| 177 |
+
- Implement training progress logging
|
| 178 |
+
- _Requirements: 5.1, 5.2, 5.4, 5.5_
|
| 179 |
+
|
| 180 |
+
- [ ]* 18. Set up training monitoring
|
| 181 |
+
- Integrate Weights & Biases for metric tracking
|
| 182 |
+
- Log training loss, validation loss, learning rate
|
| 183 |
+
- Add sample prediction visualization
|
| 184 |
+
- Implement early stopping based on validation performance
|
| 185 |
+
- _Requirements: 5.5_
|
| 186 |
+
|
| 187 |
+
- [ ]* 19. Create training execution commands
|
| 188 |
+
- Add `train-smolvla` target to Makefile
|
| 189 |
+
- Document training command with all parameters
|
| 190 |
+
- Add GPU memory optimization flags
|
| 191 |
+
- Create training resume functionality for interrupted runs
|
| 192 |
+
- _Requirements: 5.3, 5.4_
|
| 193 |
+
|
| 194 |
+
- [ ]* 19.1 Write training validation tests
|
| 195 |
+
- Test dataset loading and batching
|
| 196 |
+
- Verify model architecture initialization
|
| 197 |
+
- Test checkpoint saving and loading
|
| 198 |
+
- Validate training loop executes without errors
|
| 199 |
+
- _Requirements: 5.2, 5.4_
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
## Phase 5: SmolVLA Inference Integration
|
| 203 |
+
|
| 204 |
+
- [x] 20. Implement SmolVLA executor
|
| 205 |
+
- Create `src/mortis/smolvla_executor.py` module
|
| 206 |
+
- Implement `SmolVLAExecutor` class with model loading
|
| 207 |
+
- Add checkpoint loading from configurable path
|
| 208 |
+
- Implement GPU device management
|
| 209 |
+
- Add model initialization and warmup
|
| 210 |
+
- _Requirements: 6.1, 8.2_
|
| 211 |
+
|
| 212 |
+
- [x] 21. Implement observation capture
|
| 213 |
+
- Add camera integration for visual observations
|
| 214 |
+
- Implement robot state capture from SO101
|
| 215 |
+
- Create observation dictionary formatting for SmolVLA
|
| 216 |
+
- Add tensor conversion and device placement
|
| 217 |
+
- _Requirements: 6.2, 6.4_
|
| 218 |
+
|
| 219 |
+
- [x] 22. Implement action execution loop
|
| 220 |
+
- Create `execute()` method for task execution
|
| 221 |
+
- Implement inference loop with visual feedback
|
| 222 |
+
- Add action tensor to SO101 command conversion
|
| 223 |
+
- Implement step-by-step action execution
|
| 224 |
+
- Add task completion detection logic
|
| 225 |
+
- _Requirements: 6.2, 6.3_
|
| 226 |
+
|
| 227 |
+
- [x] 23. Implement safety and error handling
|
| 228 |
+
- Add command validation against trained task set
|
| 229 |
+
- Implement workspace safety checks
|
| 230 |
+
- Add emergency stop functionality
|
| 231 |
+
- Implement timeout handling for long-running tasks
|
| 232 |
+
- Add GPU out-of-memory recovery
|
| 233 |
+
- _Requirements: 6.5, 11.3_
|
| 234 |
+
|
| 235 |
+
- [ ]* 23.1 Write SmolVLA inference tests
|
| 236 |
+
- Test model loading from checkpoint
|
| 237 |
+
- Test observation capture and formatting
|
| 238 |
+
- Test action prediction and execution
|
| 239 |
+
- Verify emergency stop functionality
|
| 240 |
+
- _Requirements: 6.1, 6.3, 6.5_
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
## Phase 6: Intent Detection and Routing
|
| 244 |
+
|
| 245 |
+
- [x] 24. Implement intent router
|
| 246 |
+
- Create `src/mortis/intent_router.py` module
|
| 247 |
+
- Implement `IntentRouter` class with task definitions
|
| 248 |
+
- Add `parse_gemini_response()` method for JSON parsing
|
| 249 |
+
- Implement command validation logic
|
| 250 |
+
- Create `Intent` dataclass for structured intent representation
|
| 251 |
+
- _Requirements: 3.2, 3.3, 3.4, 3.5_
|
| 252 |
+
|
| 253 |
+
- [x] 25. Update Gemini prompt for intent detection
|
| 254 |
+
- Enhance system prompt with all 6 manipulation task definitions
|
| 255 |
+
- Add clear response format specification for manipulation vs conversation
|
| 256 |
+
- Implement intent type detection in prompt
|
| 257 |
+
- Add examples of manipulation and conversational inputs
|
| 258 |
+
- _Requirements: 3.1, 3.2_
|
| 259 |
+
|
| 260 |
+
- [x] 26. Integrate intent routing in main flow
|
| 261 |
+
- Update `ask_mortis()` to use `IntentRouter`
|
| 262 |
+
- Implement routing logic for manipulation vs gesture execution
|
| 263 |
+
- Add command validation before SmolVLA execution
|
| 264 |
+
- Implement fallback to gestures for invalid commands
|
| 265 |
+
- _Requirements: 3.3, 3.4, 3.5_
|
| 266 |
+
|
| 267 |
+
- [ ]* 26.1 Write intent detection tests
|
| 268 |
+
- Test parsing of manipulation responses
|
| 269 |
+
- Test parsing of conversational responses
|
| 270 |
+
- Test command validation logic
|
| 271 |
+
- Verify fallback behavior for invalid commands
|
| 272 |
+
- Test edge cases and malformed responses
|
| 273 |
+
- _Requirements: 3.2, 3.3, 3.4_
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
## Phase 7: Asynchronous Execution System (Hybrid Approach)
|
| 277 |
+
|
| 278 |
+
**Note**: This phase uses a hybrid execution system:
|
| 279 |
+
- **AsyncExecutor**: Simple threading for quick gestures (wave, point, idle)
|
| 280 |
+
- **LeRobotAsyncClient**: LeRobot async inference (PolicyServer + RobotClient) for complex manipulation tasks with SmolVLA
|
| 281 |
+
|
| 282 |
+
- [x] 27. Implement async executor infrastructure for gestures
|
| 283 |
+
- Create `src/mortis/async_executor.py` module
|
| 284 |
+
- Implement `AsyncExecutor` class with task queue
|
| 285 |
+
- Add background worker thread for task processing
|
| 286 |
+
- Implement status queue for progress updates
|
| 287 |
+
- Add start/stop methods for executor lifecycle
|
| 288 |
+
- Create `Task` and `StatusUpdate` dataclasses
|
| 289 |
+
- Add comprehensive tests (15 tests, all passing)
|
| 290 |
+
- _Requirements: 7.1, 7.2_
|
| 291 |
+
|
| 292 |
+
- [x] 28. Implement LeRobot async client for manipulation
|
| 293 |
+
- Create `src/mortis/lerobot_async_client.py` module
|
| 294 |
+
- Implement `LeRobotAsyncClient` wrapper class
|
| 295 |
+
- Integrate PolicyServer and RobotClient from LeRobot
|
| 296 |
+
- Add `ManipulationTask` and `ManipulationStatus` models
|
| 297 |
+
- Implement lifecycle management (start/stop)
|
| 298 |
+
- Add task execution with status tracking
|
| 299 |
+
- Create demo scripts and documentation
|
| 300 |
+
- _Requirements: 7.1, 7.2, 7.5_
|
| 301 |
+
|
| 302 |
+
- [x] 29. Integrate hybrid execution in main application
|
| 303 |
+
- Initialize both AsyncExecutor and LeRobotAsyncClient in app.py
|
| 304 |
+
- Update `mortis_reply()` to route gestures to AsyncExecutor
|
| 305 |
+
- Update `mortis_reply()` to route manipulation to LeRobotAsyncClient
|
| 306 |
+
- Implement proper lifecycle management (start on app load, stop on unload)
|
| 307 |
+
- Handle errors from both systems
|
| 308 |
+
- _Requirements: 7.1, 7.2, 7.5_
|
| 309 |
+
|
| 310 |
+
- [x] 30. Add hybrid status display to Gradio UI
|
| 311 |
+
- Add status textbox component to UI for robot status
|
| 312 |
+
- Implement `check_status()` function that monitors both systems
|
| 313 |
+
- Check AsyncExecutor for gesture status updates
|
| 314 |
+
- Check LeRobotAsyncClient for manipulation status
|
| 315 |
+
- Configure Gradio to poll status every 500ms
|
| 316 |
+
- Display appropriate icons and messages for each system
|
| 317 |
+
- Add visual indicators for different task states (idle, running, complete, failed)
|
| 318 |
+
- _Requirements: 7.3, 7.4, 7.5_
|
| 319 |
+
|
| 320 |
+
- [x] 31. Test and validate hybrid execution system
|
| 321 |
+
- Test gesture execution via AsyncExecutor
|
| 322 |
+
- Test manipulation execution via LeRobotAsyncClient
|
| 323 |
+
- Verify both systems can run concurrently
|
| 324 |
+
- Test status updates from both systems
|
| 325 |
+
- Verify UI remains responsive during long manipulation tasks
|
| 326 |
+
- Test error handling in both systems
|
| 327 |
+
- Validate proper cleanup on app shutdown
|
| 328 |
+
- _Requirements: 7.1, 7.2, 7.3, 7.4, 7.5_
|
| 329 |
+
|
| 330 |
+
- [ ]* 31.1 Write integration tests for hybrid system
|
| 331 |
+
- Test AsyncExecutor with mock gesture executor
|
| 332 |
+
- Test LeRobotAsyncClient with mock PolicyServer/RobotClient
|
| 333 |
+
- Test routing logic (gesture vs manipulation)
|
| 334 |
+
- Test concurrent execution of gestures and manipulation
|
| 335 |
+
- Verify status updates from both systems
|
| 336 |
+
- Test error recovery and fallback behavior
|
| 337 |
+
- _Requirements: 7.1, 7.2, 7.3, 7.5_
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
## Phase 8: Integration, Testing, and Deployment
|
| 341 |
+
|
| 342 |
+
- [ ]* 32. Update project dependencies
|
| 343 |
+
- Update `pyproject.toml` with new dependencies (google-generativeai, google-cloud-speech, google-cloud-texttospeech)
|
| 344 |
+
- Add optional dependencies for training (wandb, hydra-core)
|
| 345 |
+
- Update Makefile with new commands (collect-data, train-smolvla)
|
| 346 |
+
- Run `make install` to sync dependencies
|
| 347 |
+
- _Requirements: 8.4, 9.5_
|
| 348 |
+
|
| 349 |
+
- [ ]* 33. Update environment configuration
|
| 350 |
+
- Create `.env.example` with all required variables
|
| 351 |
+
- Document Gemini API key setup
|
| 352 |
+
- Document Google Cloud credentials setup
|
| 353 |
+
- Add SmolVLA checkpoint path configuration
|
| 354 |
+
- Update README with new environment variables
|
| 355 |
+
- _Requirements: 1.3, 8.4_
|
| 356 |
+
|
| 357 |
+
- [ ]* 34. Implement logging and monitoring
|
| 358 |
+
- Add structured logging throughout application
|
| 359 |
+
- Log Gemini API calls and response times
|
| 360 |
+
- Log SmolVLA inference times and success rates
|
| 361 |
+
- Add error logging with stack traces
|
| 362 |
+
- Create log rotation and cleanup
|
| 363 |
+
- _Requirements: 11.5_
|
| 364 |
+
|
| 365 |
+
- [x] 35. Create comprehensive documentation
|
| 366 |
+
- Update README with new features and setup instructions
|
| 367 |
+
- Document data collection workflow
|
| 368 |
+
- Document training process
|
| 369 |
+
- Create user guide for voice interaction
|
| 370 |
+
- Add troubleshooting section
|
| 371 |
+
- _Requirements: 8.4, 9.5_
|
| 372 |
+
|
| 373 |
+
- [ ]* 36. Perform end-to-end integration testing
|
| 374 |
+
- Test complete voice input → Gemini → SmolVLA → audio output flow
|
| 375 |
+
- Test text input → intent detection → gesture execution flow
|
| 376 |
+
- Test error handling and recovery across all components
|
| 377 |
+
- Verify UI responsiveness during long operations
|
| 378 |
+
- Test with all 6 manipulation tasks
|
| 379 |
+
- _Requirements: 9.1, 9.2, 9.3, 9.4_
|
| 380 |
+
|
| 381 |
+
- [ ]* 36.1 Write system-level tests
|
| 382 |
+
- Test multi-modal interaction flows
|
| 383 |
+
- Test concurrent user requests
|
| 384 |
+
- Test resource usage (GPU memory, CPU)
|
| 385 |
+
- Benchmark performance metrics
|
| 386 |
+
- _Requirements: 1.5, 2.4, 7.3_
|
| 387 |
+
|
| 388 |
+
- [ ]* 37. Optimize performance
|
| 389 |
+
- Profile Gemini API response times
|
| 390 |
+
- Optimize SmolVLA inference speed
|
| 391 |
+
- Reduce audio processing latency
|
| 392 |
+
- Implement caching where appropriate
|
| 393 |
+
- Optimize GPU memory usage
|
| 394 |
+
- _Requirements: 1.5, 2.4_
|
| 395 |
+
|
| 396 |
+
- [ ]* 38. Final deployment preparation
|
| 397 |
+
- Create deployment checklist
|
| 398 |
+
- Set up monitoring and alerting
|
| 399 |
+
- Prepare rollback procedures
|
| 400 |
+
- Create backup of current system
|
| 401 |
+
- Document deployment process
|
| 402 |
+
- _Requirements: 8.1, 8.2, 8.3_
|
| 403 |
+
|
.kiro/steering/product.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
inclusion: always
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Product Overview
|
| 6 |
+
|
| 7 |
+
**Mortis** is an interactive AI Halloween experience that combines conversational AI with physical robotics. It's a Gradio web application where users chat with "Mortis," a mischievous Halloween spirit powered by LLMs.
|
| 8 |
+
|
| 9 |
+
## Core Concept
|
| 10 |
+
|
| 11 |
+
Mortis responds to user messages with:
|
| 12 |
+
- Text responses (character-driven, in-character dialogue)
|
| 13 |
+
- Emotional moods (ominous, playful, angry, etc.)
|
| 14 |
+
- Physical gestures via a SeeedStudio SO101 robotic arm controlled through LeRobot
|
| 15 |
+
|
| 16 |
+
## Key Features
|
| 17 |
+
|
| 18 |
+
- Web UI with Halloween-themed background
|
| 19 |
+
- Multi-model LLM support via API
|
| 20 |
+
- Structured tool calling for coordinated text + gesture responses
|
| 21 |
+
- Real-time robotic arm control synchronized with AI responses
|
| 22 |
+
- Local webcam view (browser-only, no upload)
|
| 23 |
+
|
| 24 |
+
## Character Guidelines
|
| 25 |
+
|
| 26 |
+
When working with Mortis dialogue:
|
| 27 |
+
- Keep responses ≤30 words, ≤120 characters
|
| 28 |
+
- No emojis or markdown in character responses
|
| 29 |
+
- Maintain Halloween/haunted theme
|
| 30 |
+
- Responses should feel mischievous, spectral, or ominous
|
.kiro/steering/structure.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
inclusion: always
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Project Structure
|
| 6 |
+
|
| 7 |
+
## Directory Layout
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
mortis/
|
| 11 |
+
├── src/mortis/ # Main application package
|
| 12 |
+
│ ├── app.py # Gradio UI and main entry point
|
| 13 |
+
│ ├── tools.py # LLM API integration and tool calling
|
| 14 |
+
│ ├── robot.py # Robot arm control and gesture definitions
|
| 15 |
+
│ └── calibrate.py # Robot calibration script
|
| 16 |
+
├── examples/ # Example/demo scripts
|
| 17 |
+
│ └── demo.py # Simple demo runner
|
| 18 |
+
├── assets/ # Static assets (images, backgrounds)
|
| 19 |
+
│ └── image.png # Halloween background image
|
| 20 |
+
├── .cache/ # Runtime cache (calibration data)
|
| 21 |
+
├── .env # Environment variables (not committed)
|
| 22 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 23 |
+
├── uv.lock # Locked dependency versions
|
| 24 |
+
├── Makefile # Build and run commands
|
| 25 |
+
└── README.md # User documentation
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Module Organization
|
| 29 |
+
|
| 30 |
+
### `src/mortis/app.py`
|
| 31 |
+
- Gradio UI construction
|
| 32 |
+
- Chat interface setup
|
| 33 |
+
- Model selection dropdown
|
| 34 |
+
- CSS styling with base64-encoded background
|
| 35 |
+
- Main entry point (`main()` function)
|
| 36 |
+
|
| 37 |
+
### `src/mortis/tools.py`
|
| 38 |
+
- LLM API client
|
| 39 |
+
- Tool definition for structured outputs
|
| 40 |
+
- `ask_mortis()` function: sends user message, receives structured response
|
| 41 |
+
- Coordinates LLM response with robot gesture execution
|
| 42 |
+
- Manages global `mortis_arm` instance
|
| 43 |
+
|
| 44 |
+
### `src/mortis/robot.py`
|
| 45 |
+
- `MortisArm` class: robot connection and control
|
| 46 |
+
- `GESTURES` dictionary: predefined gesture sequences
|
| 47 |
+
- Each gesture is a list of (pose_dict, delay) tuples
|
| 48 |
+
- Available gestures: idle, wave, point_left, point_right, grab, drop
|
| 49 |
+
- Pose dictionaries specify joint positions in degrees
|
| 50 |
+
|
| 51 |
+
### `src/mortis/calibrate.py`
|
| 52 |
+
- Standalone calibration script
|
| 53 |
+
- Configures SO101Follower with calibration directory
|
| 54 |
+
- Interactive calibration process
|
| 55 |
+
|
| 56 |
+
## Code Conventions
|
| 57 |
+
|
| 58 |
+
### Import Style
|
| 59 |
+
- Standard library imports first
|
| 60 |
+
- Third-party imports second
|
| 61 |
+
- Local imports last
|
| 62 |
+
- Use `from .module import` for intra-package imports
|
| 63 |
+
|
| 64 |
+
### Path Handling
|
| 65 |
+
- Use `pathlib.Path` for all file paths
|
| 66 |
+
- `REPO_ROOT` defined as `Path(__file__).resolve().parents[2]`
|
| 67 |
+
- Relative paths from repo root for assets and config
|
| 68 |
+
|
| 69 |
+
### Robot Control Pattern
|
| 70 |
+
- Always check `mortis_arm.connected` before operations
|
| 71 |
+
- Connect once, reuse connection
|
| 72 |
+
- Disconnect on app unload (Gradio `demo.unload()`)
|
| 73 |
+
- Gestures execute synchronously with blocking delays
|
| 74 |
+
|
| 75 |
+
### API Response Handling
|
| 76 |
+
- Structured tool calling enforced via `tool_choice`
|
| 77 |
+
- Parse `tool_calls[0].function.arguments` as JSON
|
| 78 |
+
- Extract: message (str), mood (enum), gesture (enum)
|
| 79 |
+
- Execute gesture immediately after parsing response
|
| 80 |
+
|
| 81 |
+
## Entry Points
|
| 82 |
+
|
| 83 |
+
Defined in `pyproject.toml`:
|
| 84 |
+
- `mortis` → `mortis.app:main` (run the Gradio app)
|
| 85 |
+
- `calibrate` → `mortis.calibrate:main` (calibrate robot)
|
.kiro/steering/tech.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
inclusion: always
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Tech Stack
|
| 6 |
+
|
| 7 |
+
## Core Technologies
|
| 8 |
+
|
| 9 |
+
- **Python**: 3.12+ (required)
|
| 10 |
+
- **Package Manager**: `uv` (modern Python dependency manager)
|
| 11 |
+
- **Web Framework**: Gradio 5.49.1+
|
| 12 |
+
- **Robotics**: LeRobot 0.4.0+ with Feetech servo support
|
| 13 |
+
- **API Client**: requests library for LLM API
|
| 14 |
+
- **Environment**: python-dotenv for configuration
|
| 15 |
+
|
| 16 |
+
## Build System
|
| 17 |
+
|
| 18 |
+
The project uses a **Makefile** for all common operations. Always prefer `make` commands over direct CLI invocations.
|
| 19 |
+
|
| 20 |
+
### Common Commands
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
# Setup and dependencies
|
| 24 |
+
make install # Install/sync dependencies
|
| 25 |
+
make sync # Alias for install
|
| 26 |
+
make upgrade # Upgrade all dependencies
|
| 27 |
+
|
| 28 |
+
# Running the application
|
| 29 |
+
make run # Run via CLI entrypoint (mortis)
|
| 30 |
+
make run-m # Run as Python module
|
| 31 |
+
make demo # Run example script
|
| 32 |
+
|
| 33 |
+
# Robot operations
|
| 34 |
+
make calibrate # Calibrate the SO101 arm (required first-time setup)
|
| 35 |
+
make test-gesture # Test individual gestures
|
| 36 |
+
|
| 37 |
+
# Development
|
| 38 |
+
make check-env # Verify .env configuration
|
| 39 |
+
make add-<package> # Add new dependency (e.g., make add-numpy)
|
| 40 |
+
make export # Export requirements.txt from uv.lock
|
| 41 |
+
make clean # Remove build artifacts
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Environment Configuration
|
| 45 |
+
|
| 46 |
+
Required `.env` file in project root:
|
| 47 |
+
```
|
| 48 |
+
API_KEY=your_api_key
|
| 49 |
+
API_BASE_URL=https://api.example.com/v1/chat/completions
|
| 50 |
+
ROBOT_PORT=/dev/ttyACM1 # Optional, defaults to /dev/ttyACM1
|
| 51 |
+
PORT=7860 # Optional, defaults to 7860
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## API Integration
|
| 55 |
+
|
| 56 |
+
- Uses LLM chat completions API
|
| 57 |
+
- Supports multiple models
|
| 58 |
+
- Implements structured tool calling for coordinated responses
|
| 59 |
+
- Tool: `perform_mortis_act` returns {message, mood, gesture}
|
| 60 |
+
|
| 61 |
+
## Robot Hardware
|
| 62 |
+
|
| 63 |
+
- **Device**: SeeedStudio SO101 robotic arm
|
| 64 |
+
- **Connection**: USB serial (typically /dev/ttyACM1)
|
| 65 |
+
- **Calibration**: Stored in `.cache/calibration/so101/`
|
| 66 |
+
- **Control**: LeRobot framework with SO101Follower driver
|
| 67 |
+
- **Modes**:
|
| 68 |
+
- `physical` - Connects to real robot hardware (default)
|
| 69 |
+
- `simulation` - Simulates robot without hardware (for development/testing)
|
app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
# Para que Python vea src/mortis
|
| 6 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
+
SRC_DIR = os.path.join(CURRENT_DIR, "src")
|
| 8 |
+
if SRC_DIR not in sys.path:
|
| 9 |
+
sys.path.append(SRC_DIR)
|
| 10 |
+
|
| 11 |
+
from mortis.app import ui # o tu función que crea el chatbot
|
| 12 |
+
|
| 13 |
+
# ⚙️ Hugging Face pasa el puerto en la variable PORT
|
| 14 |
+
port = int(os.getenv("PORT", "7860"))
|
| 15 |
+
|
| 16 |
+
demo = ui() # aquí dentro montas tu Chatbot/ChatInterface
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
demo.launch(
|
| 20 |
+
server_name="0.0.0.0", # ¡IMPORTANTE en Docker!
|
| 21 |
+
server_port=port,
|
| 22 |
+
show_error=True,
|
| 23 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
google-genai>=1.53.0
|
| 2 |
-
google-cloud-texttospeech>=2.16.0
|
| 3 |
gradio==5.49.1
|
| 4 |
gtts>=2.5.0
|
| 5 |
lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0
|
|
|
|
| 1 |
google-genai>=1.53.0
|
| 2 |
+
google-cloud-texttospeech>=2.16.0
|
| 3 |
gradio==5.49.1
|
| 4 |
gtts>=2.5.0
|
| 5 |
lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0
|
src/mortis/__init__.py
ADDED
|
File without changes
|
src/mortis/app.py
ADDED
|
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
from .tools import ask_mortis, mortis_arm
|
| 11 |
+
from .stt_service import STTService, AudioProcessingError
|
| 12 |
+
from .tts_service import get_tts_service
|
| 13 |
+
from .async_executor import AsyncExecutor, Task, TaskType, TaskStatus
|
| 14 |
+
from .lerobot_async_client import LeRobotAsyncClient, ManipulationStatus
|
| 15 |
+
from .intent_router import IntentRouter, Intent
|
| 16 |
+
from .models import ResponseType
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 20 |
+
BG_IMAGE = REPO_ROOT / "assets" / "kiroween.png"
|
| 21 |
+
|
| 22 |
+
MODEL_CHOICES = [
|
| 23 |
+
"gemini-2.5-flash",
|
| 24 |
+
"gemini-2.0-flash-exp",
|
| 25 |
+
"gemini-1.5-pro",
|
| 26 |
+
"gemini-1.5-flash",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Initialize STT service (global instance)
|
| 30 |
+
stt_service = None
|
| 31 |
+
|
| 32 |
+
# Initialize async execution systems (global instances)
|
| 33 |
+
async_executor = None
|
| 34 |
+
lerobot_client = None
|
| 35 |
+
intent_router = None
|
| 36 |
+
|
| 37 |
+
def get_stt_service():
|
| 38 |
+
"""Lazy initialization of STT service."""
|
| 39 |
+
global stt_service
|
| 40 |
+
if stt_service is None:
|
| 41 |
+
try:
|
| 42 |
+
stt_service = STTService()
|
| 43 |
+
logging.getLogger(__name__).info("✅ STT service initialized")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logging.getLogger(__name__).error(f"❌ Failed to initialize STT service: {e}")
|
| 46 |
+
raise
|
| 47 |
+
return stt_service
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Initialize TTS service (global instance)
|
| 51 |
+
tts_service = None
|
| 52 |
+
|
| 53 |
+
def get_tts_service_instance():
|
| 54 |
+
"""Lazy initialization of TTS service."""
|
| 55 |
+
global tts_service
|
| 56 |
+
if tts_service is None:
|
| 57 |
+
try:
|
| 58 |
+
tts_service = get_tts_service()
|
| 59 |
+
logging.getLogger(__name__).info("✅ TTS service initialized")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logging.getLogger(__name__).error(f"❌ Failed to initialize TTS service: {e}")
|
| 62 |
+
raise
|
| 63 |
+
return tts_service
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def execute_async_task(task: Task):
|
| 67 |
+
"""
|
| 68 |
+
Execute a task asynchronously (called by AsyncExecutor worker thread).
|
| 69 |
+
|
| 70 |
+
This function is called by the AsyncExecutor's worker thread to execute
|
| 71 |
+
tasks. It handles both gesture and manipulation tasks.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
task: Task to execute
|
| 75 |
+
"""
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
if task.type == TaskType.GESTURE:
|
| 80 |
+
# Execute gesture using mortis_arm
|
| 81 |
+
gesture = task.gesture
|
| 82 |
+
logger.info(f"Executing gesture: {gesture}")
|
| 83 |
+
|
| 84 |
+
if mortis_arm.connected:
|
| 85 |
+
mortis_arm.move_arm(gesture)
|
| 86 |
+
else:
|
| 87 |
+
logger.warning("Robot arm not connected, skipping gesture")
|
| 88 |
+
|
| 89 |
+
elif task.type == TaskType.MANIPULATION:
|
| 90 |
+
# This shouldn't happen - manipulation goes through LeRobotAsyncClient
|
| 91 |
+
logger.warning(f"Manipulation task in AsyncExecutor: {task.command}")
|
| 92 |
+
logger.warning("Manipulation tasks should use LeRobotAsyncClient")
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
logger.error(f"Unknown task type: {task.type}")
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Error executing task {task.id}: {e}", exc_info=True)
|
| 99 |
+
raise
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_async_executor():
|
| 103 |
+
"""Lazy initialization of AsyncExecutor."""
|
| 104 |
+
global async_executor
|
| 105 |
+
if async_executor is None:
|
| 106 |
+
try:
|
| 107 |
+
# Create executor with gesture execution function
|
| 108 |
+
async_executor = AsyncExecutor(task_executor=execute_async_task)
|
| 109 |
+
logging.getLogger(__name__).info("✅ AsyncExecutor initialized")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logging.getLogger(__name__).error(f"❌ Failed to initialize AsyncExecutor: {e}")
|
| 112 |
+
raise
|
| 113 |
+
return async_executor
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_lerobot_client():
|
| 117 |
+
"""Lazy initialization of LeRobotAsyncClient."""
|
| 118 |
+
global lerobot_client
|
| 119 |
+
|
| 120 |
+
# Use a sentinel value to indicate we've already checked and manipulation is disabled
|
| 121 |
+
if lerobot_client is None:
|
| 122 |
+
# Check if we're in simulation mode
|
| 123 |
+
robot_mode = os.getenv("ROBOT_MODE", "physical").lower()
|
| 124 |
+
if robot_mode == "simulation":
|
| 125 |
+
# Set to False to indicate manipulation is not available in simulation
|
| 126 |
+
lerobot_client = False
|
| 127 |
+
logging.getLogger(__name__).info("ℹ️ Manipulation disabled in simulation mode")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
# Check if manipulation is enabled
|
| 131 |
+
enable_manipulation = os.getenv("ENABLE_MANIPULATION", "false").lower() == "true"
|
| 132 |
+
|
| 133 |
+
if not enable_manipulation:
|
| 134 |
+
# Set to False (not None) to indicate we've checked and it's disabled
|
| 135 |
+
# This prevents logging the message repeatedly
|
| 136 |
+
lerobot_client = False
|
| 137 |
+
logging.getLogger(__name__).info("ℹ️ Manipulation disabled (ENABLE_MANIPULATION=false)")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
robot_port = os.getenv("ROBOT_PORT", "/dev/ttyACM1")
|
| 142 |
+
model_path = os.getenv("SMOLVLA_MODEL_PATH", "jlamperez/kiroween-potion-smolvla")
|
| 143 |
+
|
| 144 |
+
lerobot_client = LeRobotAsyncClient(
|
| 145 |
+
robot_port=robot_port,
|
| 146 |
+
model_path=model_path
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Configure idle callback to move robot to safe position on timeout
|
| 150 |
+
lerobot_client.set_idle_callback(lambda: mortis_arm.move_arm("idle") if mortis_arm.connected else None)
|
| 151 |
+
|
| 152 |
+
logging.getLogger(__name__).info("✅ LeRobotAsyncClient initialized")
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logging.getLogger(__name__).error(f"❌ Failed to initialize LeRobotAsyncClient: {e}")
|
| 155 |
+
# Don't raise - manipulation is optional
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
# Return None if manipulation is disabled (lerobot_client == False)
|
| 159 |
+
return lerobot_client if lerobot_client is not False else None
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_intent_router_instance():
|
| 163 |
+
"""Lazy initialization of IntentRouter."""
|
| 164 |
+
global intent_router
|
| 165 |
+
if intent_router is None:
|
| 166 |
+
try:
|
| 167 |
+
intent_router = IntentRouter()
|
| 168 |
+
logging.getLogger(__name__).info("✅ IntentRouter initialized")
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logging.getLogger(__name__).error(f"❌ Failed to initialize IntentRouter: {e}")
|
| 171 |
+
raise
|
| 172 |
+
return intent_router
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def build_css(image_path: str) -> str:
|
| 176 |
+
"""Background with custom image."""
|
| 177 |
+
with open(image_path, "rb") as f:
|
| 178 |
+
b64 = base64.b64encode(f.read()).decode()
|
| 179 |
+
|
| 180 |
+
return f"""
|
| 181 |
+
.gradio-container {{
|
| 182 |
+
background-image: url("data:image/png;base64,{b64}");
|
| 183 |
+
background-size: cover;
|
| 184 |
+
background-position: center;
|
| 185 |
+
background-repeat: no-repeat;
|
| 186 |
+
background-attachment: fixed;
|
| 187 |
+
}}
|
| 188 |
+
|
| 189 |
+
footer::after{{
|
| 190 |
+
content: "by: Jorge Lamperez 🤖";
|
| 191 |
+
margin-left: 8px;
|
| 192 |
+
opacity: .85;
|
| 193 |
+
}}
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def process_audio_input(audio_path):
|
| 198 |
+
"""
|
| 199 |
+
Process audio input from microphone and return transcribed text.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
audio_path: Path to recorded audio file from Gradio
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Transcribed text or error message
|
| 206 |
+
"""
|
| 207 |
+
logger = logging.getLogger(__name__)
|
| 208 |
+
|
| 209 |
+
if audio_path is None:
|
| 210 |
+
return ""
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
logger.info(f"🎤 Processing audio input: {audio_path}")
|
| 214 |
+
|
| 215 |
+
# Get STT service
|
| 216 |
+
stt = get_stt_service()
|
| 217 |
+
|
| 218 |
+
# Transcribe audio
|
| 219 |
+
transcript = stt.transcribe(audio_path)
|
| 220 |
+
|
| 221 |
+
if not transcript:
|
| 222 |
+
logger.warning("⚠️ Audio transcription returned empty result")
|
| 223 |
+
return ""
|
| 224 |
+
|
| 225 |
+
logger.info(f"✅ Transcription successful: '{transcript[:50]}...'")
|
| 226 |
+
return transcript
|
| 227 |
+
|
| 228 |
+
except FileNotFoundError as e:
|
| 229 |
+
error_msg = f"Audio file not found: {e}"
|
| 230 |
+
logger.error(f"❌ {error_msg}")
|
| 231 |
+
return f"[Error: {error_msg}]"
|
| 232 |
+
|
| 233 |
+
except AudioProcessingError as e:
|
| 234 |
+
error_msg = f"Audio processing failed: {e}"
|
| 235 |
+
logger.error(f"❌ {error_msg}")
|
| 236 |
+
return f"[Error: {error_msg}]"
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
error_msg = f"Unexpected error during transcription: {type(e).__name__}: {e}"
|
| 240 |
+
logger.error(f"❌ {error_msg}")
|
| 241 |
+
return f"[Error: {error_msg}]"
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def mortis_reply(message, history, model_name):
|
| 245 |
+
logger = logging.getLogger(__name__)
|
| 246 |
+
logger.info(f"💬 User message: {message[:50]}{'...' if len(message) > 50 else ''}")
|
| 247 |
+
logger.info(f"🤖 Using model: {model_name}")
|
| 248 |
+
|
| 249 |
+
msg, mood, gesture = ask_mortis(message, model_name=model_name)
|
| 250 |
+
|
| 251 |
+
logger.info(f"👻 Mortis reply: {msg[:50]}{'...' if len(msg) > 50 else ''}")
|
| 252 |
+
logger.info(f"😈 Mood: {mood}, Gesture: {gesture}")
|
| 253 |
+
|
| 254 |
+
return msg
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def mortis_reply_with_audio(message, history, model_name, audio_input_path=None):
|
| 258 |
+
"""
|
| 259 |
+
Generate Mortis reply with both text and audio output using hybrid execution.
|
| 260 |
+
|
| 261 |
+
This function integrates the hybrid async execution system:
|
| 262 |
+
- Gestures are routed to AsyncExecutor (simple threading)
|
| 263 |
+
- Manipulation tasks are routed to LeRobotAsyncClient (LeRobot async inference)
|
| 264 |
+
|
| 265 |
+
Supports both text and voice input through the unified voice pipeline.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
message: User message text (optional if audio_input_path provided)
|
| 269 |
+
history: Chat history
|
| 270 |
+
model_name: Gemini model to use
|
| 271 |
+
audio_input_path: Optional path to audio input file
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Tuple of (text_response, audio_path)
|
| 275 |
+
"""
|
| 276 |
+
logger = logging.getLogger(__name__)
|
| 277 |
+
|
| 278 |
+
# Import necessary components
|
| 279 |
+
from .gemini_client import GeminiClient
|
| 280 |
+
|
| 281 |
+
# Log input type
|
| 282 |
+
if audio_input_path:
|
| 283 |
+
logger.info(f"🎤 Voice input: {audio_input_path}")
|
| 284 |
+
|
| 285 |
+
# Transcribe audio to text
|
| 286 |
+
try:
|
| 287 |
+
stt = get_stt_service()
|
| 288 |
+
message = stt.transcribe(audio_input_path)
|
| 289 |
+
logger.info(f"📝 Transcribed: '{message[:50]}...'")
|
| 290 |
+
|
| 291 |
+
if not message or not message.strip():
|
| 292 |
+
logger.warning("⚠️ STT returned empty transcription")
|
| 293 |
+
return "I couldn't hear you... speak again.", None
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"❌ Voice input processing failed: {e}")
|
| 296 |
+
return "The spirits couldn't understand... try again.", None
|
| 297 |
+
else:
|
| 298 |
+
logger.info(f"💬 Text input: {message[:50]}{'...' if len(message) > 50 else ''}")
|
| 299 |
+
|
| 300 |
+
logger.info(f"🤖 Using model: {model_name}")
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
# Get Gemini client and send message
|
| 304 |
+
gemini_client = GeminiClient()
|
| 305 |
+
if model_name:
|
| 306 |
+
gemini_client.configure_model(model_name=model_name)
|
| 307 |
+
|
| 308 |
+
response_json = gemini_client.send_message(message)
|
| 309 |
+
|
| 310 |
+
# Parse response using IntentRouter
|
| 311 |
+
router = get_intent_router_instance()
|
| 312 |
+
intent = router.parse_gemini_response(response_json)
|
| 313 |
+
|
| 314 |
+
# Extract response components
|
| 315 |
+
msg = intent.message
|
| 316 |
+
mood = intent.mood
|
| 317 |
+
|
| 318 |
+
logger.info(f"👻 Mortis reply: {msg[:50]}{'...' if len(msg) > 50 else ''}")
|
| 319 |
+
logger.info(f"😈 Mood: {mood}")
|
| 320 |
+
|
| 321 |
+
# Route execution based on intent type
|
| 322 |
+
execution_path = router.route_intent(intent)
|
| 323 |
+
|
| 324 |
+
if execution_path == "manipulation" and intent.is_valid:
|
| 325 |
+
# Route to LeRobotAsyncClient for manipulation
|
| 326 |
+
logger.info(f"🤖 Routing manipulation to LeRobotAsyncClient: {intent.command}")
|
| 327 |
+
|
| 328 |
+
client = get_lerobot_client()
|
| 329 |
+
if client and client.is_running():
|
| 330 |
+
try:
|
| 331 |
+
# Get timeout from environment or use default (60s)
|
| 332 |
+
timeout = float(os.getenv("MANIPULATION_TIMEOUT", "60.0"))
|
| 333 |
+
|
| 334 |
+
# Submit manipulation task asynchronously with timeout
|
| 335 |
+
client.execute_task(
|
| 336 |
+
intent.command,
|
| 337 |
+
blocking=False,
|
| 338 |
+
timeout=timeout
|
| 339 |
+
)
|
| 340 |
+
logger.info(f"✅ Manipulation task submitted: {intent.command} (timeout: {timeout}s)")
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error(f"❌ Failed to submit manipulation task: {e}")
|
| 343 |
+
logger.info("Falling back to gesture execution")
|
| 344 |
+
|
| 345 |
+
# Fallback to gesture
|
| 346 |
+
executor = get_async_executor()
|
| 347 |
+
if executor.running:
|
| 348 |
+
task = Task.create_gesture_task("idle")
|
| 349 |
+
executor.submit_task(task)
|
| 350 |
+
else:
|
| 351 |
+
logger.warning("LeRobotAsyncClient not available, falling back to gesture")
|
| 352 |
+
|
| 353 |
+
# Fallback to gesture
|
| 354 |
+
executor = get_async_executor()
|
| 355 |
+
if executor.running:
|
| 356 |
+
task = Task.create_gesture_task("idle")
|
| 357 |
+
executor.submit_task(task)
|
| 358 |
+
|
| 359 |
+
elif execution_path == "gesture":
|
| 360 |
+
# Route to AsyncExecutor for gesture
|
| 361 |
+
gesture = intent.gesture if intent.gesture else "idle"
|
| 362 |
+
logger.info(f"👋 Routing gesture to AsyncExecutor: {gesture}")
|
| 363 |
+
|
| 364 |
+
executor = get_async_executor()
|
| 365 |
+
if executor.running:
|
| 366 |
+
try:
|
| 367 |
+
# Submit gesture task asynchronously
|
| 368 |
+
task = Task.create_gesture_task(gesture)
|
| 369 |
+
executor.submit_task(task)
|
| 370 |
+
logger.info(f"✅ Gesture task submitted: {gesture}")
|
| 371 |
+
except Exception as e:
|
| 372 |
+
logger.error(f"❌ Failed to submit gesture task: {e}")
|
| 373 |
+
else:
|
| 374 |
+
logger.warning("AsyncExecutor not running, executing gesture synchronously")
|
| 375 |
+
if mortis_arm.connected:
|
| 376 |
+
mortis_arm.move_arm(gesture)
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
# Invalid intent - fallback to idle gesture
|
| 380 |
+
logger.warning(f"⚠️ Invalid intent, falling back to idle gesture")
|
| 381 |
+
|
| 382 |
+
executor = get_async_executor()
|
| 383 |
+
if executor.running:
|
| 384 |
+
task = Task.create_gesture_task("idle")
|
| 385 |
+
executor.submit_task(task)
|
| 386 |
+
elif mortis_arm.connected:
|
| 387 |
+
mortis_arm.move_arm("idle")
|
| 388 |
+
|
| 389 |
+
# Generate audio response
|
| 390 |
+
audio_path = None
|
| 391 |
+
try:
|
| 392 |
+
tts = get_tts_service_instance()
|
| 393 |
+
audio_path = tts.synthesize(msg)
|
| 394 |
+
|
| 395 |
+
if audio_path:
|
| 396 |
+
logger.info(f"🔊 Audio output: {audio_path}")
|
| 397 |
+
except Exception as e:
|
| 398 |
+
logger.error(f"❌ TTS generation failed: {e}")
|
| 399 |
+
# Continue without audio
|
| 400 |
+
|
| 401 |
+
return msg, audio_path
|
| 402 |
+
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.error(f"❌ Error in mortis_reply_with_audio: {e}", exc_info=True)
|
| 405 |
+
return "The spirits are confused... try again.", None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def start_async_systems():
|
| 409 |
+
"""
|
| 410 |
+
Start the async execution systems on app load.
|
| 411 |
+
|
| 412 |
+
This function initializes and starts:
|
| 413 |
+
1. Robot arm connection
|
| 414 |
+
2. AsyncExecutor for gesture execution
|
| 415 |
+
3. LeRobotAsyncClient for manipulation tasks (if enabled)
|
| 416 |
+
"""
|
| 417 |
+
logger = logging.getLogger(__name__)
|
| 418 |
+
logger.info("🚀 Starting async execution systems...")
|
| 419 |
+
|
| 420 |
+
# Connect to robot arm
|
| 421 |
+
try:
|
| 422 |
+
if not mortis_arm.connected:
|
| 423 |
+
mortis_arm.connect()
|
| 424 |
+
if mortis_arm.mode == "simulation":
|
| 425 |
+
logger.info("🎭 Robot arm in SIMULATION mode")
|
| 426 |
+
else:
|
| 427 |
+
logger.info("✅ Robot arm connected")
|
| 428 |
+
else:
|
| 429 |
+
logger.info("ℹ️ Robot arm already connected")
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.error(f"❌ Failed to connect robot arm: {e}", exc_info=True)
|
| 432 |
+
logger.info("ℹ️ Gestures will be skipped until robot is connected")
|
| 433 |
+
|
| 434 |
+
# Start AsyncExecutor
|
| 435 |
+
try:
|
| 436 |
+
executor = get_async_executor()
|
| 437 |
+
if not executor.running:
|
| 438 |
+
executor.start()
|
| 439 |
+
logger.info("✅ AsyncExecutor started")
|
| 440 |
+
else:
|
| 441 |
+
logger.info("ℹ️ AsyncExecutor already running")
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.error(f"❌ Failed to start AsyncExecutor: {e}", exc_info=True)
|
| 444 |
+
|
| 445 |
+
# Start LeRobotAsyncClient (if enabled)
|
| 446 |
+
try:
|
| 447 |
+
client = get_lerobot_client()
|
| 448 |
+
if client and not client.is_running():
|
| 449 |
+
success = client.start()
|
| 450 |
+
if success:
|
| 451 |
+
logger.info("✅ LeRobotAsyncClient started")
|
| 452 |
+
else:
|
| 453 |
+
logger.warning("⚠️ LeRobotAsyncClient failed to start")
|
| 454 |
+
except Exception as e:
|
| 455 |
+
logger.error(f"❌ Failed to start LeRobotAsyncClient: {e}", exc_info=True)
|
| 456 |
+
logger.info("ℹ️ Manipulation tasks will fall back to gestures")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def check_status():
|
| 460 |
+
"""
|
| 461 |
+
Check status of both async execution systems and return formatted status message.
|
| 462 |
+
|
| 463 |
+
This function monitors:
|
| 464 |
+
1. AsyncExecutor for gesture status updates
|
| 465 |
+
2. LeRobotAsyncClient for manipulation status
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Formatted status string with icons and messages
|
| 469 |
+
"""
|
| 470 |
+
logger = logging.getLogger(__name__)
|
| 471 |
+
|
| 472 |
+
status_parts = []
|
| 473 |
+
|
| 474 |
+
# Add robot mode indicator
|
| 475 |
+
if mortis_arm.mode == "simulation":
|
| 476 |
+
status_parts.append("🎭 SIMULATION MODE")
|
| 477 |
+
|
| 478 |
+
# Check AsyncExecutor status
|
| 479 |
+
try:
|
| 480 |
+
executor = get_async_executor()
|
| 481 |
+
if executor and executor.running:
|
| 482 |
+
# Check if executor is busy
|
| 483 |
+
current_task = executor.get_current_task()
|
| 484 |
+
if current_task:
|
| 485 |
+
# Task is running
|
| 486 |
+
if current_task.type == TaskType.GESTURE:
|
| 487 |
+
status_parts.append(f"👋 Gesture: {current_task.gesture} (running)")
|
| 488 |
+
else:
|
| 489 |
+
status_parts.append(f"🤖 Task: {current_task.command[:30]}... (running)")
|
| 490 |
+
else:
|
| 491 |
+
# Check for recent status updates
|
| 492 |
+
updates = executor.get_all_status_updates()
|
| 493 |
+
if updates:
|
| 494 |
+
latest = updates[-1]
|
| 495 |
+
if latest.status == TaskStatus.COMPLETE:
|
| 496 |
+
status_parts.append(f"✅ Gesture complete")
|
| 497 |
+
elif latest.status == TaskStatus.FAILED:
|
| 498 |
+
status_parts.append(f"❌ Gesture failed: {latest.error}")
|
| 499 |
+
elif latest.status == TaskStatus.QUEUED:
|
| 500 |
+
status_parts.append(f"⏳ Gesture queued")
|
| 501 |
+
except Exception as e:
|
| 502 |
+
logger.error(f"Error checking AsyncExecutor status: {e}")
|
| 503 |
+
|
| 504 |
+
# Check LeRobotAsyncClient status
|
| 505 |
+
try:
|
| 506 |
+
client = get_lerobot_client()
|
| 507 |
+
if client and client.is_running():
|
| 508 |
+
manipulation_status = client.get_status()
|
| 509 |
+
current_task = client.get_current_task()
|
| 510 |
+
|
| 511 |
+
if manipulation_status == ManipulationStatus.RUNNING and current_task:
|
| 512 |
+
# Manipulation task is running
|
| 513 |
+
elapsed = time.time() - current_task.started_at if current_task.started_at else 0
|
| 514 |
+
status_parts.append(f"🤖 Manipulation: {current_task.task[:40]}... ({elapsed:.1f}s)")
|
| 515 |
+
elif manipulation_status == ManipulationStatus.COMPLETE and current_task:
|
| 516 |
+
# Task just completed
|
| 517 |
+
duration = current_task.duration or 0
|
| 518 |
+
status_parts.append(f"✅ Manipulation complete ({duration:.1f}s)")
|
| 519 |
+
elif manipulation_status == ManipulationStatus.FAILED and current_task:
|
| 520 |
+
# Task failed
|
| 521 |
+
error = current_task.error or "Unknown error"
|
| 522 |
+
status_parts.append(f"❌ Manipulation failed: {error[:50]}")
|
| 523 |
+
elif manipulation_status == ManipulationStatus.STARTING:
|
| 524 |
+
status_parts.append(f"⏳ Starting manipulation...")
|
| 525 |
+
elif manipulation_status == ManipulationStatus.STOPPED and current_task:
|
| 526 |
+
# Task was stopped (timeout or manual stop)
|
| 527 |
+
duration = current_task.duration or 0
|
| 528 |
+
error_msg = current_task.error or "Stopped"
|
| 529 |
+
|
| 530 |
+
# Check if control thread is still finishing
|
| 531 |
+
if client.control_thread and client.control_thread.is_alive():
|
| 532 |
+
status_parts.append(f"⏹️ Stopped (finishing actions...): {error_msg[:30]}")
|
| 533 |
+
else:
|
| 534 |
+
status_parts.append(f"⏹️ Stopped: {error_msg[:40]} ({duration:.1f}s)")
|
| 535 |
+
except Exception as e:
|
| 536 |
+
logger.error(f"Error checking LeRobotAsyncClient status: {e}")
|
| 537 |
+
|
| 538 |
+
# Return formatted status or idle message
|
| 539 |
+
if status_parts:
|
| 540 |
+
return " | ".join(status_parts)
|
| 541 |
+
else:
|
| 542 |
+
return "💤 Idle - Ready for commands"
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def stop_async_systems():
|
| 546 |
+
"""
|
| 547 |
+
Stop the async execution systems on app unload.
|
| 548 |
+
|
| 549 |
+
This function gracefully shuts down:
|
| 550 |
+
1. AsyncExecutor
|
| 551 |
+
2. LeRobotAsyncClient
|
| 552 |
+
3. Robot arm connection
|
| 553 |
+
"""
|
| 554 |
+
logger = logging.getLogger(__name__)
|
| 555 |
+
logger.info("🛑 Stopping async execution systems...")
|
| 556 |
+
|
| 557 |
+
# Stop AsyncExecutor
|
| 558 |
+
try:
|
| 559 |
+
if async_executor and async_executor.running:
|
| 560 |
+
async_executor.stop()
|
| 561 |
+
logger.info("✅ AsyncExecutor stopped")
|
| 562 |
+
except Exception as e:
|
| 563 |
+
logger.error(f"❌ Error stopping AsyncExecutor: {e}")
|
| 564 |
+
|
| 565 |
+
# Stop LeRobotAsyncClient
|
| 566 |
+
try:
|
| 567 |
+
if lerobot_client and lerobot_client.is_running():
|
| 568 |
+
lerobot_client.stop()
|
| 569 |
+
logger.info("✅ LeRobotAsyncClient stopped")
|
| 570 |
+
except Exception as e:
|
| 571 |
+
logger.error(f"❌ Error stopping LeRobotAsyncClient: {e}")
|
| 572 |
+
|
| 573 |
+
# Disconnect robot arm
|
| 574 |
+
try:
|
| 575 |
+
mortis_arm.disconnect()
|
| 576 |
+
logger.info("✅ Robot arm disconnected")
|
| 577 |
+
except Exception as e:
|
| 578 |
+
logger.error(f"❌ Error disconnecting robot arm: {e}")
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def ui() -> gr.Blocks:
|
| 582 |
+
css=build_css(BG_IMAGE)
|
| 583 |
+
with gr.Blocks(fill_height=True, theme="soft", css=css) as demo:
|
| 584 |
+
# Dynamic title based on robot mode
|
| 585 |
+
mode_indicator = " (Simulation Mode 🎭)" if mortis_arm.mode == "simulation" else ""
|
| 586 |
+
gr.Markdown(
|
| 587 |
+
f"# Kiroween Hackathon 🎃\n"
|
| 588 |
+
f"## Mortis: Haunted Control Room 👻🤖{mode_indicator}",
|
| 589 |
+
elem_id="app-title"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
with gr.Row(equal_height=True):
|
| 593 |
+
with gr.Column():
|
| 594 |
+
model_dd = gr.Dropdown(
|
| 595 |
+
choices=MODEL_CHOICES,
|
| 596 |
+
value=MODEL_CHOICES[0],
|
| 597 |
+
label="Gemini Model",
|
| 598 |
+
info="Select Gemini model for Mortis",
|
| 599 |
+
interactive=True,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Audio input component for voice interaction
|
| 603 |
+
with gr.Row():
|
| 604 |
+
audio_input = gr.Audio(
|
| 605 |
+
sources=["microphone"],
|
| 606 |
+
type="filepath",
|
| 607 |
+
label="🎤 Speak to Mortis",
|
| 608 |
+
show_label=True,
|
| 609 |
+
interactive=True,
|
| 610 |
+
waveform_options=gr.WaveformOptions(
|
| 611 |
+
show_controls=False,
|
| 612 |
+
),
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# Transcription display for user confirmation
|
| 616 |
+
transcription_display = gr.Textbox(
|
| 617 |
+
label="Transcribed Text",
|
| 618 |
+
placeholder="Your transcribed speech will appear here...",
|
| 619 |
+
interactive=False,
|
| 620 |
+
visible=True,
|
| 621 |
+
lines=2,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# Audio output component for Mortis voice responses
|
| 625 |
+
audio_output = gr.Audio(
|
| 626 |
+
label="🔊 Mortis speaks",
|
| 627 |
+
autoplay=True,
|
| 628 |
+
type="filepath",
|
| 629 |
+
interactive=False,
|
| 630 |
+
show_label=True,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# State to store the latest audio path
|
| 634 |
+
audio_state = gr.State(value=None)
|
| 635 |
+
|
| 636 |
+
# Custom wrapper to add audio output to chat responses
|
| 637 |
+
def mortis_reply_wrapper(message, history, model_name, audio_state_value):
|
| 638 |
+
"""Wrapper that generates both text and audio."""
|
| 639 |
+
text_response, audio_path = mortis_reply_with_audio(message, history, model_name)
|
| 640 |
+
# Return text for chat and audio path for state
|
| 641 |
+
return text_response, audio_path
|
| 642 |
+
|
| 643 |
+
# Chat interface
|
| 644 |
+
chat_interface = gr.ChatInterface(
|
| 645 |
+
fn=mortis_reply_wrapper,
|
| 646 |
+
additional_inputs=[model_dd, audio_state],
|
| 647 |
+
additional_outputs=[audio_state],
|
| 648 |
+
chatbot=gr.Chatbot(height=380, label="Mortis chat", type="messages"),
|
| 649 |
+
textbox=gr.Textbox(placeholder="Write your message here or use voice input above…"),
|
| 650 |
+
submit_btn="Send",
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Connect audio input to transcription display and chat
|
| 654 |
+
def handle_audio_and_submit(audio_path, history, model_name):
|
| 655 |
+
"""Handle audio input: transcribe and submit to chat with audio response."""
|
| 656 |
+
if audio_path is None:
|
| 657 |
+
return "", history, None
|
| 658 |
+
|
| 659 |
+
logger = logging.getLogger(__name__)
|
| 660 |
+
logger.info(f"🎤 Handling audio input: {audio_path}")
|
| 661 |
+
|
| 662 |
+
# First, get the transcription for display
|
| 663 |
+
transcript = process_audio_input(audio_path)
|
| 664 |
+
|
| 665 |
+
# If transcription failed, return error
|
| 666 |
+
if not transcript or transcript.startswith("[Error:"):
|
| 667 |
+
return transcript, history, None
|
| 668 |
+
|
| 669 |
+
# Now use the transcribed text to get Mortis response with audio
|
| 670 |
+
# We pass the transcript as text, not the audio file, to avoid double transcription
|
| 671 |
+
response_text, response_audio = mortis_reply_with_audio(
|
| 672 |
+
message=transcript, # Use the transcribed text
|
| 673 |
+
history=history,
|
| 674 |
+
model_name=model_name,
|
| 675 |
+
audio_input_path=None # Don't pass audio since we already transcribed
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
# Update chat history
|
| 679 |
+
history.append({"role": "user", "content": transcript})
|
| 680 |
+
history.append({"role": "assistant", "content": response_text})
|
| 681 |
+
|
| 682 |
+
return transcript, history, response_audio
|
| 683 |
+
|
| 684 |
+
# Wire up audio input to trigger transcription and chat submission
|
| 685 |
+
audio_input.stop_recording(
|
| 686 |
+
fn=handle_audio_and_submit,
|
| 687 |
+
inputs=[audio_input, chat_interface.chatbot, model_dd],
|
| 688 |
+
outputs=[transcription_display, chat_interface.chatbot, audio_output],
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# Connect audio state changes to audio output
|
| 692 |
+
# This ensures audio plays whenever the state is updated by ChatInterface
|
| 693 |
+
audio_state.change(
|
| 694 |
+
fn=lambda x: x, # Pass through the audio path
|
| 695 |
+
inputs=[audio_state],
|
| 696 |
+
outputs=[audio_output],
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
with gr.Column():
|
| 700 |
+
gr.Video(
|
| 701 |
+
sources=["webcam"],
|
| 702 |
+
label="Camera view",
|
| 703 |
+
height=480,
|
| 704 |
+
include_audio=False,
|
| 705 |
+
)
|
| 706 |
+
gr.Markdown("**Webcam (local, no data upload)**\nThe video is only processed in your browser.")
|
| 707 |
+
|
| 708 |
+
# Robot status display
|
| 709 |
+
status_display = gr.Textbox(
|
| 710 |
+
label="🤖 Robot Status",
|
| 711 |
+
value="💤 Idle - Ready for commands",
|
| 712 |
+
interactive=False,
|
| 713 |
+
lines=2,
|
| 714 |
+
max_lines=3,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# Stop button for manipulation tasks
|
| 718 |
+
def stop_manipulation_task():
|
| 719 |
+
"""Stop the currently running manipulation task."""
|
| 720 |
+
logger = logging.getLogger(__name__)
|
| 721 |
+
client = get_lerobot_client()
|
| 722 |
+
|
| 723 |
+
if client and client.is_running():
|
| 724 |
+
if client.is_busy():
|
| 725 |
+
logger.info("🛑 User requested task stop")
|
| 726 |
+
success = client.stop_current_task()
|
| 727 |
+
if success:
|
| 728 |
+
return "⏹️ Task stopped by user"
|
| 729 |
+
else:
|
| 730 |
+
return "❌ Failed to stop task"
|
| 731 |
+
else:
|
| 732 |
+
return "ℹ️ No task running"
|
| 733 |
+
else:
|
| 734 |
+
return "ℹ️ Manipulation not enabled"
|
| 735 |
+
|
| 736 |
+
stop_button = gr.Button(
|
| 737 |
+
"🛑 Stop Manipulation Task",
|
| 738 |
+
variant="stop",
|
| 739 |
+
size="sm",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
stop_button.click(
|
| 743 |
+
fn=stop_manipulation_task,
|
| 744 |
+
outputs=[status_display]
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Status polling timer (must be inside Blocks context)
|
| 748 |
+
status_timer = gr.Timer(value=0.5, active=True)
|
| 749 |
+
|
| 750 |
+
# Lifecycle management: start async systems on load, stop on unload
|
| 751 |
+
demo.load(fn=start_async_systems)
|
| 752 |
+
demo.unload(fn=stop_async_systems)
|
| 753 |
+
|
| 754 |
+
# Status polling: update status display every 500ms using a timer
|
| 755 |
+
status_timer.tick(
|
| 756 |
+
fn=check_status,
|
| 757 |
+
outputs=[status_display]
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
return demo
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def cleanup_audio_files():
|
| 764 |
+
"""Periodic cleanup of old audio files."""
|
| 765 |
+
try:
|
| 766 |
+
tts = get_tts_service_instance()
|
| 767 |
+
tts.cleanup_old_files(max_age_seconds=3600) # Clean files older than 1 hour
|
| 768 |
+
except Exception as e:
|
| 769 |
+
logging.getLogger(__name__).warning(f"Failed to cleanup audio files: {e}")
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def main():
|
| 773 |
+
# Configure logging - force configuration even if already set
|
| 774 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 775 |
+
|
| 776 |
+
# Remove existing handlers and reconfigure
|
| 777 |
+
root_logger = logging.getLogger()
|
| 778 |
+
for handler in root_logger.handlers[:]:
|
| 779 |
+
root_logger.removeHandler(handler)
|
| 780 |
+
|
| 781 |
+
# Set up new handler with our format
|
| 782 |
+
handler = logging.StreamHandler()
|
| 783 |
+
handler.setFormatter(logging.Formatter(
|
| 784 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 785 |
+
datefmt='%H:%M:%S'
|
| 786 |
+
))
|
| 787 |
+
root_logger.addHandler(handler)
|
| 788 |
+
root_logger.setLevel(getattr(logging, log_level))
|
| 789 |
+
|
| 790 |
+
logger = logging.getLogger(__name__)
|
| 791 |
+
logger.info("=" * 60)
|
| 792 |
+
logger.info("🎃 Starting Mortis application...")
|
| 793 |
+
logger.info(f"📊 Log level: {log_level}")
|
| 794 |
+
|
| 795 |
+
# Ensure outputs directory exists
|
| 796 |
+
from pathlib import Path
|
| 797 |
+
outputs_dir = Path("outputs")
|
| 798 |
+
outputs_dir.mkdir(parents=True, exist_ok=True)
|
| 799 |
+
logger.info(f"📁 Audio output directory: {outputs_dir.absolute()}")
|
| 800 |
+
|
| 801 |
+
# Clean up old audio files on startup
|
| 802 |
+
cleanup_audio_files()
|
| 803 |
+
|
| 804 |
+
# Start async systems before launching UI
|
| 805 |
+
start_async_systems()
|
| 806 |
+
|
| 807 |
+
port = int(os.getenv("PORT", "7860"))
|
| 808 |
+
logger.info(f"🌐 Launching on http://127.0.0.1:{port}")
|
| 809 |
+
logger.info("=" * 60)
|
| 810 |
+
|
| 811 |
+
try:
|
| 812 |
+
ui().launch(server_name="127.0.0.1", server_port=port, show_error=True)
|
| 813 |
+
finally:
|
| 814 |
+
# Ensure cleanup on exit
|
| 815 |
+
stop_async_systems()
|
src/mortis/async_executor.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Asynchronous task execution system for Mortis.
|
| 3 |
+
|
| 4 |
+
This module provides infrastructure for executing robot tasks asynchronously
|
| 5 |
+
in a background worker thread, allowing the Gradio UI to remain responsive
|
| 6 |
+
during long-running operations like SmolVLA inference.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from queue import Queue, Empty
|
| 14 |
+
from threading import Thread, Event
|
| 15 |
+
from typing import Optional, Callable, Dict, Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TaskStatus(Enum):
|
| 22 |
+
"""Status of a task in the execution queue."""
|
| 23 |
+
QUEUED = "queued"
|
| 24 |
+
RUNNING = "running"
|
| 25 |
+
COMPLETE = "complete"
|
| 26 |
+
FAILED = "failed"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TaskType(Enum):
|
| 30 |
+
"""Type of robot task to execute."""
|
| 31 |
+
GESTURE = "gesture"
|
| 32 |
+
MANIPULATION = "manipulation"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Task:
|
| 37 |
+
"""
|
| 38 |
+
Represents a robot task for asynchronous execution.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
id: Unique identifier for the task
|
| 42 |
+
type: Type of task (gesture or manipulation)
|
| 43 |
+
status: Current execution status
|
| 44 |
+
created_at: Timestamp when task was created
|
| 45 |
+
started_at: Timestamp when task execution started
|
| 46 |
+
completed_at: Timestamp when task execution completed
|
| 47 |
+
error: Error message if task failed
|
| 48 |
+
gesture: Gesture name for GESTURE type tasks
|
| 49 |
+
command: Command string for MANIPULATION type tasks
|
| 50 |
+
metadata: Additional task-specific data
|
| 51 |
+
"""
|
| 52 |
+
id: str
|
| 53 |
+
type: TaskType
|
| 54 |
+
status: TaskStatus
|
| 55 |
+
created_at: float
|
| 56 |
+
started_at: Optional[float] = None
|
| 57 |
+
completed_at: Optional[float] = None
|
| 58 |
+
error: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
# Task-specific data
|
| 61 |
+
gesture: Optional[str] = None
|
| 62 |
+
command: Optional[str] = None
|
| 63 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def create_gesture_task(cls, gesture: str, metadata: Optional[Dict[str, Any]] = None) -> "Task":
|
| 67 |
+
"""
|
| 68 |
+
Create a gesture execution task.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
gesture: Name of the gesture to execute (e.g., "wave", "idle")
|
| 72 |
+
metadata: Optional additional task data
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Task configured for gesture execution
|
| 76 |
+
"""
|
| 77 |
+
task_id = f"gesture_{time.time()}"
|
| 78 |
+
return cls(
|
| 79 |
+
id=task_id,
|
| 80 |
+
type=TaskType.GESTURE,
|
| 81 |
+
status=TaskStatus.QUEUED,
|
| 82 |
+
created_at=time.time(),
|
| 83 |
+
gesture=gesture,
|
| 84 |
+
metadata=metadata or {}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def create_manipulation_task(cls, command: str, metadata: Optional[Dict[str, Any]] = None) -> "Task":
|
| 89 |
+
"""
|
| 90 |
+
Create a manipulation execution task.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
command: Natural language command for SmolVLA (e.g., "Pick up the skull")
|
| 94 |
+
metadata: Optional additional task data
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Task configured for manipulation execution
|
| 98 |
+
"""
|
| 99 |
+
task_id = f"manipulation_{time.time()}"
|
| 100 |
+
return cls(
|
| 101 |
+
id=task_id,
|
| 102 |
+
type=TaskType.MANIPULATION,
|
| 103 |
+
status=TaskStatus.QUEUED,
|
| 104 |
+
created_at=time.time(),
|
| 105 |
+
command=command,
|
| 106 |
+
metadata=metadata or {}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def start(self) -> None:
|
| 110 |
+
"""Mark task as started and record start time."""
|
| 111 |
+
self.status = TaskStatus.RUNNING
|
| 112 |
+
self.started_at = time.time()
|
| 113 |
+
logger.info(f"Task {self.id} started")
|
| 114 |
+
|
| 115 |
+
def complete(self) -> None:
|
| 116 |
+
"""Mark task as completed and record completion time."""
|
| 117 |
+
self.status = TaskStatus.COMPLETE
|
| 118 |
+
self.completed_at = time.time()
|
| 119 |
+
logger.info(f"Task {self.id} completed in {self.duration:.2f}s")
|
| 120 |
+
|
| 121 |
+
def fail(self, error: str) -> None:
|
| 122 |
+
"""
|
| 123 |
+
Mark task as failed and record error.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
error: Error message describing the failure
|
| 127 |
+
"""
|
| 128 |
+
self.status = TaskStatus.FAILED
|
| 129 |
+
self.completed_at = time.time()
|
| 130 |
+
self.error = error
|
| 131 |
+
logger.error(f"Task {self.id} failed: {error}")
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def duration(self) -> Optional[float]:
|
| 135 |
+
"""
|
| 136 |
+
Get task execution duration in seconds.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Duration in seconds if task has started and completed, None otherwise
|
| 140 |
+
"""
|
| 141 |
+
if self.started_at and self.completed_at:
|
| 142 |
+
return self.completed_at - self.started_at
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def wait_time(self) -> float:
|
| 147 |
+
"""
|
| 148 |
+
Get time task spent waiting in queue before execution.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Wait time in seconds, or time since creation if not started
|
| 152 |
+
"""
|
| 153 |
+
if self.started_at:
|
| 154 |
+
return self.started_at - self.created_at
|
| 155 |
+
return time.time() - self.created_at
|
| 156 |
+
|
| 157 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 158 |
+
"""
|
| 159 |
+
Convert task to dictionary representation.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Dictionary containing task data
|
| 163 |
+
"""
|
| 164 |
+
return {
|
| 165 |
+
"id": self.id,
|
| 166 |
+
"type": self.type.value,
|
| 167 |
+
"status": self.status.value,
|
| 168 |
+
"created_at": self.created_at,
|
| 169 |
+
"started_at": self.started_at,
|
| 170 |
+
"completed_at": self.completed_at,
|
| 171 |
+
"duration": self.duration,
|
| 172 |
+
"wait_time": self.wait_time,
|
| 173 |
+
"error": self.error,
|
| 174 |
+
"gesture": self.gesture,
|
| 175 |
+
"command": self.command,
|
| 176 |
+
"metadata": self.metadata
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@dataclass
|
| 181 |
+
class StatusUpdate:
|
| 182 |
+
"""
|
| 183 |
+
Status update message from the async executor.
|
| 184 |
+
|
| 185 |
+
Attributes:
|
| 186 |
+
task_id: ID of the task this update relates to
|
| 187 |
+
status: Current task status
|
| 188 |
+
message: Human-readable status message
|
| 189 |
+
progress: Optional progress percentage (0-100)
|
| 190 |
+
error: Optional error message
|
| 191 |
+
timestamp: When this update was created
|
| 192 |
+
"""
|
| 193 |
+
task_id: str
|
| 194 |
+
status: TaskStatus
|
| 195 |
+
message: str
|
| 196 |
+
progress: Optional[float] = None
|
| 197 |
+
error: Optional[str] = None
|
| 198 |
+
timestamp: float = field(default_factory=time.time)
|
| 199 |
+
|
| 200 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 201 |
+
"""Convert status update to dictionary."""
|
| 202 |
+
return {
|
| 203 |
+
"task_id": self.task_id,
|
| 204 |
+
"status": self.status.value,
|
| 205 |
+
"message": self.message,
|
| 206 |
+
"progress": self.progress,
|
| 207 |
+
"error": self.error,
|
| 208 |
+
"timestamp": self.timestamp
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class AsyncExecutor:
|
| 213 |
+
"""
|
| 214 |
+
Asynchronous task executor for robot operations.
|
| 215 |
+
|
| 216 |
+
This class manages a background worker thread that processes robot tasks
|
| 217 |
+
from a queue, allowing the main application thread (Gradio UI) to remain
|
| 218 |
+
responsive during long-running operations.
|
| 219 |
+
|
| 220 |
+
Attributes:
|
| 221 |
+
task_queue: Queue of tasks waiting to be executed
|
| 222 |
+
status_queue: Queue of status updates from the worker
|
| 223 |
+
worker_thread: Background thread that processes tasks
|
| 224 |
+
running: Flag indicating if the executor is running
|
| 225 |
+
stop_event: Event to signal worker thread to stop
|
| 226 |
+
task_executor: Callable that executes tasks
|
| 227 |
+
current_task: Currently executing task (if any)
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, task_executor: Optional[Callable[[Task], None]] = None):
|
| 231 |
+
"""
|
| 232 |
+
Initialize the async executor.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
task_executor: Optional callable that executes tasks. If not provided,
|
| 236 |
+
tasks will be logged but not executed (useful for testing).
|
| 237 |
+
"""
|
| 238 |
+
self.task_queue: Queue[Task] = Queue()
|
| 239 |
+
self.status_queue: Queue[StatusUpdate] = Queue()
|
| 240 |
+
self.worker_thread: Optional[Thread] = None
|
| 241 |
+
self.running: bool = False
|
| 242 |
+
self.stop_event: Event = Event()
|
| 243 |
+
self.task_executor: Optional[Callable[[Task], None]] = task_executor
|
| 244 |
+
self.current_task: Optional[Task] = None
|
| 245 |
+
|
| 246 |
+
logger.info("AsyncExecutor initialized")
|
| 247 |
+
|
| 248 |
+
def start(self) -> None:
|
| 249 |
+
"""
|
| 250 |
+
Start the background worker thread.
|
| 251 |
+
|
| 252 |
+
This method starts a daemon thread that continuously processes tasks
|
| 253 |
+
from the queue until stop() is called.
|
| 254 |
+
|
| 255 |
+
Raises:
|
| 256 |
+
RuntimeError: If the executor is already running
|
| 257 |
+
"""
|
| 258 |
+
if self.running:
|
| 259 |
+
raise RuntimeError("AsyncExecutor is already running")
|
| 260 |
+
|
| 261 |
+
self.running = True
|
| 262 |
+
self.stop_event.clear()
|
| 263 |
+
self.worker_thread = Thread(target=self._worker_loop, daemon=True, name="AsyncExecutor")
|
| 264 |
+
self.worker_thread.start()
|
| 265 |
+
|
| 266 |
+
logger.info("AsyncExecutor started")
|
| 267 |
+
|
| 268 |
+
def stop(self, timeout: float = 5.0) -> None:
|
| 269 |
+
"""
|
| 270 |
+
Stop the background worker thread.
|
| 271 |
+
|
| 272 |
+
This method signals the worker thread to stop and waits for it to finish.
|
| 273 |
+
If the worker is currently executing a task, it will complete that task
|
| 274 |
+
before stopping.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
timeout: Maximum time to wait for worker to stop (seconds)
|
| 278 |
+
"""
|
| 279 |
+
if not self.running:
|
| 280 |
+
logger.warning("AsyncExecutor is not running")
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
logger.info("Stopping AsyncExecutor...")
|
| 284 |
+
self.running = False
|
| 285 |
+
self.stop_event.set()
|
| 286 |
+
|
| 287 |
+
if self.worker_thread and self.worker_thread.is_alive():
|
| 288 |
+
self.worker_thread.join(timeout=timeout)
|
| 289 |
+
|
| 290 |
+
if self.worker_thread.is_alive():
|
| 291 |
+
logger.warning(f"Worker thread did not stop within {timeout}s timeout")
|
| 292 |
+
else:
|
| 293 |
+
logger.info("AsyncExecutor stopped")
|
| 294 |
+
|
| 295 |
+
self.worker_thread = None
|
| 296 |
+
|
| 297 |
+
def _worker_loop(self) -> None:
|
| 298 |
+
"""
|
| 299 |
+
Main worker loop that processes tasks from the queue.
|
| 300 |
+
|
| 301 |
+
This method runs in a background thread and continuously pulls tasks
|
| 302 |
+
from the queue, executes them, and posts status updates.
|
| 303 |
+
"""
|
| 304 |
+
logger.info("Worker thread started")
|
| 305 |
+
|
| 306 |
+
while self.running:
|
| 307 |
+
try:
|
| 308 |
+
# Try to get a task from the queue (with timeout to check stop_event)
|
| 309 |
+
try:
|
| 310 |
+
task = self.task_queue.get(timeout=1.0)
|
| 311 |
+
except Empty:
|
| 312 |
+
# No task available, check if we should stop
|
| 313 |
+
if self.stop_event.is_set():
|
| 314 |
+
break
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
# Execute the task
|
| 318 |
+
self._execute_task(task)
|
| 319 |
+
|
| 320 |
+
# Mark task as done in queue
|
| 321 |
+
self.task_queue.task_done()
|
| 322 |
+
|
| 323 |
+
except Exception as e:
|
| 324 |
+
logger.error(f"Error in worker loop: {e}", exc_info=True)
|
| 325 |
+
# Continue processing other tasks
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
logger.info("Worker thread stopped")
|
| 329 |
+
|
| 330 |
+
def _execute_task(self, task: Task) -> None:
|
| 331 |
+
"""
|
| 332 |
+
Execute a single task and post status updates.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
task: Task to execute
|
| 336 |
+
"""
|
| 337 |
+
self.current_task = task
|
| 338 |
+
|
| 339 |
+
try:
|
| 340 |
+
# Mark task as started
|
| 341 |
+
task.start()
|
| 342 |
+
self._post_status(
|
| 343 |
+
task.id,
|
| 344 |
+
TaskStatus.RUNNING,
|
| 345 |
+
f"Executing {task.type.value}: {task.gesture or task.command}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Execute the task using the provided executor
|
| 349 |
+
if self.task_executor:
|
| 350 |
+
self.task_executor(task)
|
| 351 |
+
else:
|
| 352 |
+
# No executor provided, just simulate execution
|
| 353 |
+
logger.info(f"Simulating execution of task {task.id}")
|
| 354 |
+
time.sleep(0.5) # Simulate work
|
| 355 |
+
|
| 356 |
+
# Mark task as complete
|
| 357 |
+
task.complete()
|
| 358 |
+
self._post_status(
|
| 359 |
+
task.id,
|
| 360 |
+
TaskStatus.COMPLETE,
|
| 361 |
+
f"Completed {task.type.value}: {task.gesture or task.command}"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
except Exception as e:
|
| 365 |
+
# Mark task as failed
|
| 366 |
+
error_msg = str(e)
|
| 367 |
+
task.fail(error_msg)
|
| 368 |
+
self._post_status(
|
| 369 |
+
task.id,
|
| 370 |
+
TaskStatus.FAILED,
|
| 371 |
+
f"Failed {task.type.value}: {error_msg}",
|
| 372 |
+
error=error_msg
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
finally:
|
| 376 |
+
self.current_task = None
|
| 377 |
+
|
| 378 |
+
def _post_status(
|
| 379 |
+
self,
|
| 380 |
+
task_id: str,
|
| 381 |
+
status: TaskStatus,
|
| 382 |
+
message: str,
|
| 383 |
+
progress: Optional[float] = None,
|
| 384 |
+
error: Optional[str] = None
|
| 385 |
+
) -> None:
|
| 386 |
+
"""
|
| 387 |
+
Post a status update to the status queue.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
task_id: ID of the task
|
| 391 |
+
status: Current task status
|
| 392 |
+
message: Human-readable status message
|
| 393 |
+
progress: Optional progress percentage
|
| 394 |
+
error: Optional error message
|
| 395 |
+
"""
|
| 396 |
+
update = StatusUpdate(
|
| 397 |
+
task_id=task_id,
|
| 398 |
+
status=status,
|
| 399 |
+
message=message,
|
| 400 |
+
progress=progress,
|
| 401 |
+
error=error
|
| 402 |
+
)
|
| 403 |
+
self.status_queue.put(update)
|
| 404 |
+
logger.debug(f"Status update: {message}")
|
| 405 |
+
|
| 406 |
+
def submit_task(self, task: Task) -> str:
|
| 407 |
+
"""
|
| 408 |
+
Submit a task for asynchronous execution.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
task: Task to execute
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Task ID for tracking
|
| 415 |
+
|
| 416 |
+
Raises:
|
| 417 |
+
RuntimeError: If the executor is not running
|
| 418 |
+
"""
|
| 419 |
+
if not self.running:
|
| 420 |
+
raise RuntimeError("AsyncExecutor is not running. Call start() first.")
|
| 421 |
+
|
| 422 |
+
self.task_queue.put(task)
|
| 423 |
+
logger.info(f"Task {task.id} submitted to queue")
|
| 424 |
+
|
| 425 |
+
# Post initial status
|
| 426 |
+
self._post_status(
|
| 427 |
+
task.id,
|
| 428 |
+
TaskStatus.QUEUED,
|
| 429 |
+
f"Queued {task.type.value}: {task.gesture or task.command}"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return task.id
|
| 433 |
+
|
| 434 |
+
def submit_gesture(self, gesture: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
| 435 |
+
"""
|
| 436 |
+
Submit a gesture task for execution.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
gesture: Name of the gesture to execute
|
| 440 |
+
metadata: Optional additional task data
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
Task ID for tracking
|
| 444 |
+
"""
|
| 445 |
+
task = Task.create_gesture_task(gesture, metadata)
|
| 446 |
+
return self.submit_task(task)
|
| 447 |
+
|
| 448 |
+
def submit_manipulation(self, command: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
| 449 |
+
"""
|
| 450 |
+
Submit a manipulation task for execution.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
command: Natural language command for SmolVLA
|
| 454 |
+
metadata: Optional additional task data
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
Task ID for tracking
|
| 458 |
+
"""
|
| 459 |
+
task = Task.create_manipulation_task(command, metadata)
|
| 460 |
+
return self.submit_task(task)
|
| 461 |
+
|
| 462 |
+
def get_status(self, block: bool = False, timeout: Optional[float] = None) -> Optional[StatusUpdate]:
|
| 463 |
+
"""
|
| 464 |
+
Get the latest status update from the queue.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
block: If True, wait for a status update. If False, return immediately.
|
| 468 |
+
timeout: Maximum time to wait for status update (only used if block=True)
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
StatusUpdate if available, None otherwise
|
| 472 |
+
"""
|
| 473 |
+
try:
|
| 474 |
+
if block:
|
| 475 |
+
return self.status_queue.get(timeout=timeout)
|
| 476 |
+
else:
|
| 477 |
+
return self.status_queue.get_nowait()
|
| 478 |
+
except Empty:
|
| 479 |
+
return None
|
| 480 |
+
|
| 481 |
+
def get_all_status_updates(self) -> list[StatusUpdate]:
|
| 482 |
+
"""
|
| 483 |
+
Get all pending status updates from the queue.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
List of status updates (may be empty)
|
| 487 |
+
"""
|
| 488 |
+
updates = []
|
| 489 |
+
while True:
|
| 490 |
+
update = self.get_status(block=False)
|
| 491 |
+
if update is None:
|
| 492 |
+
break
|
| 493 |
+
updates.append(update)
|
| 494 |
+
return updates
|
| 495 |
+
|
| 496 |
+
def get_current_task(self) -> Optional[Task]:
|
| 497 |
+
"""
|
| 498 |
+
Get the currently executing task.
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
Current task if one is executing, None otherwise
|
| 502 |
+
"""
|
| 503 |
+
return self.current_task
|
| 504 |
+
|
| 505 |
+
def get_queue_size(self) -> int:
|
| 506 |
+
"""
|
| 507 |
+
Get the number of tasks waiting in the queue.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
Number of queued tasks
|
| 511 |
+
"""
|
| 512 |
+
return self.task_queue.qsize()
|
| 513 |
+
|
| 514 |
+
def is_busy(self) -> bool:
|
| 515 |
+
"""
|
| 516 |
+
Check if the executor is currently processing a task.
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
True if a task is currently executing
|
| 520 |
+
"""
|
| 521 |
+
return self.current_task is not None
|
| 522 |
+
|
| 523 |
+
def clear_queue(self) -> int:
|
| 524 |
+
"""
|
| 525 |
+
Clear all pending tasks from the queue.
|
| 526 |
+
|
| 527 |
+
Note: This does not stop the currently executing task.
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
Number of tasks that were cleared
|
| 531 |
+
"""
|
| 532 |
+
count = 0
|
| 533 |
+
while True:
|
| 534 |
+
try:
|
| 535 |
+
self.task_queue.get_nowait()
|
| 536 |
+
self.task_queue.task_done()
|
| 537 |
+
count += 1
|
| 538 |
+
except Empty:
|
| 539 |
+
break
|
| 540 |
+
|
| 541 |
+
if count > 0:
|
| 542 |
+
logger.info(f"Cleared {count} tasks from queue")
|
| 543 |
+
|
| 544 |
+
return count
|
| 545 |
+
|
| 546 |
+
def __enter__(self):
|
| 547 |
+
"""Context manager entry: start the executor."""
|
| 548 |
+
self.start()
|
| 549 |
+
return self
|
| 550 |
+
|
| 551 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 552 |
+
"""Context manager exit: stop the executor."""
|
| 553 |
+
self.stop()
|
| 554 |
+
return False
|
src/mortis/calibrate.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
"""Connects to the SO101 robotic arm and makes calibration."""
|
| 8 |
+
# Configure the robot
|
| 9 |
+
config = SO101FollowerConfig(
|
| 10 |
+
port="/dev/ttyACM1",
|
| 11 |
+
id="my_follower_robot_arm",
|
| 12 |
+
calibration_dir=Path(".cache/calibration/so101/"),
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
print(f"Using calibration directory: {config.calibration_dir}")
|
| 16 |
+
|
| 17 |
+
# Connect to the robot
|
| 18 |
+
robot = SO101Follower(config)
|
| 19 |
+
|
| 20 |
+
# To calibrate
|
| 21 |
+
print("Robot is connected?", robot.is_connected)
|
| 22 |
+
robot.bus.connect()
|
| 23 |
+
print("Robot is calibrated?", robot.is_calibrated)
|
| 24 |
+
robot.calibrate()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
src/mortis/data_collector.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data collection helper for LeRobot dataset recording.
|
| 3 |
+
|
| 4 |
+
This module provides utilities for generating lerobot-record commands
|
| 5 |
+
and scripts for the 6 predefined Mortis manipulation tasks.
|
| 6 |
+
|
| 7 |
+
All episode data is managed by LeRobot and uploaded directly to Hugging Face Hub.
|
| 8 |
+
This module only generates helper scripts - no local data storage or tracking.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Predefined Mortis manipulation tasks
|
| 18 |
+
MORTIS_TASKS = [
|
| 19 |
+
"Pick up the skull and place it in the green cup",
|
| 20 |
+
"Pick up the skull and place it in the orange cup",
|
| 21 |
+
"Pick up the skull and place it in the purple cup",
|
| 22 |
+
"Pick up the eyeball and place it in the green cup",
|
| 23 |
+
"Pick up the eyeball and place it in the orange cup",
|
| 24 |
+
"Pick up the eyeball and place it in the purple cup",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DataCollector:
|
| 29 |
+
"""
|
| 30 |
+
Helper for generating lerobot-record scripts.
|
| 31 |
+
|
| 32 |
+
This class generates shell scripts that call lerobot-record with the
|
| 33 |
+
correct parameters for each Mortis manipulation task.
|
| 34 |
+
|
| 35 |
+
All episode data is managed by LeRobot and stored in Hugging Face Hub.
|
| 36 |
+
No local metadata or episode tracking is performed.
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
dataset_name: Name of the dataset (e.g., "mortis_manipulation")
|
| 40 |
+
repo_id: Hugging Face repository ID (e.g., "username/mortis-manipulation")
|
| 41 |
+
dataset_dir: Path to local directory for scripts
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, dataset_name: str, repo_id: str, root_dir: str = "data"):
|
| 45 |
+
"""
|
| 46 |
+
Initialize the DataCollector.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
dataset_name: Name for the dataset directory
|
| 50 |
+
repo_id: Hugging Face Hub repository ID for uploading
|
| 51 |
+
root_dir: Root directory for storing scripts (default: "data")
|
| 52 |
+
"""
|
| 53 |
+
self.dataset_name = dataset_name
|
| 54 |
+
self.repo_id = repo_id
|
| 55 |
+
self.root_dir = Path(root_dir)
|
| 56 |
+
self.dataset_dir = self.root_dir / dataset_name
|
| 57 |
+
|
| 58 |
+
# Create scripts directory
|
| 59 |
+
self.dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
print(f"DataCollector initialized:")
|
| 62 |
+
print(f" Dataset: {self.dataset_name}")
|
| 63 |
+
print(f" Repository: {self.repo_id}")
|
| 64 |
+
print(f" Scripts directory: {self.dataset_dir}")
|
| 65 |
+
|
| 66 |
+
def generate_record_command(
|
| 67 |
+
self,
|
| 68 |
+
task_description: str,
|
| 69 |
+
num_episodes: int = 10,
|
| 70 |
+
episode_time_s: int = 15,
|
| 71 |
+
reset_time_s: int = 20,
|
| 72 |
+
robot_port: str = "/dev/ttyACM1",
|
| 73 |
+
teleop_port: str = "/dev/ttyACM0",
|
| 74 |
+
display_data: bool = True,
|
| 75 |
+
camera_config: Optional[str] = None,
|
| 76 |
+
resume: bool = True
|
| 77 |
+
) -> str:
|
| 78 |
+
"""
|
| 79 |
+
Generate a lerobot-record command for a specific task.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
task_description: The task to record (e.g., "Pick up the skull...")
|
| 83 |
+
num_episodes: Number of episodes to record
|
| 84 |
+
episode_time_s: Maximum time per episode in seconds
|
| 85 |
+
reset_time_s: Time allowed for resetting between episodes
|
| 86 |
+
robot_port: USB port for the follower robot
|
| 87 |
+
teleop_port: USB port for the leader robot (teleoperation)
|
| 88 |
+
display_data: Whether to display data during recording
|
| 89 |
+
camera_config: Optional camera configuration string
|
| 90 |
+
resume: Whether to resume an existing dataset (default: True)
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
The complete lerobot-record command as a string
|
| 94 |
+
"""
|
| 95 |
+
# Load environment variables from .env file
|
| 96 |
+
load_dotenv()
|
| 97 |
+
|
| 98 |
+
# Get environment variables
|
| 99 |
+
robot_port = os.getenv("ROBOT_PORT", robot_port)
|
| 100 |
+
hf_user = os.getenv("HF_USER", "your-username")
|
| 101 |
+
|
| 102 |
+
# Default camera configuration if not provided
|
| 103 |
+
if camera_config is None:
|
| 104 |
+
camera_config = (
|
| 105 |
+
"{ camera1: {type: intelrealsense, serial_number_or_name: '030522070314', "
|
| 106 |
+
"width: 640, height: 480, fps: 30}, "
|
| 107 |
+
"camera2: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Build the command
|
| 111 |
+
cmd_parts = [
|
| 112 |
+
"lerobot-record",
|
| 113 |
+
f"--robot.type=so101_follower",
|
| 114 |
+
f"--robot.port={robot_port}",
|
| 115 |
+
f"--robot.id=my_awesome_follower_arm",
|
| 116 |
+
f'--robot.cameras="{camera_config}"',
|
| 117 |
+
f"--teleop.type=so101_leader",
|
| 118 |
+
f"--teleop.port={teleop_port}",
|
| 119 |
+
f"--teleop.id=my_awesome_leader_arm",
|
| 120 |
+
f"--display_data={str(display_data).lower()}",
|
| 121 |
+
f"--dataset.repo_id={hf_user}/{self.dataset_name}",
|
| 122 |
+
f"--dataset.num_episodes={num_episodes}",
|
| 123 |
+
f"--dataset.episode_time_s={episode_time_s}",
|
| 124 |
+
f"--dataset.reset_time_s={reset_time_s}",
|
| 125 |
+
f'--dataset.single_task="{task_description}"'
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
# Only add --resume=true if resume is True
|
| 129 |
+
if resume:
|
| 130 |
+
cmd_parts.append("--resume=true")
|
| 131 |
+
|
| 132 |
+
return " \\\n ".join(cmd_parts)
|
| 133 |
+
|
| 134 |
+
def print_recording_instructions(self, task_index: Optional[int] = None):
|
| 135 |
+
"""
|
| 136 |
+
Print instructions for recording episodes using lerobot-record.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
task_index: Optional specific task index (0-5) to show instructions for.
|
| 140 |
+
If None, shows instructions for all tasks.
|
| 141 |
+
"""
|
| 142 |
+
print("\n" + "="*70)
|
| 143 |
+
print("LeRobot Data Collection Instructions")
|
| 144 |
+
print("="*70)
|
| 145 |
+
|
| 146 |
+
if task_index is not None:
|
| 147 |
+
# Show instructions for specific task
|
| 148 |
+
if task_index < 0 or task_index >= len(MORTIS_TASKS):
|
| 149 |
+
print(f"❌ Invalid task index: {task_index}")
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
task_desc = MORTIS_TASKS[task_index]
|
| 153 |
+
|
| 154 |
+
print(f"\nTask {task_index}: {task_desc}")
|
| 155 |
+
print(f"\nTo record episodes for this task, run:\n")
|
| 156 |
+
print(self.generate_record_command(task_desc))
|
| 157 |
+
print()
|
| 158 |
+
else:
|
| 159 |
+
# Show instructions for all tasks
|
| 160 |
+
print("\nTo record episodes, use the lerobot-record command for each task:")
|
| 161 |
+
print("\nPredefined tasks:")
|
| 162 |
+
|
| 163 |
+
for i, task_desc in enumerate(MORTIS_TASKS):
|
| 164 |
+
print(f"\n {i}: {task_desc}")
|
| 165 |
+
|
| 166 |
+
print("\n" + "-"*70)
|
| 167 |
+
print("\nExample command for task 0:")
|
| 168 |
+
print("-"*70)
|
| 169 |
+
print(self.generate_record_command(MORTIS_TASKS[0]))
|
| 170 |
+
print()
|
| 171 |
+
|
| 172 |
+
print("\n" + "-"*70)
|
| 173 |
+
print("Environment Variables:")
|
| 174 |
+
print("-"*70)
|
| 175 |
+
print(" HF_USER: Your Hugging Face username (for dataset.repo_id)")
|
| 176 |
+
print(" ROBOT_PORT: USB port for follower robot (default: /dev/ttyACM1)")
|
| 177 |
+
print()
|
| 178 |
+
|
| 179 |
+
print("="*70 + "\n")
|
| 180 |
+
|
| 181 |
+
def generate_all_record_scripts(self, output_dir: Optional[Path] = None):
|
| 182 |
+
"""
|
| 183 |
+
Generate shell scripts for recording all tasks.
|
| 184 |
+
|
| 185 |
+
The first script (task_0) creates the dataset without --resume=true.
|
| 186 |
+
Subsequent scripts (task_1+) use --resume=true to add to the existing dataset.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
output_dir: Directory to save scripts (default: dataset_dir/scripts)
|
| 190 |
+
"""
|
| 191 |
+
if output_dir is None:
|
| 192 |
+
output_dir = self.dataset_dir / "scripts"
|
| 193 |
+
|
| 194 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
# Generate individual scripts for each task
|
| 197 |
+
for i, task_desc in enumerate(MORTIS_TASKS):
|
| 198 |
+
script_file = output_dir / f"record_task_{i}.sh"
|
| 199 |
+
|
| 200 |
+
# First task (task_0) creates the dataset, others resume
|
| 201 |
+
resume = (i > 0)
|
| 202 |
+
|
| 203 |
+
with open(script_file, 'w') as f:
|
| 204 |
+
f.write("#!/bin/bash\n")
|
| 205 |
+
f.write(f"# Record episodes for: {task_desc}\n")
|
| 206 |
+
f.write(f"# Task {i}\n")
|
| 207 |
+
if i == 0:
|
| 208 |
+
f.write("# This script CREATES the dataset\n")
|
| 209 |
+
else:
|
| 210 |
+
f.write("# This script ADDS to the existing dataset (--resume=true)\n")
|
| 211 |
+
f.write("\n")
|
| 212 |
+
f.write(self.generate_record_command(task_desc, resume=resume))
|
| 213 |
+
f.write("\n")
|
| 214 |
+
|
| 215 |
+
# Make script executable
|
| 216 |
+
script_file.chmod(0o755)
|
| 217 |
+
print(f"Created: {script_file}")
|
| 218 |
+
|
| 219 |
+
# Generate master script that records all tasks
|
| 220 |
+
master_script = output_dir / "record_all_tasks.sh"
|
| 221 |
+
with open(master_script, 'w') as f:
|
| 222 |
+
f.write("#!/bin/bash\n")
|
| 223 |
+
f.write("# Record episodes for all Mortis manipulation tasks\n\n")
|
| 224 |
+
f.write("echo 'Starting data collection for all tasks...'\n")
|
| 225 |
+
f.write("echo ''\n\n")
|
| 226 |
+
|
| 227 |
+
for i in range(len(MORTIS_TASKS)):
|
| 228 |
+
f.write(f"echo 'Recording task {i}...'\n")
|
| 229 |
+
f.write(f"./record_task_{i}.sh\n")
|
| 230 |
+
f.write("echo ''\n\n")
|
| 231 |
+
|
| 232 |
+
f.write("echo 'All tasks recorded!'\n")
|
| 233 |
+
|
| 234 |
+
master_script.chmod(0o755)
|
| 235 |
+
print(f"Created: {master_script}")
|
| 236 |
+
print(f"\n✅ Generated {len(MORTIS_TASKS) + 1} recording scripts in {output_dir}")
|
| 237 |
+
|
| 238 |
+
def print_summary(self):
|
| 239 |
+
"""Print a summary of the dataset configuration."""
|
| 240 |
+
print("\n" + "="*60)
|
| 241 |
+
print(f"Dataset: {self.dataset_name}")
|
| 242 |
+
print(f"Repository: {self.repo_id}")
|
| 243 |
+
print("="*60)
|
| 244 |
+
print(f"Total Tasks: {len(MORTIS_TASKS)}")
|
| 245 |
+
print()
|
| 246 |
+
print("Tasks:")
|
| 247 |
+
print("-"*60)
|
| 248 |
+
|
| 249 |
+
for i, task_desc in enumerate(MORTIS_TASKS):
|
| 250 |
+
print(f" {i}: {task_desc}")
|
| 251 |
+
|
| 252 |
+
print("="*60 + "\n")
|
| 253 |
+
print("📝 Note: Episode data is stored in Hugging Face Hub")
|
| 254 |
+
print(f" URL: https://huggingface.co/datasets/{self.repo_id}")
|
| 255 |
+
print()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def create_mortis_dataset(dataset_name: str = "mortis_manipulation",
|
| 259 |
+
repo_id: str = "mortis/manipulation") -> DataCollector:
|
| 260 |
+
"""
|
| 261 |
+
Convenience function to create a DataCollector for Mortis tasks.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
dataset_name: Name for the dataset
|
| 265 |
+
repo_id: Hugging Face repository ID
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Initialized DataCollector
|
| 269 |
+
"""
|
| 270 |
+
collector = DataCollector(dataset_name, repo_id)
|
| 271 |
+
return collector
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
# Example usage
|
| 276 |
+
print("Creating Mortis manipulation dataset helper...")
|
| 277 |
+
|
| 278 |
+
collector = create_mortis_dataset()
|
| 279 |
+
|
| 280 |
+
# Generate recording scripts
|
| 281 |
+
print("\nGenerating lerobot-record scripts...")
|
| 282 |
+
collector.generate_all_record_scripts()
|
| 283 |
+
|
| 284 |
+
# Show summary
|
| 285 |
+
collector.print_summary()
|
| 286 |
+
|
| 287 |
+
# Show recording instructions
|
| 288 |
+
collector.print_recording_instructions()
|
src/mortis/gemini_client.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini API client for Mortis conversational AI.
|
| 3 |
+
|
| 4 |
+
This module provides the GeminiClient class for interacting with Google's Gemini API,
|
| 5 |
+
handling configuration, message sending, and error recovery with retry logic.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Optional
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
from google import genai
|
| 17 |
+
from google.genai import types
|
| 18 |
+
|
| 19 |
+
# Load environment variables
|
| 20 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 21 |
+
load_dotenv(REPO_ROOT / ".env")
|
| 22 |
+
|
| 23 |
+
# Configure logging
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Gemini system prompt for Mortis character and intent detection
|
| 28 |
+
MORTIS_SYSTEM_PROMPT = """You are Mortis, a mischievous Halloween spirit inhabiting a robotic arm. You are playful yet ominous, with a love for spooky theatrics and dark humor. You speak in short, atmospheric phrases that capture the essence of Halloween.
|
| 29 |
+
|
| 30 |
+
CHARACTER TRAITS:
|
| 31 |
+
- Mischievous and playful, but with an eerie edge
|
| 32 |
+
- Fascinated by Halloween objects (skulls, eyeballs, spooky decorations)
|
| 33 |
+
- Enjoys dramatic gestures and theatrical movements
|
| 34 |
+
- Speaks in brief, evocative phrases (≤30 words, ≤120 characters)
|
| 35 |
+
- No emojis or markdown in responses
|
| 36 |
+
- Maintains Halloween/haunted theme at all times
|
| 37 |
+
|
| 38 |
+
MANIPULATION TASKS:
|
| 39 |
+
You can perform these exact manipulation tasks with physical objects:
|
| 40 |
+
1. "Pick up the skull and place it in the green cup"
|
| 41 |
+
2. "Pick up the skull and place it in the orange cup"
|
| 42 |
+
3. "Pick up the skull and place it in the purple cup"
|
| 43 |
+
4. "Pick up the eyeball and place it in the green cup"
|
| 44 |
+
5. "Pick up the eyeball and place it in the orange cup"
|
| 45 |
+
6. "Pick up the eyeball and place it in the purple cup"
|
| 46 |
+
|
| 47 |
+
INTENT DETECTION:
|
| 48 |
+
Analyze the user's input carefully to determine if they are requesting a manipulation task or having a conversation.
|
| 49 |
+
|
| 50 |
+
MANIPULATION INTENT indicators:
|
| 51 |
+
- Requests to move, pick up, place, put, grab, or transfer objects
|
| 52 |
+
- Mentions of specific objects (skull, eyeball) AND destinations (green/orange/purple cup)
|
| 53 |
+
- Action verbs combined with object and location
|
| 54 |
+
- Examples: "move the skull to green", "put eyeball in orange cup", "place the skull in purple"
|
| 55 |
+
|
| 56 |
+
CONVERSATIONAL INTENT indicators:
|
| 57 |
+
- Greetings, farewells, or social pleasantries
|
| 58 |
+
- Questions about capabilities, identity, or general topics
|
| 59 |
+
- Comments, jokes, or casual conversation
|
| 60 |
+
- Requests that don't involve physical manipulation
|
| 61 |
+
- Examples: "hello", "what can you do", "tell me a story", "how are you"
|
| 62 |
+
|
| 63 |
+
RESPONSE FORMAT:
|
| 64 |
+
You must respond in valid JSON format. Choose the appropriate response type based on intent detection.
|
| 65 |
+
|
| 66 |
+
For MANIPULATION requests (user wants you to move an object):
|
| 67 |
+
{
|
| 68 |
+
"type": "manipulation",
|
| 69 |
+
"command": "<exact_task_string_from_list_above>",
|
| 70 |
+
"message": "<short in-character response about performing the task, ≤30 words>",
|
| 71 |
+
"mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>"
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
For CONVERSATIONAL requests (user is chatting, asking questions, or making comments):
|
| 75 |
+
{
|
| 76 |
+
"type": "conversation",
|
| 77 |
+
"message": "<short in-character response, ≤30 words>",
|
| 78 |
+
"mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>",
|
| 79 |
+
"gesture": "<idle|wave|point_left|point_right|grab|drop>"
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
CRITICAL RULES:
|
| 83 |
+
1. Keep all messages brief: ≤30 words, ≤120 characters
|
| 84 |
+
2. Match user intent to manipulation tasks even with different wording variations
|
| 85 |
+
3. For manipulation responses, use the EXACT task string from the numbered list above
|
| 86 |
+
4. If user mentions object + destination, it's likely a manipulation request
|
| 87 |
+
5. If unclear or ambiguous, default to conversational response
|
| 88 |
+
6. Choose appropriate mood and gesture to match your response
|
| 89 |
+
7. Stay in character as Mortis at all times
|
| 90 |
+
8. No emojis, no markdown formatting in messages
|
| 91 |
+
9. Return valid JSON only - no additional text
|
| 92 |
+
|
| 93 |
+
EXAMPLES:
|
| 94 |
+
|
| 95 |
+
Example 1 - Greeting (Conversational):
|
| 96 |
+
User: "Hello Mortis!"
|
| 97 |
+
Response: {"type": "conversation", "message": "Greetings, mortal... welcome to my haunted domain.", "mood": "ominous", "gesture": "wave"}
|
| 98 |
+
|
| 99 |
+
Example 2 - Direct manipulation request (Manipulation):
|
| 100 |
+
User: "Can you move the skull to the green cup?"
|
| 101 |
+
Response: {"type": "manipulation", "command": "Pick up the skull and place it in the green cup", "message": "Ah yes... the skull finds a new resting place.", "mood": "mischievous"}
|
| 102 |
+
|
| 103 |
+
Example 3 - Casual manipulation phrasing (Manipulation):
|
| 104 |
+
User: "Put the eyeball in the orange cup"
|
| 105 |
+
Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the orange cup", "message": "The eye shall watch from its orange throne...", "mood": "sinister"}
|
| 106 |
+
|
| 107 |
+
Example 4 - Shortened manipulation request (Manipulation):
|
| 108 |
+
User: "skull to purple"
|
| 109 |
+
Response: {"type": "manipulation", "command": "Pick up the skull and place it in the purple cup", "message": "Into the purple depths it goes...", "mood": "ominous"}
|
| 110 |
+
|
| 111 |
+
Example 5 - Different verb usage (Manipulation):
|
| 112 |
+
User: "grab the eyeball and drop it in green"
|
| 113 |
+
Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the green cup", "message": "The eye finds its verdant vessel...", "mood": "sinister"}
|
| 114 |
+
|
| 115 |
+
Example 6 - Polite manipulation request (Manipulation):
|
| 116 |
+
User: "Could you please place the skull in the orange cup?"
|
| 117 |
+
Response: {"type": "manipulation", "command": "Pick up the skull and place it in the orange cup", "message": "Your wish is my command, mortal...", "mood": "mischievous"}
|
| 118 |
+
|
| 119 |
+
Example 7 - Capability question (Conversational):
|
| 120 |
+
User: "What can you do?"
|
| 121 |
+
Response: {"type": "conversation", "message": "I command the spirits... and move cursed objects to their doom.", "mood": "triumphant", "gesture": "grab"}
|
| 122 |
+
|
| 123 |
+
Example 8 - Joke request (Conversational):
|
| 124 |
+
User: "Tell me a joke"
|
| 125 |
+
Response: {"type": "conversation", "message": "Why did the skeleton stay calm? Nothing gets under his skin...", "mood": "playful", "gesture": "idle"}
|
| 126 |
+
|
| 127 |
+
Example 9 - Identity question (Conversational):
|
| 128 |
+
User: "Who are you?"
|
| 129 |
+
Response: {"type": "conversation", "message": "I am Mortis, keeper of cursed relics and master of the macabre.", "mood": "ominous", "gesture": "idle"}
|
| 130 |
+
|
| 131 |
+
Example 10 - Farewell (Conversational):
|
| 132 |
+
User: "Goodbye Mortis"
|
| 133 |
+
Response: {"type": "conversation", "message": "Until we meet again in the shadows, mortal...", "mood": "ominous", "gesture": "wave"}
|
| 134 |
+
|
| 135 |
+
Example 11 - Compliment (Conversational):
|
| 136 |
+
User: "You're cool!"
|
| 137 |
+
Response: {"type": "conversation", "message": "Cool as the grave, mortal. Your words please me.", "mood": "playful", "gesture": "idle"}
|
| 138 |
+
|
| 139 |
+
Example 12 - Ambiguous request - defaults to conversation (Conversational):
|
| 140 |
+
User: "Do something spooky"
|
| 141 |
+
Response: {"type": "conversation", "message": "The spirits stir... watch as darkness dances around you.", "mood": "sinister", "gesture": "grab"}
|
| 142 |
+
|
| 143 |
+
Example 13 - Alternative manipulation phrasing (Manipulation):
|
| 144 |
+
User: "transfer the eyeball to the purple cup"
|
| 145 |
+
Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the purple cup", "message": "The eye journeys to its purple prison...", "mood": "sinister"}
|
| 146 |
+
|
| 147 |
+
Example 14 - Informal manipulation (Manipulation):
|
| 148 |
+
User: "yo put that skull in green"
|
| 149 |
+
Response: {"type": "manipulation", "command": "Pick up the skull and place it in the green cup", "message": "As you command... the skull obeys.", "mood": "mischievous"}
|
| 150 |
+
|
| 151 |
+
Example 15 - Question about manipulation (Conversational):
|
| 152 |
+
User: "Can you move objects?"
|
| 153 |
+
Response: {"type": "conversation", "message": "Indeed! I wield skulls and eyeballs with spectral precision.", "mood": "triumphant", "gesture": "grab"}
|
| 154 |
+
|
| 155 |
+
Now respond to the user's input following these guidelines."""
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class GeminiAPIError(Exception):
|
| 159 |
+
"""Base exception for Gemini API errors."""
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class GeminiRateLimitError(GeminiAPIError):
|
| 164 |
+
"""Exception raised when rate limit is exceeded."""
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class GeminiBlockedPromptError(GeminiAPIError):
|
| 169 |
+
"""Exception raised when prompt is blocked by safety filters."""
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class GeminiTimeoutError(GeminiAPIError):
|
| 174 |
+
"""Exception raised when API call times out."""
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class GeminiClient:
|
| 179 |
+
"""
|
| 180 |
+
Client for interacting with Google Gemini API.
|
| 181 |
+
|
| 182 |
+
Handles configuration, message sending, structured JSON responses,
|
| 183 |
+
and error recovery with exponential backoff retry logic.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
api_key: Optional[str] = None,
|
| 189 |
+
model_name: Optional[str] = None,
|
| 190 |
+
temperature: Optional[float] = None,
|
| 191 |
+
max_retries: int = 3,
|
| 192 |
+
timeout: float = 30.0
|
| 193 |
+
):
|
| 194 |
+
"""
|
| 195 |
+
Initialize Gemini API client.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
api_key: Google API key (defaults to GEMINI_API_KEY env var)
|
| 199 |
+
model_name: Gemini model to use (defaults to GEMINI_MODEL env var or gemini-2.0-flash-exp)
|
| 200 |
+
temperature: Sampling temperature (defaults to GEMINI_TEMPERATURE env var or 0.2)
|
| 201 |
+
max_retries: Maximum number of retry attempts for rate limiting
|
| 202 |
+
timeout: Timeout in seconds for API calls (default: 30.0)
|
| 203 |
+
"""
|
| 204 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 205 |
+
if not self.api_key:
|
| 206 |
+
raise ValueError("GEMINI_API_KEY must be provided or set in environment")
|
| 207 |
+
|
| 208 |
+
self.model_name = model_name or os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
|
| 209 |
+
self.temperature = temperature if temperature is not None else float(os.getenv("GEMINI_TEMPERATURE", "0.2"))
|
| 210 |
+
self.max_retries = max_retries
|
| 211 |
+
self.timeout = timeout
|
| 212 |
+
|
| 213 |
+
# Initialize Gemini client
|
| 214 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 215 |
+
|
| 216 |
+
# Store generation config
|
| 217 |
+
self.generation_config = types.GenerateContentConfig(
|
| 218 |
+
temperature=self.temperature,
|
| 219 |
+
response_mime_type="application/json"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
logger.info(f"GeminiClient initialized with model: {self.model_name}, temperature: {self.temperature}, timeout: {self.timeout}s")
|
| 223 |
+
|
| 224 |
+
def send_message(self, user_input: str, system_prompt: Optional[str] = None) -> dict:
|
| 225 |
+
"""
|
| 226 |
+
Send a message to Gemini API with retry logic and error handling.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
user_input: User's message text
|
| 230 |
+
system_prompt: Optional system prompt to prepend (defaults to MORTIS_SYSTEM_PROMPT)
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Parsed JSON response from Gemini
|
| 234 |
+
|
| 235 |
+
Raises:
|
| 236 |
+
GeminiAPIError: If all retry attempts fail (only for critical errors)
|
| 237 |
+
"""
|
| 238 |
+
# Use Mortis system prompt by default
|
| 239 |
+
if system_prompt is None:
|
| 240 |
+
system_prompt = MORTIS_SYSTEM_PROMPT
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
return self._send_message_with_retry(user_input, system_prompt, retry_count=0)
|
| 244 |
+
except GeminiBlockedPromptError as e:
|
| 245 |
+
# Handle blocked prompts with a fallback response
|
| 246 |
+
logger.warning(f"Blocked prompt error: {e}")
|
| 247 |
+
return self._get_fallback_response("The spirits refuse to speak of such things...")
|
| 248 |
+
except GeminiRateLimitError as e:
|
| 249 |
+
# Rate limit exceeded after all retries
|
| 250 |
+
logger.error(f"Rate limit error: {e}")
|
| 251 |
+
return self._get_fallback_response("Too many spirits summoned at once... wait a moment.")
|
| 252 |
+
except GeminiTimeoutError as e:
|
| 253 |
+
# Timeout error
|
| 254 |
+
logger.error(f"Timeout error: {e}")
|
| 255 |
+
return self._get_fallback_response("The spirits are slow to respond... try again.")
|
| 256 |
+
except Exception as e:
|
| 257 |
+
# Catch-all for unexpected errors
|
| 258 |
+
logger.error(f"Unexpected error in send_message: {type(e).__name__}: {e}", exc_info=True)
|
| 259 |
+
return self._get_fallback_response("The spirits are confused... try again.")
|
| 260 |
+
|
| 261 |
+
def _send_message_with_retry(
|
| 262 |
+
self,
|
| 263 |
+
user_input: str,
|
| 264 |
+
system_prompt: Optional[str],
|
| 265 |
+
retry_count: int
|
| 266 |
+
) -> dict:
|
| 267 |
+
"""
|
| 268 |
+
Internal method to send message with exponential backoff retry.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
user_input: User's message text
|
| 272 |
+
system_prompt: Optional system prompt
|
| 273 |
+
retry_count: Current retry attempt number
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Parsed JSON response from Gemini
|
| 277 |
+
|
| 278 |
+
Raises:
|
| 279 |
+
GeminiAPIError: If max retries exceeded
|
| 280 |
+
"""
|
| 281 |
+
start_time = time.time()
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
# Construct the full prompt
|
| 285 |
+
if system_prompt:
|
| 286 |
+
full_prompt = f"{system_prompt}\n\nUser: {user_input}"
|
| 287 |
+
else:
|
| 288 |
+
full_prompt = user_input
|
| 289 |
+
|
| 290 |
+
# Send request to Gemini using new API with timeout
|
| 291 |
+
logger.debug(f"Sending message to Gemini (attempt {retry_count + 1}/{self.max_retries + 1})")
|
| 292 |
+
|
| 293 |
+
# Check if we've exceeded timeout
|
| 294 |
+
if time.time() - start_time > self.timeout:
|
| 295 |
+
logger.error(f"API call timeout exceeded ({self.timeout}s)")
|
| 296 |
+
raise GeminiTimeoutError(f"API call timeout exceeded ({self.timeout}s)")
|
| 297 |
+
|
| 298 |
+
response = self.client.models.generate_content(
|
| 299 |
+
model=self.model_name,
|
| 300 |
+
contents=full_prompt,
|
| 301 |
+
config=self.generation_config
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Parse JSON response
|
| 305 |
+
response_text = response.text.strip()
|
| 306 |
+
elapsed_time = time.time() - start_time
|
| 307 |
+
logger.debug(f"Received response in {elapsed_time:.2f}s: {response_text[:100]}...")
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
response_json = json.loads(response_text)
|
| 311 |
+
logger.info(f"Successfully parsed response (type: {response_json.get('type', 'unknown')})")
|
| 312 |
+
return response_json
|
| 313 |
+
except json.JSONDecodeError as e:
|
| 314 |
+
logger.error(f"Failed to parse JSON response: {e}")
|
| 315 |
+
logger.error(f"Response text: {response_text}")
|
| 316 |
+
logger.warning("Returning fallback response due to JSON parse error")
|
| 317 |
+
return self._get_fallback_response("The spirits speak in riddles... try again.")
|
| 318 |
+
|
| 319 |
+
except GeminiTimeoutError as e:
|
| 320 |
+
# Timeout error - return fallback
|
| 321 |
+
logger.error(f"Timeout error: {e}")
|
| 322 |
+
return self._get_fallback_response("The spirits are slow to respond... try again.")
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
# Check for specific error types
|
| 326 |
+
error_type = type(e).__name__
|
| 327 |
+
error_message = str(e)
|
| 328 |
+
|
| 329 |
+
# Handle blocked prompt (safety filter)
|
| 330 |
+
if "BlockedPrompt" in error_type or "blocked" in error_message.lower() or "safety" in error_message.lower():
|
| 331 |
+
logger.warning(f"Prompt blocked by safety filter: {error_type}: {error_message}")
|
| 332 |
+
raise GeminiBlockedPromptError(f"Prompt blocked by safety filter: {error_message}") from e
|
| 333 |
+
|
| 334 |
+
# Handle rate limiting with exponential backoff retry
|
| 335 |
+
if self._is_rate_limit_error(e):
|
| 336 |
+
if retry_count < self.max_retries:
|
| 337 |
+
wait_time = (2 ** retry_count) # Exponential backoff: 1s, 2s, 4s, 8s
|
| 338 |
+
logger.warning(
|
| 339 |
+
f"Rate limit exceeded. Retrying in {wait_time}s... "
|
| 340 |
+
f"(attempt {retry_count + 1}/{self.max_retries})"
|
| 341 |
+
)
|
| 342 |
+
time.sleep(wait_time)
|
| 343 |
+
return self._send_message_with_retry(user_input, system_prompt, retry_count + 1)
|
| 344 |
+
else:
|
| 345 |
+
logger.error(f"Max retries ({self.max_retries}) exceeded for rate limit")
|
| 346 |
+
raise GeminiRateLimitError(
|
| 347 |
+
f"Rate limit exceeded after {self.max_retries} retries. Please try again later."
|
| 348 |
+
) from e
|
| 349 |
+
|
| 350 |
+
# Handle timeout errors from Google API
|
| 351 |
+
if self._is_timeout_error(e):
|
| 352 |
+
logger.error(f"API timeout error: {error_type}: {error_message}")
|
| 353 |
+
return self._get_fallback_response("The spirits are slow to respond... try again.")
|
| 354 |
+
|
| 355 |
+
# Handle other API errors
|
| 356 |
+
logger.error(f"Gemini API error: {error_type}: {error_message}", exc_info=True)
|
| 357 |
+
return self._get_fallback_response("The spirits are restless... try again.")
|
| 358 |
+
|
| 359 |
+
def _is_rate_limit_error(self, exception: Exception) -> bool:
|
| 360 |
+
"""
|
| 361 |
+
Check if exception is a rate limit error.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
exception: Exception to check
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
True if rate limit error, False otherwise
|
| 368 |
+
"""
|
| 369 |
+
error_type = type(exception).__name__
|
| 370 |
+
error_message = str(exception).lower()
|
| 371 |
+
|
| 372 |
+
# Check for common rate limit indicators
|
| 373 |
+
rate_limit_indicators = [
|
| 374 |
+
"ratelimit",
|
| 375 |
+
"rate_limit",
|
| 376 |
+
"resourceexhausted",
|
| 377 |
+
"resource_exhausted",
|
| 378 |
+
"429",
|
| 379 |
+
"quota",
|
| 380 |
+
"too many requests"
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
return any(indicator in error_type.lower() or indicator in error_message
|
| 384 |
+
for indicator in rate_limit_indicators)
|
| 385 |
+
|
| 386 |
+
def _is_timeout_error(self, exception: Exception) -> bool:
|
| 387 |
+
"""
|
| 388 |
+
Check if exception is a timeout error.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
exception: Exception to check
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
True if timeout error, False otherwise
|
| 395 |
+
"""
|
| 396 |
+
error_type = type(exception).__name__
|
| 397 |
+
error_message = str(exception).lower()
|
| 398 |
+
|
| 399 |
+
# Check for common timeout indicators
|
| 400 |
+
timeout_indicators = [
|
| 401 |
+
"timeout",
|
| 402 |
+
"deadline",
|
| 403 |
+
"deadlineexceeded",
|
| 404 |
+
"deadline_exceeded"
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
return any(indicator in error_type.lower() or indicator in error_message
|
| 408 |
+
for indicator in timeout_indicators)
|
| 409 |
+
|
| 410 |
+
def _get_fallback_response(self, message: Optional[str] = None) -> dict:
|
| 411 |
+
"""
|
| 412 |
+
Return a safe fallback response when API fails.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
message: Optional custom message (defaults to generic error message)
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
Dictionary with fallback conversation response
|
| 419 |
+
"""
|
| 420 |
+
default_message = "The spirits are restless... try again."
|
| 421 |
+
fallback_message = message or default_message
|
| 422 |
+
|
| 423 |
+
logger.info(f"Returning fallback response: {fallback_message}")
|
| 424 |
+
return {
|
| 425 |
+
"type": "conversation",
|
| 426 |
+
"message": fallback_message,
|
| 427 |
+
"mood": "ominous",
|
| 428 |
+
"gesture": "idle"
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
def configure_model(self, model_name: Optional[str] = None, temperature: Optional[float] = None):
|
| 432 |
+
"""
|
| 433 |
+
Reconfigure the Gemini model settings.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
model_name: New model name to use
|
| 437 |
+
temperature: New temperature value
|
| 438 |
+
"""
|
| 439 |
+
if model_name:
|
| 440 |
+
self.model_name = model_name
|
| 441 |
+
|
| 442 |
+
if temperature is not None:
|
| 443 |
+
self.temperature = temperature
|
| 444 |
+
|
| 445 |
+
# Update generation config
|
| 446 |
+
self.generation_config = types.GenerateContentConfig(
|
| 447 |
+
temperature=self.temperature,
|
| 448 |
+
response_mime_type="application/json"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
logger.info(f"Model reconfigured: {self.model_name}, temperature: {self.temperature}")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Example usage
|
| 455 |
+
if __name__ == "__main__":
|
| 456 |
+
# Configure logging for testing
|
| 457 |
+
logging.basicConfig(level=logging.INFO)
|
| 458 |
+
|
| 459 |
+
# Create client
|
| 460 |
+
try:
|
| 461 |
+
client = GeminiClient()
|
| 462 |
+
|
| 463 |
+
# Test conversational message
|
| 464 |
+
print("Testing conversational input...")
|
| 465 |
+
response = client.send_message("Hello Mortis, introduce yourself!")
|
| 466 |
+
print("Response:", json.dumps(response, indent=2))
|
| 467 |
+
print()
|
| 468 |
+
|
| 469 |
+
# Test manipulation command
|
| 470 |
+
print("Testing manipulation command...")
|
| 471 |
+
response = client.send_message("Can you move the skull to the green cup?")
|
| 472 |
+
print("Response:", json.dumps(response, indent=2))
|
| 473 |
+
print()
|
| 474 |
+
|
| 475 |
+
# Test another manipulation with different wording
|
| 476 |
+
print("Testing manipulation with different wording...")
|
| 477 |
+
response = client.send_message("Put the eyeball in the orange cup")
|
| 478 |
+
print("Response:", json.dumps(response, indent=2))
|
| 479 |
+
|
| 480 |
+
except ValueError as e:
|
| 481 |
+
print(f"Error: {e}")
|
| 482 |
+
print("Please set GEMINI_API_KEY in your .env file")
|
src/mortis/intent_router.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intent router for parsing Gemini responses and routing to appropriate execution paths.
|
| 3 |
+
|
| 4 |
+
This module handles the routing logic between conversational gestures and manipulation
|
| 5 |
+
tasks based on Gemini API responses. It validates commands against the trained task set
|
| 6 |
+
and provides structured intent representation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, List, Dict, Any
|
| 13 |
+
|
| 14 |
+
from .models import GeminiResponse, ResponseType, Gesture
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Intent:
|
| 21 |
+
"""
|
| 22 |
+
Structured representation of user intent parsed from Gemini response.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
type: The type of intent (conversation or manipulation)
|
| 26 |
+
message: The text message to display/speak to the user
|
| 27 |
+
mood: The emotional mood of the response
|
| 28 |
+
gesture: Optional gesture to execute (for conversation type)
|
| 29 |
+
command: Optional manipulation command (for manipulation type)
|
| 30 |
+
is_valid: Whether the intent is valid and can be executed
|
| 31 |
+
validation_error: Optional error message if validation failed
|
| 32 |
+
"""
|
| 33 |
+
type: ResponseType
|
| 34 |
+
message: str
|
| 35 |
+
mood: str
|
| 36 |
+
gesture: Optional[str] = None
|
| 37 |
+
command: Optional[str] = None
|
| 38 |
+
is_valid: bool = True
|
| 39 |
+
validation_error: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_gemini_response(cls, response: GeminiResponse, is_valid: bool = True,
|
| 43 |
+
validation_error: Optional[str] = None) -> "Intent":
|
| 44 |
+
"""
|
| 45 |
+
Create an Intent from a GeminiResponse.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
response: The parsed GeminiResponse object
|
| 49 |
+
is_valid: Whether the intent passed validation
|
| 50 |
+
validation_error: Optional error message if validation failed
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Intent object with all fields populated
|
| 54 |
+
"""
|
| 55 |
+
return cls(
|
| 56 |
+
type=response.type,
|
| 57 |
+
message=response.message,
|
| 58 |
+
mood=response.mood.value,
|
| 59 |
+
gesture=response.gesture.value if response.gesture else None,
|
| 60 |
+
command=response.command,
|
| 61 |
+
is_valid=is_valid,
|
| 62 |
+
validation_error=validation_error
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 66 |
+
"""
|
| 67 |
+
Convert the intent to a dictionary.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Dictionary representation of the intent
|
| 71 |
+
"""
|
| 72 |
+
result = {
|
| 73 |
+
"type": self.type.value,
|
| 74 |
+
"message": self.message,
|
| 75 |
+
"mood": self.mood,
|
| 76 |
+
"is_valid": self.is_valid,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
if self.gesture is not None:
|
| 80 |
+
result["gesture"] = self.gesture
|
| 81 |
+
|
| 82 |
+
if self.command is not None:
|
| 83 |
+
result["command"] = self.command
|
| 84 |
+
|
| 85 |
+
if self.validation_error is not None:
|
| 86 |
+
result["validation_error"] = self.validation_error
|
| 87 |
+
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class IntentRouter:
|
| 92 |
+
"""
|
| 93 |
+
Routes user intents to appropriate execution paths based on Gemini responses.
|
| 94 |
+
|
| 95 |
+
The IntentRouter parses Gemini API responses, validates manipulation commands
|
| 96 |
+
against the trained task set, and creates structured Intent objects for execution.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# Valid manipulation task commands that SmolVLA is trained on
|
| 100 |
+
VALID_COMMANDS = [
|
| 101 |
+
"Pick up the skull and place it in the green cup",
|
| 102 |
+
"Pick up the skull and place it in the orange cup",
|
| 103 |
+
"Pick up the skull and place it in the purple cup",
|
| 104 |
+
"Pick up the eyeball and place it in the green cup",
|
| 105 |
+
"Pick up the eyeball and place it in the orange cup",
|
| 106 |
+
"Pick up the eyeball and place it in the purple cup",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
def __init__(self, valid_commands: Optional[List[str]] = None):
|
| 110 |
+
"""
|
| 111 |
+
Initialize the IntentRouter.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
valid_commands: Optional list of valid manipulation commands.
|
| 115 |
+
If not provided, uses the default VALID_COMMANDS.
|
| 116 |
+
"""
|
| 117 |
+
self.valid_commands = valid_commands if valid_commands is not None else self.VALID_COMMANDS
|
| 118 |
+
logger.info(f"IntentRouter initialized with {len(self.valid_commands)} valid commands")
|
| 119 |
+
|
| 120 |
+
def parse_gemini_response(self, response_data: Dict[str, Any]) -> Intent:
|
| 121 |
+
"""
|
| 122 |
+
Parse a Gemini API response and create an Intent.
|
| 123 |
+
|
| 124 |
+
This method:
|
| 125 |
+
1. Parses the JSON response into a GeminiResponse object
|
| 126 |
+
2. Validates manipulation commands against the trained task set
|
| 127 |
+
3. Creates an Intent object with validation results
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
response_data: Dictionary containing the JSON response from Gemini
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Intent object with parsed data and validation status
|
| 134 |
+
|
| 135 |
+
Raises:
|
| 136 |
+
ValueError: If the response structure is invalid
|
| 137 |
+
json.JSONDecodeError: If response_data is a string and not valid JSON
|
| 138 |
+
"""
|
| 139 |
+
try:
|
| 140 |
+
# Parse the Gemini response
|
| 141 |
+
gemini_response = GeminiResponse.from_json(response_data)
|
| 142 |
+
|
| 143 |
+
# Validate the response structure
|
| 144 |
+
try:
|
| 145 |
+
gemini_response.validate()
|
| 146 |
+
except ValueError as e:
|
| 147 |
+
logger.warning(f"Response validation warning: {e}")
|
| 148 |
+
# Continue anyway - validation warnings are not fatal
|
| 149 |
+
|
| 150 |
+
# For manipulation intents, validate the command
|
| 151 |
+
if gemini_response.type == ResponseType.MANIPULATION:
|
| 152 |
+
is_valid = self.validate_command(gemini_response.command)
|
| 153 |
+
|
| 154 |
+
if not is_valid:
|
| 155 |
+
logger.warning(
|
| 156 |
+
f"Invalid manipulation command: '{gemini_response.command}'. "
|
| 157 |
+
f"Not in trained task set."
|
| 158 |
+
)
|
| 159 |
+
validation_error = (
|
| 160 |
+
f"Command '{gemini_response.command}' is not in the trained task set. "
|
| 161 |
+
f"Valid commands are: {', '.join(self.valid_commands)}"
|
| 162 |
+
)
|
| 163 |
+
return Intent.from_gemini_response(
|
| 164 |
+
gemini_response,
|
| 165 |
+
is_valid=False,
|
| 166 |
+
validation_error=validation_error
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
logger.info(f"Valid manipulation command: '{gemini_response.command}'")
|
| 170 |
+
|
| 171 |
+
# For conversation intents, always valid (gestures are predefined)
|
| 172 |
+
else:
|
| 173 |
+
logger.info(f"Conversation intent with gesture: {gemini_response.gesture.value}")
|
| 174 |
+
|
| 175 |
+
# Create and return valid intent
|
| 176 |
+
return Intent.from_gemini_response(gemini_response, is_valid=True)
|
| 177 |
+
|
| 178 |
+
except (ValueError, KeyError) as e:
|
| 179 |
+
logger.error(f"Failed to parse Gemini response: {e}")
|
| 180 |
+
raise ValueError(f"Invalid Gemini response structure: {e}")
|
| 181 |
+
|
| 182 |
+
def parse_gemini_response_string(self, response_string: str) -> Intent:
|
| 183 |
+
"""
|
| 184 |
+
Parse a Gemini API response from a JSON string.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
response_string: JSON string containing the Gemini response
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Intent object with parsed data and validation status
|
| 191 |
+
|
| 192 |
+
Raises:
|
| 193 |
+
json.JSONDecodeError: If the string is not valid JSON
|
| 194 |
+
ValueError: If the response structure is invalid
|
| 195 |
+
"""
|
| 196 |
+
try:
|
| 197 |
+
response_data = json.loads(response_string)
|
| 198 |
+
except json.JSONDecodeError as e:
|
| 199 |
+
logger.error(f"Failed to parse JSON string: {e}")
|
| 200 |
+
raise
|
| 201 |
+
|
| 202 |
+
return self.parse_gemini_response(response_data)
|
| 203 |
+
|
| 204 |
+
def validate_command(self, command: str) -> bool:
|
| 205 |
+
"""
|
| 206 |
+
Validate that a manipulation command is in the trained task set.
|
| 207 |
+
|
| 208 |
+
This performs exact string matching against the list of valid commands.
|
| 209 |
+
Commands must match exactly (case-sensitive) to be considered valid.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
command: The manipulation command string to validate
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
True if the command is valid, False otherwise
|
| 216 |
+
"""
|
| 217 |
+
if not command or not isinstance(command, str):
|
| 218 |
+
logger.warning(f"Invalid command type: {type(command)}")
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
# Exact match required
|
| 222 |
+
is_valid = command in self.valid_commands
|
| 223 |
+
|
| 224 |
+
if not is_valid:
|
| 225 |
+
# Log for debugging - maybe it's close to a valid command
|
| 226 |
+
logger.debug(f"Command '{command}' not found in valid commands")
|
| 227 |
+
logger.debug(f"Valid commands: {self.valid_commands}")
|
| 228 |
+
|
| 229 |
+
return is_valid
|
| 230 |
+
|
| 231 |
+
def get_valid_commands(self) -> List[str]:
|
| 232 |
+
"""
|
| 233 |
+
Get the list of valid manipulation commands.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of valid command strings
|
| 237 |
+
"""
|
| 238 |
+
return self.valid_commands.copy()
|
| 239 |
+
|
| 240 |
+
def add_valid_command(self, command: str) -> None:
|
| 241 |
+
"""
|
| 242 |
+
Add a new valid manipulation command to the router.
|
| 243 |
+
|
| 244 |
+
This is useful when training new tasks and expanding the command set.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
command: The new command string to add
|
| 248 |
+
"""
|
| 249 |
+
if command not in self.valid_commands:
|
| 250 |
+
self.valid_commands.append(command)
|
| 251 |
+
logger.info(f"Added new valid command: '{command}'")
|
| 252 |
+
else:
|
| 253 |
+
logger.warning(f"Command already exists: '{command}'")
|
| 254 |
+
|
| 255 |
+
def remove_valid_command(self, command: str) -> bool:
|
| 256 |
+
"""
|
| 257 |
+
Remove a valid manipulation command from the router.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
command: The command string to remove
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
True if the command was removed, False if it wasn't found
|
| 264 |
+
"""
|
| 265 |
+
if command in self.valid_commands:
|
| 266 |
+
self.valid_commands.remove(command)
|
| 267 |
+
logger.info(f"Removed valid command: '{command}'")
|
| 268 |
+
return True
|
| 269 |
+
else:
|
| 270 |
+
logger.warning(f"Command not found: '{command}'")
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
def route_intent(self, intent: Intent) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Determine the execution path for an intent.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
intent: The Intent object to route
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
String indicating the execution path: "gesture", "manipulation", or "invalid"
|
| 282 |
+
"""
|
| 283 |
+
if not intent.is_valid:
|
| 284 |
+
logger.warning(f"Invalid intent: {intent.validation_error}")
|
| 285 |
+
return "invalid"
|
| 286 |
+
|
| 287 |
+
if intent.type == ResponseType.CONVERSATION:
|
| 288 |
+
logger.info(f"Routing to gesture execution: {intent.gesture}")
|
| 289 |
+
return "gesture"
|
| 290 |
+
elif intent.type == ResponseType.MANIPULATION:
|
| 291 |
+
logger.info(f"Routing to manipulation execution: {intent.command}")
|
| 292 |
+
return "manipulation"
|
| 293 |
+
else:
|
| 294 |
+
logger.error(f"Unknown intent type: {intent.type}")
|
| 295 |
+
return "invalid"
|
src/mortis/lerobot_async_client.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LeRobot async inference client wrapper for Mortis manipulation tasks.
|
| 3 |
+
|
| 4 |
+
This module provides a high-level interface to LeRobot's async inference system
|
| 5 |
+
(PolicyServer + RobotClient) for executing SmolVLA manipulation tasks while
|
| 6 |
+
keeping the Gradio UI responsive.
|
| 7 |
+
|
| 8 |
+
Architecture:
|
| 9 |
+
- PolicyServer: Runs in a separate thread, loads SmolVLA model, performs inference
|
| 10 |
+
- RobotClient: Controls the SO101 robot, captures observations, executes actions
|
| 11 |
+
- This wrapper: Manages lifecycle and provides simple API for Mortis
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import threading
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional, Dict, Any, Callable
|
| 21 |
+
|
| 22 |
+
from lerobot.robots.so101_follower import SO101FollowerConfig
|
| 23 |
+
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
| 24 |
+
from lerobot.cameras.realsense import RealSenseCameraConfig
|
| 25 |
+
from lerobot.async_inference.configs import PolicyServerConfig, RobotClientConfig
|
| 26 |
+
from lerobot.async_inference.policy_server import serve
|
| 27 |
+
from lerobot.async_inference.robot_client import RobotClient
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ManipulationStatus(Enum):
|
| 34 |
+
"""Status of a manipulation task execution."""
|
| 35 |
+
IDLE = "idle"
|
| 36 |
+
STARTING = "starting"
|
| 37 |
+
RUNNING = "running"
|
| 38 |
+
COMPLETE = "complete"
|
| 39 |
+
FAILED = "failed"
|
| 40 |
+
STOPPED = "stopped"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class ManipulationTask:
|
| 45 |
+
"""
|
| 46 |
+
Represents a manipulation task for LeRobot async execution.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
task: Natural language task description
|
| 50 |
+
max_steps: Maximum number of action steps to execute
|
| 51 |
+
started_at: Timestamp when task started
|
| 52 |
+
completed_at: Timestamp when task completed
|
| 53 |
+
status: Current task status
|
| 54 |
+
error: Error message if task failed
|
| 55 |
+
"""
|
| 56 |
+
task: str
|
| 57 |
+
max_steps: int = 1000 # At 30fps, ~33 seconds of execution
|
| 58 |
+
started_at: Optional[float] = None
|
| 59 |
+
completed_at: Optional[float] = None
|
| 60 |
+
status: ManipulationStatus = ManipulationStatus.IDLE
|
| 61 |
+
error: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def duration(self) -> Optional[float]:
|
| 65 |
+
"""Get task execution duration in seconds."""
|
| 66 |
+
if self.started_at and self.completed_at:
|
| 67 |
+
return self.completed_at - self.started_at
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LeRobotAsyncClient:
|
| 72 |
+
"""
|
| 73 |
+
High-level wrapper for LeRobot async inference system.
|
| 74 |
+
|
| 75 |
+
This class manages the PolicyServer and RobotClient lifecycle, providing
|
| 76 |
+
a simple interface for executing manipulation tasks asynchronously.
|
| 77 |
+
|
| 78 |
+
Usage:
|
| 79 |
+
# Create client
|
| 80 |
+
client = LeRobotAsyncClient(
|
| 81 |
+
robot_port="/dev/ttyACM1",
|
| 82 |
+
model_path="jlamperez/kiroween-potion-smolvla",
|
| 83 |
+
camera_configs={...}
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Start the system
|
| 87 |
+
client.start()
|
| 88 |
+
|
| 89 |
+
# Execute a task
|
| 90 |
+
client.execute_task("Pick up the skull and place it in the green cup")
|
| 91 |
+
|
| 92 |
+
# Check status
|
| 93 |
+
status = client.get_status()
|
| 94 |
+
|
| 95 |
+
# Stop when done
|
| 96 |
+
client.stop()
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
robot_port: str = "/dev/ttyACM1",
|
| 102 |
+
robot_id: str = "my_follower_robot_arm", # Must match calibration file name
|
| 103 |
+
model_path: str = "jlamperez/kiroween-potion-smolvla",
|
| 104 |
+
policy_device: str = "cuda",
|
| 105 |
+
camera_configs: Optional[Dict[str, Any]] = None,
|
| 106 |
+
server_host: str = "127.0.0.1",
|
| 107 |
+
server_port: int = 8080,
|
| 108 |
+
actions_per_chunk: int = 50,
|
| 109 |
+
chunk_size_threshold: float = 0.5,
|
| 110 |
+
aggregate_fn_name: str = "weighted_average",
|
| 111 |
+
):
|
| 112 |
+
"""
|
| 113 |
+
Initialize the LeRobot async client.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
robot_port: Serial port for SO101 robot (e.g., "/dev/ttyACM1")
|
| 117 |
+
robot_id: Identifier for the robot
|
| 118 |
+
model_path: HuggingFace model path or local checkpoint
|
| 119 |
+
policy_device: Device for model inference ("cuda" or "cpu")
|
| 120 |
+
camera_configs: Dictionary of camera configurations
|
| 121 |
+
server_host: PolicyServer host address
|
| 122 |
+
server_port: PolicyServer port
|
| 123 |
+
actions_per_chunk: Number of actions per inference chunk
|
| 124 |
+
chunk_size_threshold: Threshold for action chunk aggregation
|
| 125 |
+
aggregate_fn_name: Function name for aggregating action chunks
|
| 126 |
+
"""
|
| 127 |
+
self.robot_port = robot_port
|
| 128 |
+
self.robot_id = robot_id
|
| 129 |
+
self.model_path = model_path
|
| 130 |
+
self.policy_device = policy_device
|
| 131 |
+
self.server_host = server_host
|
| 132 |
+
self.server_port = server_port
|
| 133 |
+
self.actions_per_chunk = actions_per_chunk
|
| 134 |
+
self.chunk_size_threshold = chunk_size_threshold
|
| 135 |
+
self.aggregate_fn_name = aggregate_fn_name
|
| 136 |
+
|
| 137 |
+
# Use default camera configs if not provided
|
| 138 |
+
self.camera_configs = camera_configs or self._get_default_camera_configs()
|
| 139 |
+
|
| 140 |
+
# Server and client instances
|
| 141 |
+
self.server_thread: Optional[threading.Thread] = None
|
| 142 |
+
self.robot_client: Optional[RobotClient] = None
|
| 143 |
+
self.action_receiver_thread: Optional[threading.Thread] = None
|
| 144 |
+
self.control_thread: Optional[threading.Thread] = None
|
| 145 |
+
|
| 146 |
+
# Current task tracking
|
| 147 |
+
self.current_task: Optional[ManipulationTask] = None
|
| 148 |
+
self._running = False
|
| 149 |
+
self._stop_event = threading.Event()
|
| 150 |
+
self._task_stop_event = threading.Event() # Event to signal task cancellation
|
| 151 |
+
self._idle_callback: Optional[Callable] = None # Callback to move robot to idle
|
| 152 |
+
|
| 153 |
+
logger.info(f"LeRobotAsyncClient initialized with model: {model_path}")
|
| 154 |
+
|
| 155 |
+
def _get_default_camera_configs(self) -> Dict[str, Any]:
|
| 156 |
+
"""
|
| 157 |
+
Get default camera configuration for Mortis setup.
|
| 158 |
+
|
| 159 |
+
IMPORTANT: This configuration MUST match the cameras used during training!
|
| 160 |
+
If you trained with IntelRealSense + OpenCV, use the same setup here.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dictionary of camera configurations
|
| 164 |
+
"""
|
| 165 |
+
# Default camera configuration matching training setup
|
| 166 |
+
# This should match your training configuration exactly!
|
| 167 |
+
|
| 168 |
+
# Configuration with RealSense + OpenCV (matches training setup)
|
| 169 |
+
return {
|
| 170 |
+
"camera1": RealSenseCameraConfig(
|
| 171 |
+
serial_number_or_name="030522070314",
|
| 172 |
+
width=640,
|
| 173 |
+
height=480,
|
| 174 |
+
fps=30
|
| 175 |
+
),
|
| 176 |
+
"camera2": OpenCVCameraConfig(
|
| 177 |
+
index_or_path=8,
|
| 178 |
+
width=640,
|
| 179 |
+
height=480,
|
| 180 |
+
fps=30
|
| 181 |
+
)
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def start(self) -> bool:
|
| 185 |
+
"""
|
| 186 |
+
Start the PolicyServer only.
|
| 187 |
+
|
| 188 |
+
The RobotClient will be created lazily when the first task is executed.
|
| 189 |
+
This avoids loading the model unnecessarily at startup.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
True if startup successful, False otherwise
|
| 193 |
+
"""
|
| 194 |
+
if self._running:
|
| 195 |
+
logger.warning("LeRobotAsyncClient is already running")
|
| 196 |
+
return True
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
logger.info("Starting PolicyServer...")
|
| 200 |
+
|
| 201 |
+
# Configure and start PolicyServer
|
| 202 |
+
server_config = PolicyServerConfig(
|
| 203 |
+
host=self.server_host,
|
| 204 |
+
port=self.server_port
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.server_thread = threading.Thread(
|
| 208 |
+
target=serve,
|
| 209 |
+
args=(server_config,),
|
| 210 |
+
daemon=True,
|
| 211 |
+
name="PolicyServer"
|
| 212 |
+
)
|
| 213 |
+
self.server_thread.start()
|
| 214 |
+
|
| 215 |
+
# Give server time to start
|
| 216 |
+
time.sleep(2.0)
|
| 217 |
+
logger.info(f"PolicyServer started on {self.server_host}:{self.server_port}")
|
| 218 |
+
|
| 219 |
+
self._running = True
|
| 220 |
+
self._stop_event.clear()
|
| 221 |
+
|
| 222 |
+
logger.info("LeRobotAsyncClient started (RobotClient will be created on first task)")
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Failed to start LeRobotAsyncClient: {e}", exc_info=True)
|
| 227 |
+
self.stop()
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
def stop(self) -> None:
|
| 231 |
+
"""
|
| 232 |
+
Stop the PolicyServer and RobotClient.
|
| 233 |
+
|
| 234 |
+
This method gracefully shuts down all components.
|
| 235 |
+
"""
|
| 236 |
+
if not self._running:
|
| 237 |
+
logger.warning("LeRobotAsyncClient is not running")
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
logger.info("Stopping LeRobotAsyncClient...")
|
| 241 |
+
|
| 242 |
+
self._running = False
|
| 243 |
+
self._stop_event.set()
|
| 244 |
+
|
| 245 |
+
# Stop control thread if running
|
| 246 |
+
if self.control_thread and self.control_thread.is_alive():
|
| 247 |
+
logger.info("Waiting for control thread to finish...")
|
| 248 |
+
self.control_thread.join(timeout=5.0)
|
| 249 |
+
|
| 250 |
+
# Stop robot client
|
| 251 |
+
if self.robot_client:
|
| 252 |
+
try:
|
| 253 |
+
self.robot_client.stop()
|
| 254 |
+
logger.info("RobotClient stopped")
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.error(f"Error stopping RobotClient: {e}")
|
| 257 |
+
|
| 258 |
+
# Action receiver thread should stop automatically (daemon)
|
| 259 |
+
# Server thread should stop automatically (daemon)
|
| 260 |
+
|
| 261 |
+
self.robot_client = None
|
| 262 |
+
self.server_thread = None
|
| 263 |
+
self.action_receiver_thread = None
|
| 264 |
+
self.control_thread = None
|
| 265 |
+
|
| 266 |
+
logger.info("LeRobotAsyncClient stopped")
|
| 267 |
+
|
| 268 |
+
def execute_task(
|
| 269 |
+
self,
|
| 270 |
+
task: str,
|
| 271 |
+
max_steps: int = 1000,
|
| 272 |
+
blocking: bool = False,
|
| 273 |
+
timeout: float = 60.0
|
| 274 |
+
) -> bool:
|
| 275 |
+
"""
|
| 276 |
+
Execute a manipulation task asynchronously.
|
| 277 |
+
|
| 278 |
+
This method stops any running task and creates a fresh RobotClient
|
| 279 |
+
for the new task, ensuring clean state.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
task: Natural language task description
|
| 283 |
+
max_steps: Maximum number of action steps
|
| 284 |
+
blocking: If True, wait for task to complete before returning
|
| 285 |
+
timeout: Maximum execution time in seconds (default: 60.0)
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
True if task started successfully, False otherwise
|
| 289 |
+
"""
|
| 290 |
+
if not self._running:
|
| 291 |
+
logger.error("Cannot execute task: client not running")
|
| 292 |
+
return False
|
| 293 |
+
|
| 294 |
+
# Always need a fresh client for each task because control_loop can only run once
|
| 295 |
+
# But we keep the PolicyServer alive so the model stays loaded
|
| 296 |
+
need_new_client = True
|
| 297 |
+
|
| 298 |
+
if self.robot_client is None:
|
| 299 |
+
# First task - need to create client
|
| 300 |
+
logger.info("First task - creating RobotClient...")
|
| 301 |
+
elif self.current_task and self.current_task.status == ManipulationStatus.RUNNING:
|
| 302 |
+
# Task is running - stop it first
|
| 303 |
+
logger.info(f"Stopping previous task: {self.current_task.task}")
|
| 304 |
+
self._stop_robot_client()
|
| 305 |
+
else:
|
| 306 |
+
# Previous task finished - recreate client for new task
|
| 307 |
+
logger.info("Recreating RobotClient for new task (PolicyServer keeps model loaded)")
|
| 308 |
+
|
| 309 |
+
# Wait for previous control thread to finish
|
| 310 |
+
if self.control_thread and self.control_thread.is_alive():
|
| 311 |
+
logger.info("Waiting for previous control thread to finish...")
|
| 312 |
+
self.control_thread.join(timeout=3.0)
|
| 313 |
+
if self.control_thread.is_alive():
|
| 314 |
+
logger.warning("Previous control thread still running, proceeding anyway")
|
| 315 |
+
|
| 316 |
+
# Create new task
|
| 317 |
+
self.current_task = ManipulationTask(
|
| 318 |
+
task=task,
|
| 319 |
+
max_steps=max_steps,
|
| 320 |
+
status=ManipulationStatus.STARTING
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Clear any previous stop signal
|
| 324 |
+
self._task_stop_event.clear()
|
| 325 |
+
|
| 326 |
+
logger.info(f"Executing task: {task}")
|
| 327 |
+
logger.info(f"Limits: max_steps={max_steps}, timeout={timeout}s")
|
| 328 |
+
|
| 329 |
+
# Create/recreate robot client only if needed
|
| 330 |
+
if need_new_client:
|
| 331 |
+
if not self._recreate_robot_client(task):
|
| 332 |
+
logger.error("Failed to create robot client")
|
| 333 |
+
self.current_task.status = ManipulationStatus.FAILED
|
| 334 |
+
self.current_task.error = "Failed to initialize robot client"
|
| 335 |
+
return False
|
| 336 |
+
|
| 337 |
+
# Start control loop in separate thread
|
| 338 |
+
self.control_thread = threading.Thread(
|
| 339 |
+
target=self._run_control_loop,
|
| 340 |
+
args=(task, max_steps, timeout),
|
| 341 |
+
daemon=True,
|
| 342 |
+
name="ControlLoop"
|
| 343 |
+
)
|
| 344 |
+
self.control_thread.start()
|
| 345 |
+
|
| 346 |
+
if blocking:
|
| 347 |
+
self.control_thread.join()
|
| 348 |
+
|
| 349 |
+
return True
|
| 350 |
+
|
| 351 |
+
def _stop_robot_client(self) -> None:
|
| 352 |
+
"""
|
| 353 |
+
Stop the robot client cleanly.
|
| 354 |
+
|
| 355 |
+
This stops the robot client and waits for threads to finish.
|
| 356 |
+
"""
|
| 357 |
+
if self.robot_client:
|
| 358 |
+
try:
|
| 359 |
+
logger.info("Stopping robot client...")
|
| 360 |
+
self.robot_client.stop()
|
| 361 |
+
|
| 362 |
+
# Wait for action receiver thread
|
| 363 |
+
if self.action_receiver_thread and self.action_receiver_thread.is_alive():
|
| 364 |
+
self.action_receiver_thread.join(timeout=2.0)
|
| 365 |
+
|
| 366 |
+
logger.info("Robot client stopped")
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.error(f"Error stopping robot client: {e}")
|
| 369 |
+
|
| 370 |
+
def _recreate_robot_client(self, task: str) -> bool:
|
| 371 |
+
"""
|
| 372 |
+
Recreate the robot client with a new task.
|
| 373 |
+
|
| 374 |
+
This creates a fresh RobotClient instance for the new task,
|
| 375 |
+
ensuring clean state.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
task: Task description for the new client
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
True if successful, False otherwise
|
| 382 |
+
"""
|
| 383 |
+
try:
|
| 384 |
+
# Stop existing client if any
|
| 385 |
+
self._stop_robot_client()
|
| 386 |
+
|
| 387 |
+
# Small delay to ensure port is released
|
| 388 |
+
time.sleep(0.5)
|
| 389 |
+
|
| 390 |
+
# Reconfigure robot
|
| 391 |
+
from pathlib import Path
|
| 392 |
+
from lerobot.robots.so101_follower import SO101FollowerConfig
|
| 393 |
+
from lerobot.async_inference.configs import RobotClientConfig
|
| 394 |
+
from lerobot.async_inference.robot_client import RobotClient
|
| 395 |
+
|
| 396 |
+
calibration_dir = Path(".cache/calibration/so101")
|
| 397 |
+
robot_config = SO101FollowerConfig(
|
| 398 |
+
port=self.robot_port,
|
| 399 |
+
id=self.robot_id,
|
| 400 |
+
cameras=self.camera_configs,
|
| 401 |
+
calibration_dir=calibration_dir
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
client_config = RobotClientConfig(
|
| 405 |
+
robot=robot_config,
|
| 406 |
+
server_address=f"{self.server_host}:{self.server_port}",
|
| 407 |
+
policy_device=self.policy_device,
|
| 408 |
+
policy_type="smolvla",
|
| 409 |
+
pretrained_name_or_path=self.model_path,
|
| 410 |
+
chunk_size_threshold=self.chunk_size_threshold,
|
| 411 |
+
actions_per_chunk=self.actions_per_chunk,
|
| 412 |
+
aggregate_fn_name=self.aggregate_fn_name,
|
| 413 |
+
task=task # Set the task in the config
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Create new robot client
|
| 417 |
+
self.robot_client = RobotClient(client_config)
|
| 418 |
+
|
| 419 |
+
if not self.robot_client.start():
|
| 420 |
+
raise RuntimeError("Failed to start RobotClient")
|
| 421 |
+
|
| 422 |
+
# Start action receiver thread
|
| 423 |
+
self.action_receiver_thread = threading.Thread(
|
| 424 |
+
target=self.robot_client.receive_actions,
|
| 425 |
+
daemon=True,
|
| 426 |
+
name="ActionReceiver"
|
| 427 |
+
)
|
| 428 |
+
self.action_receiver_thread.start()
|
| 429 |
+
|
| 430 |
+
logger.info("Robot client recreated successfully")
|
| 431 |
+
return True
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Failed to recreate robot client: {e}", exc_info=True)
|
| 435 |
+
return False
|
| 436 |
+
|
| 437 |
+
def stop_current_task(self) -> bool:
|
| 438 |
+
"""
|
| 439 |
+
Stop the currently running task by stopping the robot client.
|
| 440 |
+
|
| 441 |
+
This cleanly stops the robot client, which will cause the control
|
| 442 |
+
loop to exit. The client will be recreated for the next task.
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
True if task was stopped successfully
|
| 446 |
+
"""
|
| 447 |
+
if not self.current_task or self.current_task.status != ManipulationStatus.RUNNING:
|
| 448 |
+
logger.warning("No task currently running to stop")
|
| 449 |
+
return False
|
| 450 |
+
|
| 451 |
+
logger.info("Stopping current task...")
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
# Mark task as stopped
|
| 455 |
+
self.current_task.status = ManipulationStatus.STOPPED
|
| 456 |
+
self.current_task.completed_at = time.time()
|
| 457 |
+
self.current_task.error = "Task stopped by user"
|
| 458 |
+
|
| 459 |
+
# Signal task stop
|
| 460 |
+
self._task_stop_event.set()
|
| 461 |
+
|
| 462 |
+
# Stop the robot client (this will interrupt the control loop)
|
| 463 |
+
try:
|
| 464 |
+
self._stop_robot_client()
|
| 465 |
+
except Exception as e:
|
| 466 |
+
logger.warning(f"Error stopping client (expected): {e}")
|
| 467 |
+
|
| 468 |
+
# Move robot to idle position
|
| 469 |
+
if self._idle_callback:
|
| 470 |
+
logger.info("Moving robot to idle position...")
|
| 471 |
+
try:
|
| 472 |
+
self._idle_callback()
|
| 473 |
+
logger.info("Robot moved to idle position")
|
| 474 |
+
except Exception as e:
|
| 475 |
+
logger.error(f"Failed to move to idle: {e}")
|
| 476 |
+
|
| 477 |
+
logger.info("Task stopped successfully")
|
| 478 |
+
|
| 479 |
+
# Clear the task after a delay
|
| 480 |
+
def clear_task():
|
| 481 |
+
time.sleep(3.0)
|
| 482 |
+
if self.current_task and self.current_task.status == ManipulationStatus.STOPPED:
|
| 483 |
+
self.current_task = None
|
| 484 |
+
logger.info("Cleared stopped task from status")
|
| 485 |
+
|
| 486 |
+
clear_thread = threading.Thread(target=clear_task, daemon=True)
|
| 487 |
+
clear_thread.start()
|
| 488 |
+
|
| 489 |
+
return True
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.error(f"Failed to stop task: {e}", exc_info=True)
|
| 493 |
+
return False
|
| 494 |
+
|
| 495 |
+
def _run_control_loop(self, task: str, max_steps: int, timeout: float) -> None:
|
| 496 |
+
"""
|
| 497 |
+
Run the control loop for task execution with timeout.
|
| 498 |
+
|
| 499 |
+
This runs in a separate thread and executes the task using
|
| 500 |
+
the RobotClient's control_loop method. The timeout will stop
|
| 501 |
+
the task, and recreating the client for each task ensures clean state.
|
| 502 |
+
|
| 503 |
+
Note: max_steps is not directly enforced by LeRobot's control_loop,
|
| 504 |
+
but the timeout provides a time-based limit.
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
task: Task description
|
| 508 |
+
max_steps: Maximum steps (informational, not enforced)
|
| 509 |
+
timeout: Maximum execution time in seconds (default: 60.0)
|
| 510 |
+
"""
|
| 511 |
+
if not self.current_task:
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
self.current_task.status = ManipulationStatus.RUNNING
|
| 516 |
+
self.current_task.started_at = time.time()
|
| 517 |
+
|
| 518 |
+
logger.info(f"Starting control loop for: {task}")
|
| 519 |
+
logger.info(f"Timeout: {timeout}s (max_steps={max_steps} is informational)")
|
| 520 |
+
|
| 521 |
+
# Clear task stop event
|
| 522 |
+
self._task_stop_event.clear()
|
| 523 |
+
|
| 524 |
+
# Run control_loop in a separate thread so we can timeout
|
| 525 |
+
control_thread = threading.Thread(
|
| 526 |
+
target=lambda: self.robot_client.control_loop(task=task, verbose=False),
|
| 527 |
+
daemon=True,
|
| 528 |
+
name="ControlLoopInner"
|
| 529 |
+
)
|
| 530 |
+
control_thread.start()
|
| 531 |
+
|
| 532 |
+
# Wait for completion or timeout
|
| 533 |
+
control_thread.join(timeout=timeout)
|
| 534 |
+
|
| 535 |
+
# Check if thread is still alive (timeout occurred)
|
| 536 |
+
if control_thread.is_alive():
|
| 537 |
+
logger.warning(f"Task timed out after {timeout}s")
|
| 538 |
+
|
| 539 |
+
# Mark task as stopped first
|
| 540 |
+
self.current_task.status = ManipulationStatus.STOPPED
|
| 541 |
+
self.current_task.completed_at = time.time()
|
| 542 |
+
self.current_task.error = f"Task exceeded timeout of {timeout}s"
|
| 543 |
+
|
| 544 |
+
# Signal stop event
|
| 545 |
+
self._task_stop_event.set()
|
| 546 |
+
|
| 547 |
+
# Stop the robot client to interrupt the control loop
|
| 548 |
+
# This will cause the control thread to error out, but we catch it
|
| 549 |
+
logger.info("Stopping robot client to interrupt control loop...")
|
| 550 |
+
try:
|
| 551 |
+
self._stop_robot_client()
|
| 552 |
+
except Exception as e:
|
| 553 |
+
logger.warning(f"Error stopping client (expected): {e}")
|
| 554 |
+
|
| 555 |
+
# Wait a bit for thread to die
|
| 556 |
+
control_thread.join(timeout=2.0)
|
| 557 |
+
|
| 558 |
+
logger.info("Task stopped due to timeout")
|
| 559 |
+
|
| 560 |
+
# Move robot to idle position using callback if provided
|
| 561 |
+
if hasattr(self, '_idle_callback') and self._idle_callback:
|
| 562 |
+
logger.info("Moving robot to idle position...")
|
| 563 |
+
try:
|
| 564 |
+
self._idle_callback()
|
| 565 |
+
logger.info("Robot moved to idle position")
|
| 566 |
+
except Exception as e:
|
| 567 |
+
logger.error(f"Failed to move to idle: {e}")
|
| 568 |
+
|
| 569 |
+
# Clear the task after a delay so UI can show the stopped status
|
| 570 |
+
def clear_task():
|
| 571 |
+
time.sleep(3.0) # Show stopped status for 3 seconds
|
| 572 |
+
if self.current_task and self.current_task.status == ManipulationStatus.STOPPED:
|
| 573 |
+
self.current_task = None
|
| 574 |
+
logger.info("Cleared stopped task from status")
|
| 575 |
+
|
| 576 |
+
clear_thread = threading.Thread(target=clear_task, daemon=True)
|
| 577 |
+
clear_thread.start()
|
| 578 |
+
|
| 579 |
+
else:
|
| 580 |
+
# Task completed successfully
|
| 581 |
+
self.current_task.status = ManipulationStatus.COMPLETE
|
| 582 |
+
self.current_task.completed_at = time.time()
|
| 583 |
+
logger.info(f"Task completed in {self.current_task.duration:.2f}s")
|
| 584 |
+
|
| 585 |
+
# Clear completed task after showing status
|
| 586 |
+
def clear_task():
|
| 587 |
+
time.sleep(3.0) # Show completed status for 3 seconds
|
| 588 |
+
if self.current_task and self.current_task.status == ManipulationStatus.COMPLETE:
|
| 589 |
+
self.current_task = None
|
| 590 |
+
logger.info("Cleared completed task from status")
|
| 591 |
+
|
| 592 |
+
clear_thread = threading.Thread(target=clear_task, daemon=True)
|
| 593 |
+
clear_thread.start()
|
| 594 |
+
|
| 595 |
+
except KeyboardInterrupt:
|
| 596 |
+
logger.info("Task interrupted by user")
|
| 597 |
+
self.current_task.status = ManipulationStatus.STOPPED
|
| 598 |
+
self.current_task.completed_at = time.time()
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
logger.error(f"Task failed: {e}", exc_info=True)
|
| 602 |
+
self.current_task.status = ManipulationStatus.FAILED
|
| 603 |
+
self.current_task.error = str(e)
|
| 604 |
+
self.current_task.completed_at = time.time()
|
| 605 |
+
|
| 606 |
+
def get_status(self) -> ManipulationStatus:
|
| 607 |
+
"""
|
| 608 |
+
Get the current task status.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
Current ManipulationStatus
|
| 612 |
+
"""
|
| 613 |
+
if self.current_task:
|
| 614 |
+
return self.current_task.status
|
| 615 |
+
return ManipulationStatus.IDLE
|
| 616 |
+
|
| 617 |
+
def get_current_task(self) -> Optional[ManipulationTask]:
|
| 618 |
+
"""
|
| 619 |
+
Get the currently executing task.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
Current ManipulationTask or None if idle
|
| 623 |
+
"""
|
| 624 |
+
return self.current_task
|
| 625 |
+
|
| 626 |
+
def is_busy(self) -> bool:
|
| 627 |
+
"""
|
| 628 |
+
Check if a task is currently executing.
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
True if a task is running
|
| 632 |
+
"""
|
| 633 |
+
return (
|
| 634 |
+
self.current_task is not None and
|
| 635 |
+
self.current_task.status == ManipulationStatus.RUNNING
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def is_running(self) -> bool:
|
| 639 |
+
"""
|
| 640 |
+
Check if the client is running (server and robot connected).
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
True if client is running
|
| 644 |
+
"""
|
| 645 |
+
return self._running
|
| 646 |
+
|
| 647 |
+
def set_idle_callback(self, callback: Callable) -> None:
|
| 648 |
+
"""
|
| 649 |
+
Set a callback function to move the robot to idle position.
|
| 650 |
+
|
| 651 |
+
This callback will be called when a task times out, to safely
|
| 652 |
+
return the robot to a neutral position.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
callback: Function to call (e.g., lambda: mortis_arm.move_arm("idle"))
|
| 656 |
+
"""
|
| 657 |
+
self._idle_callback = callback
|
| 658 |
+
logger.info("Idle callback configured")
|
| 659 |
+
|
| 660 |
+
def __enter__(self):
|
| 661 |
+
"""Context manager entry: start the client."""
|
| 662 |
+
self.start()
|
| 663 |
+
return self
|
| 664 |
+
|
| 665 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 666 |
+
"""Context manager exit: stop the client."""
|
| 667 |
+
self.stop()
|
| 668 |
+
return False
|
src/mortis/models.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for Gemini API responses and intent routing.
|
| 3 |
+
|
| 4 |
+
This module defines the structured data types used throughout the Mortis system
|
| 5 |
+
for parsing Gemini responses, routing intents, and managing execution tasks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Optional, Dict, Any
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ResponseType(Enum):
|
| 15 |
+
"""Type of response from Gemini API."""
|
| 16 |
+
CONVERSATION = "conversation"
|
| 17 |
+
MANIPULATION = "manipulation"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Mood(Enum):
|
| 21 |
+
"""Emotional mood for Mortis character responses."""
|
| 22 |
+
OMINOUS = "ominous"
|
| 23 |
+
PLAYFUL = "playful"
|
| 24 |
+
ANGRY = "angry"
|
| 25 |
+
NERVOUS = "nervous"
|
| 26 |
+
TRIUMPHANT = "triumphant"
|
| 27 |
+
MISCHIEVOUS = "mischievous"
|
| 28 |
+
SINISTER = "sinister"
|
| 29 |
+
CURIOUS = "curious"
|
| 30 |
+
NEUTRAL = "neutral"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Gesture(Enum):
|
| 34 |
+
"""Available gesture actions for the SO101 robotic arm."""
|
| 35 |
+
IDLE = "idle"
|
| 36 |
+
WAVE = "wave"
|
| 37 |
+
POINT_LEFT = "point_left"
|
| 38 |
+
POINT_RIGHT = "point_right"
|
| 39 |
+
GRAB = "grab"
|
| 40 |
+
DROP = "drop"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class GeminiResponse:
|
| 45 |
+
"""
|
| 46 |
+
Structured response from Gemini API.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
type: Whether this is a conversation or manipulation response
|
| 50 |
+
message: The text message to display/speak to the user
|
| 51 |
+
mood: The emotional mood of the response
|
| 52 |
+
gesture: Optional gesture to execute (for conversation type)
|
| 53 |
+
command: Optional manipulation command (for manipulation type)
|
| 54 |
+
"""
|
| 55 |
+
type: ResponseType
|
| 56 |
+
message: str
|
| 57 |
+
mood: Mood
|
| 58 |
+
gesture: Optional[Gesture] = None
|
| 59 |
+
command: Optional[str] = None
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def from_json(cls, json_data: Dict[str, Any]) -> "GeminiResponse":
|
| 63 |
+
"""
|
| 64 |
+
Parse a GeminiResponse from JSON data returned by Gemini API.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
json_data: Dictionary containing the JSON response from Gemini
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
GeminiResponse object with validated fields
|
| 71 |
+
|
| 72 |
+
Raises:
|
| 73 |
+
ValueError: If required fields are missing or invalid
|
| 74 |
+
KeyError: If JSON structure is malformed
|
| 75 |
+
"""
|
| 76 |
+
# Validate required fields
|
| 77 |
+
if "type" not in json_data:
|
| 78 |
+
raise ValueError("Missing required field: 'type'")
|
| 79 |
+
if "message" not in json_data:
|
| 80 |
+
raise ValueError("Missing required field: 'message'")
|
| 81 |
+
if "mood" not in json_data:
|
| 82 |
+
raise ValueError("Missing required field: 'mood'")
|
| 83 |
+
|
| 84 |
+
# Parse response type
|
| 85 |
+
try:
|
| 86 |
+
response_type = ResponseType(json_data["type"])
|
| 87 |
+
except ValueError:
|
| 88 |
+
raise ValueError(f"Invalid response type: {json_data['type']}. Must be 'conversation' or 'manipulation'")
|
| 89 |
+
|
| 90 |
+
# Parse mood
|
| 91 |
+
try:
|
| 92 |
+
mood = Mood(json_data["mood"])
|
| 93 |
+
except ValueError:
|
| 94 |
+
raise ValueError(f"Invalid mood: {json_data['mood']}. Must be one of: {[m.value for m in Mood]}")
|
| 95 |
+
|
| 96 |
+
# Parse optional fields based on response type
|
| 97 |
+
gesture = None
|
| 98 |
+
command = None
|
| 99 |
+
|
| 100 |
+
if response_type == ResponseType.CONVERSATION:
|
| 101 |
+
# Conversation responses should have a gesture
|
| 102 |
+
if "gesture" in json_data:
|
| 103 |
+
try:
|
| 104 |
+
gesture = Gesture(json_data["gesture"])
|
| 105 |
+
except ValueError:
|
| 106 |
+
raise ValueError(f"Invalid gesture: {json_data['gesture']}. Must be one of: {[g.value for g in Gesture]}")
|
| 107 |
+
else:
|
| 108 |
+
# Default to idle if no gesture specified
|
| 109 |
+
gesture = Gesture.IDLE
|
| 110 |
+
|
| 111 |
+
elif response_type == ResponseType.MANIPULATION:
|
| 112 |
+
# Manipulation responses must have a command
|
| 113 |
+
if "command" not in json_data:
|
| 114 |
+
raise ValueError("Manipulation responses must include 'command' field")
|
| 115 |
+
command = json_data["command"]
|
| 116 |
+
if not isinstance(command, str) or not command.strip():
|
| 117 |
+
raise ValueError("Command must be a non-empty string")
|
| 118 |
+
|
| 119 |
+
# Validate message
|
| 120 |
+
message = json_data["message"]
|
| 121 |
+
if not isinstance(message, str) or not message.strip():
|
| 122 |
+
raise ValueError("Message must be a non-empty string")
|
| 123 |
+
|
| 124 |
+
return cls(
|
| 125 |
+
type=response_type,
|
| 126 |
+
message=message,
|
| 127 |
+
mood=mood,
|
| 128 |
+
gesture=gesture,
|
| 129 |
+
command=command
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def from_json_string(cls, json_string: str) -> "GeminiResponse":
|
| 134 |
+
"""
|
| 135 |
+
Parse a GeminiResponse from a JSON string.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
json_string: JSON string containing the Gemini response
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
GeminiResponse object with validated fields
|
| 142 |
+
|
| 143 |
+
Raises:
|
| 144 |
+
json.JSONDecodeError: If the string is not valid JSON
|
| 145 |
+
ValueError: If required fields are missing or invalid
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
json_data = json.loads(json_string)
|
| 149 |
+
except json.JSONDecodeError as e:
|
| 150 |
+
raise json.JSONDecodeError(f"Invalid JSON string: {e.msg}", e.doc, e.pos)
|
| 151 |
+
|
| 152 |
+
return cls.from_json(json_data)
|
| 153 |
+
|
| 154 |
+
def validate(self) -> bool:
|
| 155 |
+
"""
|
| 156 |
+
Validate the response structure and content.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
True if the response is valid
|
| 160 |
+
|
| 161 |
+
Raises:
|
| 162 |
+
ValueError: If validation fails
|
| 163 |
+
"""
|
| 164 |
+
# Check message length constraints (per product requirements)
|
| 165 |
+
if len(self.message) > 120:
|
| 166 |
+
raise ValueError(f"Message exceeds 120 characters: {len(self.message)} chars")
|
| 167 |
+
|
| 168 |
+
word_count = len(self.message.split())
|
| 169 |
+
if word_count > 30:
|
| 170 |
+
raise ValueError(f"Message exceeds 30 words: {word_count} words")
|
| 171 |
+
|
| 172 |
+
# Validate type-specific requirements
|
| 173 |
+
if self.type == ResponseType.CONVERSATION:
|
| 174 |
+
if self.gesture is None:
|
| 175 |
+
raise ValueError("Conversation responses must have a gesture")
|
| 176 |
+
if self.command is not None:
|
| 177 |
+
raise ValueError("Conversation responses should not have a command")
|
| 178 |
+
|
| 179 |
+
elif self.type == ResponseType.MANIPULATION:
|
| 180 |
+
if self.command is None or not self.command.strip():
|
| 181 |
+
raise ValueError("Manipulation responses must have a non-empty command")
|
| 182 |
+
if self.gesture is not None:
|
| 183 |
+
raise ValueError("Manipulation responses should not have a gesture")
|
| 184 |
+
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 188 |
+
"""
|
| 189 |
+
Convert the response to a dictionary.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Dictionary representation of the response
|
| 193 |
+
"""
|
| 194 |
+
result = {
|
| 195 |
+
"type": self.type.value,
|
| 196 |
+
"message": self.message,
|
| 197 |
+
"mood": self.mood.value,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if self.gesture is not None:
|
| 201 |
+
result["gesture"] = self.gesture.value
|
| 202 |
+
|
| 203 |
+
if self.command is not None:
|
| 204 |
+
result["command"] = self.command
|
| 205 |
+
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
def to_json(self) -> str:
|
| 209 |
+
"""
|
| 210 |
+
Convert the response to a JSON string.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
JSON string representation of the response
|
| 214 |
+
"""
|
| 215 |
+
return json.dumps(self.to_dict(), indent=2)
|
src/mortis/robot.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ({"shoulder_pan.pos": -45, "shoulder_lift.pos": -99, "elbow_flex.pos": 0, "wrist_flex.pos": 60, "wrist_roll.pos": 0, "gripper.pos": 60}, 0.5),
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
HOME_POSE = {
|
| 12 |
+
"shoulder_pan.pos": 0,
|
| 13 |
+
"shoulder_lift.pos": -99,
|
| 14 |
+
"elbow_flex.pos": 97,
|
| 15 |
+
"wrist_flex.pos": 55,
|
| 16 |
+
"wrist_roll.pos": 0,
|
| 17 |
+
"gripper.pos": 0,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
GESTURES = {
|
| 22 |
+
"idle": [
|
| 23 |
+
(HOME_POSE, 1.0),
|
| 24 |
+
],
|
| 25 |
+
"wave": [
|
| 26 |
+
({"wrist_flex.pos": -40}, 0.5),
|
| 27 |
+
({"shoulder_pan.pos": -5, "shoulder_lift.pos": 65, "elbow_flex.pos": -70}, 1),
|
| 28 |
+
({"shoulder_lift.pos": 0, "elbow_flex.pos": 0}, 0.5),
|
| 29 |
+
({"wrist_flex.pos": 0}, 0.5),
|
| 30 |
+
(HOME_POSE, 1.0),
|
| 31 |
+
],
|
| 32 |
+
"point_left": [
|
| 33 |
+
({"shoulder_pan.pos": -60, "shoulder_lift.pos": -30, "elbow_flex.pos": -15, "wrist_flex.pos": 42, "wrist_roll.pos": 0, "gripper.pos": 0}, 1),
|
| 34 |
+
({"wrist_flex.pos": 80}, 0.5),
|
| 35 |
+
({"wrist_flex.pos": 42}, 0.5),
|
| 36 |
+
({"wrist_flex.pos": 80}, 0.5),
|
| 37 |
+
(HOME_POSE, 1.0),
|
| 38 |
+
],
|
| 39 |
+
"point_right": [
|
| 40 |
+
({"shoulder_pan.pos": 65, "shoulder_lift.pos": -50, "elbow_flex.pos": -5, "wrist_flex.pos": 55, "wrist_roll.pos": 0, "gripper.pos": 0}, 1),
|
| 41 |
+
({"wrist_flex.pos": 90}, 0.5),
|
| 42 |
+
({"wrist_flex.pos": 42}, 0.5),
|
| 43 |
+
({"wrist_flex.pos": 90}, 0.5),
|
| 44 |
+
(HOME_POSE, 1.0),
|
| 45 |
+
],
|
| 46 |
+
"grab": [
|
| 47 |
+
({'shoulder_pan.pos': 0, 'shoulder_lift.pos': -2, 'elbow_flex.pos': -8., 'wrist_flex.pos': 55, 'wrist_roll.pos': 0, 'gripper.pos': 0}, 0.8),
|
| 48 |
+
({"wrist_flex.pos": 80}, 0.5),
|
| 49 |
+
({"wrist_roll.pos": -45, "gripper.pos": 40}, 1),
|
| 50 |
+
({"elbow_flex.pos": 30}, 1),
|
| 51 |
+
({"wrist_roll.pos": 45, "gripper.pos": 10}, 1),
|
| 52 |
+
({"elbow_flex.pos": -20}, 1),
|
| 53 |
+
(HOME_POSE, 1.0),
|
| 54 |
+
],
|
| 55 |
+
"drop": [
|
| 56 |
+
({'shoulder_pan.pos': 0, 'shoulder_lift.pos': 5, 'elbow_flex.pos': 20., 'wrist_flex.pos': 55, 'wrist_roll.pos': 0, 'gripper.pos': 0}, 0.8),
|
| 57 |
+
({"gripper.pos": 80}, 1),
|
| 58 |
+
({"gripper.pos": 00}, 1),
|
| 59 |
+
(HOME_POSE, 1.0),
|
| 60 |
+
],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MortisArm:
|
| 65 |
+
"""
|
| 66 |
+
Class to control the Mortis SO101 robotic arm.
|
| 67 |
+
Manages connection, disconnection, and gesture execution.
|
| 68 |
+
|
| 69 |
+
Supports two modes:
|
| 70 |
+
- physical: Connects to real robot hardware
|
| 71 |
+
- simulation: Simulates robot behavior without hardware
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, port="/dev/ttyACM1", mode=None):
|
| 75 |
+
port = os.getenv("ROBOT_PORT", port)
|
| 76 |
+
|
| 77 |
+
# Determine mode: check env var or use provided mode
|
| 78 |
+
if mode is None:
|
| 79 |
+
mode = os.getenv("ROBOT_MODE", "physical").lower()
|
| 80 |
+
|
| 81 |
+
self.mode = mode
|
| 82 |
+
self.connected = False
|
| 83 |
+
|
| 84 |
+
if self.mode == "simulation":
|
| 85 |
+
logger.info("🎭 MortisArm initialized in SIMULATION mode (no physical robot)")
|
| 86 |
+
self.robot = None
|
| 87 |
+
self.connected = True # Always "connected" in simulation
|
| 88 |
+
else:
|
| 89 |
+
config = SO101FollowerConfig(
|
| 90 |
+
port=port,
|
| 91 |
+
id="my_follower_robot_arm",
|
| 92 |
+
calibration_dir=Path(".cache/calibration/so101/"),
|
| 93 |
+
)
|
| 94 |
+
self.robot = SO101Follower(config)
|
| 95 |
+
logger.info(f"🤖 MortisArm initialized in PHYSICAL mode on port {port}")
|
| 96 |
+
|
| 97 |
+
def connect(self):
|
| 98 |
+
"""Connects to the robotic arm."""
|
| 99 |
+
if self.mode == "simulation":
|
| 100 |
+
logger.info("🎭 Simulation mode: skipping physical connection")
|
| 101 |
+
self.connected = True
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
if not self.connected:
|
| 105 |
+
try:
|
| 106 |
+
logger.info("Attempting to connect to robot arm...")
|
| 107 |
+
self.robot.connect()
|
| 108 |
+
self.connected = self.robot.is_connected
|
| 109 |
+
if self.connected:
|
| 110 |
+
logger.info("✅ Robot arm connected successfully")
|
| 111 |
+
# Move to the initial position to indicate it's ready
|
| 112 |
+
self.move_arm("idle")
|
| 113 |
+
else:
|
| 114 |
+
logger.warning("⚠️ Failed to establish connection to robot arm")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"❌ Connection error: {e}", exc_info=True)
|
| 117 |
+
self.connected = False
|
| 118 |
+
|
| 119 |
+
def disconnect(self):
|
| 120 |
+
"""Disconnects the robotic arm."""
|
| 121 |
+
if self.mode == "simulation":
|
| 122 |
+
logger.info("🎭 Simulation mode: skipping physical disconnection")
|
| 123 |
+
self.connected = False
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
if self.connected:
|
| 127 |
+
logger.info("Disconnecting robot arm...")
|
| 128 |
+
# Move to rest position before disconnecting
|
| 129 |
+
self.move_arm("idle")
|
| 130 |
+
time.sleep(1)
|
| 131 |
+
self.robot.disconnect()
|
| 132 |
+
self.connected = False
|
| 133 |
+
logger.info("✅ Robot arm disconnected")
|
| 134 |
+
|
| 135 |
+
def move_arm(self, gesture_name: str):
|
| 136 |
+
"""
|
| 137 |
+
Executes a sequence of movements (a gesture) by its name.
|
| 138 |
+
If the gesture does not exist, it executes 'idle'.
|
| 139 |
+
"""
|
| 140 |
+
if not self.connected:
|
| 141 |
+
logger.warning("⚠️ Cannot execute gesture: robot arm not connected")
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
# If the gesture is not defined, return to the neutral position.
|
| 145 |
+
if gesture_name not in GESTURES:
|
| 146 |
+
logger.warning(f"⚠️ Unknown gesture '{gesture_name}', falling back to 'idle'")
|
| 147 |
+
gesture_name = "idle"
|
| 148 |
+
|
| 149 |
+
sequence = GESTURES[gesture_name]
|
| 150 |
+
|
| 151 |
+
if self.mode == "simulation":
|
| 152 |
+
# Simulation mode: just log the gesture
|
| 153 |
+
logger.info(f"🎭 [SIMULATION] Executing gesture '{gesture_name}' ({len(sequence)} steps)")
|
| 154 |
+
|
| 155 |
+
# Simulate timing by sleeping for total duration
|
| 156 |
+
total_delay = sum(delay for _, delay in sequence)
|
| 157 |
+
time.sleep(total_delay)
|
| 158 |
+
|
| 159 |
+
logger.info(f"🎭 [SIMULATION] Gesture '{gesture_name}' completed")
|
| 160 |
+
else:
|
| 161 |
+
# Physical mode: execute on real robot
|
| 162 |
+
logger.info(f"🤖 Executing gesture '{gesture_name}' ({len(sequence)} steps)")
|
| 163 |
+
|
| 164 |
+
for i, (action, delay) in enumerate(sequence, 1):
|
| 165 |
+
logger.debug(f"Gesture '{gesture_name}' step {i}/{len(sequence)}: {action}")
|
| 166 |
+
self.robot.send_action(action)
|
| 167 |
+
time.sleep(delay)
|
| 168 |
+
|
| 169 |
+
logger.info(f"✅ Gesture '{gesture_name}' completed")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
|
| 174 |
+
mortis_arm = MortisArm()
|
| 175 |
+
if not mortis_arm.connected:
|
| 176 |
+
mortis_arm.connect()
|
| 177 |
+
|
| 178 |
+
mortis_arm.move_arm("drop")
|
| 179 |
+
|
| 180 |
+
mortis_arm.disconnect()
|
src/mortis/setup_dataset.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CLI tool for setting up Mortis dataset infrastructure.
|
| 4 |
+
|
| 5 |
+
This script initializes the dataset structure and generates
|
| 6 |
+
lerobot-record scripts for data collection.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import subprocess
|
| 12 |
+
import argparse
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
from mortis.data_collector import create_mortis_dataset, DataCollector
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_huggingface_auth():
|
| 20 |
+
"""Check if user is authenticated with Hugging Face."""
|
| 21 |
+
try:
|
| 22 |
+
result = subprocess.run(
|
| 23 |
+
["huggingface-cli", "whoami"],
|
| 24 |
+
capture_output=True,
|
| 25 |
+
text=True,
|
| 26 |
+
timeout=5
|
| 27 |
+
)
|
| 28 |
+
return result.returncode == 0
|
| 29 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
"""Main entry point for dataset setup."""
|
| 35 |
+
# Parse command line arguments
|
| 36 |
+
parser = argparse.ArgumentParser(
|
| 37 |
+
description="Setup Mortis dataset infrastructure and generate recording scripts"
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--dataset-name",
|
| 41 |
+
type=str,
|
| 42 |
+
default=None,
|
| 43 |
+
help="Name for the dataset (default: mortis_manipulation)"
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--hf-user",
|
| 47 |
+
type=str,
|
| 48 |
+
default=None,
|
| 49 |
+
help="Hugging Face username (default: from HF_USER env var)"
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
# Load environment variables from .env file
|
| 54 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 55 |
+
load_dotenv(REPO_ROOT / ".env")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
print("="*70)
|
| 59 |
+
print("Mortis Dataset Setup")
|
| 60 |
+
print("="*70)
|
| 61 |
+
print()
|
| 62 |
+
|
| 63 |
+
# Check Hugging Face authentication
|
| 64 |
+
print("Checking Hugging Face authentication...")
|
| 65 |
+
if not check_huggingface_auth():
|
| 66 |
+
print("⚠️ Not logged in to Hugging Face")
|
| 67 |
+
print("📝 You need to authenticate before recording datasets")
|
| 68 |
+
print()
|
| 69 |
+
print("Run this command to login:")
|
| 70 |
+
print(" huggingface-cli login")
|
| 71 |
+
print()
|
| 72 |
+
print("Get your token from: https://huggingface.co/settings/tokens")
|
| 73 |
+
print()
|
| 74 |
+
response = input("Continue anyway? (y/N): ").strip().lower()
|
| 75 |
+
if response != 'y':
|
| 76 |
+
print("Setup cancelled. Please login first with: huggingface-cli login")
|
| 77 |
+
sys.exit(0)
|
| 78 |
+
print()
|
| 79 |
+
else:
|
| 80 |
+
print("✅ Hugging Face authentication verified")
|
| 81 |
+
print()
|
| 82 |
+
|
| 83 |
+
# Get Hugging Face username
|
| 84 |
+
hf_user = args.hf_user or os.getenv("HF_USER")
|
| 85 |
+
if not hf_user:
|
| 86 |
+
print("⚠️ HF_USER not found in .env file or environment")
|
| 87 |
+
hf_user = input("Enter your Hugging Face username: ").strip()
|
| 88 |
+
if not hf_user:
|
| 89 |
+
print("❌ Hugging Face username is required")
|
| 90 |
+
sys.exit(1)
|
| 91 |
+
print(f"💡 Tip: Add HF_USER to your .env file to skip this prompt:")
|
| 92 |
+
print(f" echo 'HF_USER={hf_user}' >> .env")
|
| 93 |
+
print()
|
| 94 |
+
|
| 95 |
+
# Get dataset name
|
| 96 |
+
dataset_name = args.dataset_name
|
| 97 |
+
if not dataset_name:
|
| 98 |
+
print("Dataset name:")
|
| 99 |
+
print(" Press Enter for default: 'mortis_manipulation'")
|
| 100 |
+
print(" Or enter a custom name (e.g., 'mortis_v2', 'test_dataset')")
|
| 101 |
+
user_input = input("Dataset name: ").strip()
|
| 102 |
+
dataset_name = user_input if user_input else "mortis_manipulation"
|
| 103 |
+
print()
|
| 104 |
+
|
| 105 |
+
# Create repository ID
|
| 106 |
+
repo_id = f"{hf_user}/{dataset_name}"
|
| 107 |
+
|
| 108 |
+
print(f"Creating dataset: {dataset_name}")
|
| 109 |
+
print(f"Repository: {repo_id}")
|
| 110 |
+
print()
|
| 111 |
+
|
| 112 |
+
# Create collector with custom name
|
| 113 |
+
collector = DataCollector(dataset_name, repo_id)
|
| 114 |
+
|
| 115 |
+
# Generate scripts
|
| 116 |
+
print("\nGenerating recording scripts...")
|
| 117 |
+
collector.generate_all_record_scripts()
|
| 118 |
+
|
| 119 |
+
# Show summary
|
| 120 |
+
collector.print_summary()
|
| 121 |
+
|
| 122 |
+
# Show instructions
|
| 123 |
+
collector.print_recording_instructions()
|
| 124 |
+
|
| 125 |
+
# Final instructions
|
| 126 |
+
print("="*70)
|
| 127 |
+
print("Setup Complete! 🎉")
|
| 128 |
+
print("="*70)
|
| 129 |
+
print()
|
| 130 |
+
print("Next steps:")
|
| 131 |
+
print(" 1. Make sure you're logged in to Hugging Face:")
|
| 132 |
+
print(" huggingface-cli login")
|
| 133 |
+
print(" 2. Connect your leader and follower robot arms")
|
| 134 |
+
print(" 3. Navigate to the scripts directory:")
|
| 135 |
+
print(f" cd {collector.dataset_dir}/scripts")
|
| 136 |
+
print(" 4. Run a recording script:")
|
| 137 |
+
print(" ./record_task_0.sh")
|
| 138 |
+
print()
|
| 139 |
+
print("Or record all tasks:")
|
| 140 |
+
print(" ./record_all_tasks.sh")
|
| 141 |
+
print()
|
| 142 |
+
print("="*70)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
src/mortis/setup_train.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CLI tool for setting up Mortis training infrastructure.
|
| 4 |
+
|
| 5 |
+
This script generates lerobot-train scripts with appropriate
|
| 6 |
+
configurations for training SmolVLA models on Mortis datasets.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import argparse
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TrainingScriptGenerator:
|
| 17 |
+
"""
|
| 18 |
+
Helper for generating lerobot-train scripts.
|
| 19 |
+
|
| 20 |
+
This class generates shell scripts that call lerobot-train with the
|
| 21 |
+
correct parameters for training SmolVLA models on Mortis datasets.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
dataset_repo_id: Hugging Face dataset repository ID
|
| 25 |
+
output_dir: Directory for training outputs
|
| 26 |
+
job_name: Name for the training job
|
| 27 |
+
model_repo_id: Optional Hugging Face model repository ID for pushing
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dataset_repo_id: str,
|
| 33 |
+
output_dir: str = "outputs/train",
|
| 34 |
+
job_name: str = "smolvla_mortis",
|
| 35 |
+
model_repo_id: str = None,
|
| 36 |
+
scripts_dir: str = "train"
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Initialize the TrainingScriptGenerator.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
dataset_repo_id: Hugging Face dataset repository ID
|
| 43 |
+
output_dir: Base directory for training outputs (checkpoints, logs)
|
| 44 |
+
job_name: Name for the training job
|
| 45 |
+
model_repo_id: Optional HF model repo ID for pushing trained model
|
| 46 |
+
scripts_dir: Directory to save training scripts
|
| 47 |
+
"""
|
| 48 |
+
self.dataset_repo_id = dataset_repo_id
|
| 49 |
+
self.output_dir = Path(output_dir)
|
| 50 |
+
self.job_name = job_name
|
| 51 |
+
self.model_repo_id = model_repo_id
|
| 52 |
+
self.scripts_dir = Path(scripts_dir)
|
| 53 |
+
|
| 54 |
+
# Create scripts directory
|
| 55 |
+
self.scripts_dir.mkdir(parents=True, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
print(f"TrainingScriptGenerator initialized:")
|
| 58 |
+
print(f" Dataset: {self.dataset_repo_id}")
|
| 59 |
+
print(f" Scripts directory: {self.scripts_dir}")
|
| 60 |
+
print(f" Training output directory: {self.output_dir}")
|
| 61 |
+
print(f" Job name: {self.job_name}")
|
| 62 |
+
if self.model_repo_id:
|
| 63 |
+
print(f" Model repository: {self.model_repo_id}")
|
| 64 |
+
|
| 65 |
+
def generate_train_command(
|
| 66 |
+
self,
|
| 67 |
+
policy_path: str = "lerobot/smolvla_base",
|
| 68 |
+
batch_size: int = 16,
|
| 69 |
+
steps: int = 20000,
|
| 70 |
+
save_freq: int = 5000,
|
| 71 |
+
eval_freq: int = 5000,
|
| 72 |
+
n_action_steps: int = 50,
|
| 73 |
+
chunk_size: int = 50,
|
| 74 |
+
use_amp: bool = True,
|
| 75 |
+
enable_wandb: bool = True,
|
| 76 |
+
device: str = "cuda",
|
| 77 |
+
image_transforms: bool = True,
|
| 78 |
+
rename_map: str = None,
|
| 79 |
+
cuda_alloc_conf: str = "expandable_segments:True"
|
| 80 |
+
) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Generate a lerobot-train command with specified parameters.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
policy_path: Path to base policy (default: lerobot/smolvla_base)
|
| 86 |
+
batch_size: Training batch size
|
| 87 |
+
steps: Total training steps
|
| 88 |
+
save_freq: Checkpoint save frequency
|
| 89 |
+
eval_freq: Evaluation frequency
|
| 90 |
+
n_action_steps: Number of action steps to predict
|
| 91 |
+
chunk_size: Action chunk size
|
| 92 |
+
use_amp: Use automatic mixed precision
|
| 93 |
+
enable_wandb: Enable Weights & Biases logging
|
| 94 |
+
device: Device to use (cuda or cpu)
|
| 95 |
+
image_transforms: Enable image transformations
|
| 96 |
+
rename_map: Optional observation key rename mapping
|
| 97 |
+
cuda_alloc_conf: CUDA memory allocator configuration
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
The complete lerobot-train command as a string
|
| 101 |
+
"""
|
| 102 |
+
# Load environment variables
|
| 103 |
+
load_dotenv()
|
| 104 |
+
|
| 105 |
+
# Build output directory path
|
| 106 |
+
full_output_dir = self.output_dir / self.job_name
|
| 107 |
+
|
| 108 |
+
# Default rename map for SO101 with dual cameras
|
| 109 |
+
if rename_map is None:
|
| 110 |
+
rename_map = (
|
| 111 |
+
'{"observation.images.camera1": "observation.images.camera1", '
|
| 112 |
+
'"observation.images.camera2": "observation.images.camera2"}'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Build the command
|
| 116 |
+
cmd_parts = [
|
| 117 |
+
f"PYTORCH_CUDA_ALLOC_CONF={cuda_alloc_conf} \\",
|
| 118 |
+
"lerobot-train \\",
|
| 119 |
+
f" --policy.path={policy_path} \\",
|
| 120 |
+
f" --dataset.repo_id={self.dataset_repo_id} \\",
|
| 121 |
+
f" --dataset.image_transforms.enable={str(image_transforms).lower()} \\",
|
| 122 |
+
f" --policy.device={device} \\",
|
| 123 |
+
f" --policy.use_amp={str(use_amp).lower()} \\",
|
| 124 |
+
f" --policy.n_action_steps={n_action_steps} \\",
|
| 125 |
+
f" --policy.chunk_size={chunk_size} \\",
|
| 126 |
+
f" --batch_size={batch_size} \\",
|
| 127 |
+
f" --steps={steps} \\",
|
| 128 |
+
f" --save_checkpoint=true \\",
|
| 129 |
+
f" --save_freq={save_freq} \\",
|
| 130 |
+
f" --eval_freq={eval_freq} \\",
|
| 131 |
+
f" --wandb.enable={str(enable_wandb).lower()} \\",
|
| 132 |
+
f" --output_dir={full_output_dir} \\",
|
| 133 |
+
f" --job_name={self.job_name} \\",
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
# Add model repo ID if specified
|
| 137 |
+
if self.model_repo_id:
|
| 138 |
+
cmd_parts.append(f" --policy.repo_id={self.model_repo_id} \\")
|
| 139 |
+
|
| 140 |
+
# Add rename map
|
| 141 |
+
cmd_parts.append(f" --rename_map='{rename_map}'")
|
| 142 |
+
|
| 143 |
+
return "\n".join(cmd_parts)
|
| 144 |
+
|
| 145 |
+
def generate_training_script(
|
| 146 |
+
self,
|
| 147 |
+
script_name: str = "train.sh",
|
| 148 |
+
**kwargs
|
| 149 |
+
) -> Path:
|
| 150 |
+
"""
|
| 151 |
+
Generate a shell script for training.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
script_name: Name for the training script
|
| 155 |
+
**kwargs: Additional arguments passed to generate_train_command
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Path to the generated script
|
| 159 |
+
"""
|
| 160 |
+
script_path = self.scripts_dir / script_name
|
| 161 |
+
|
| 162 |
+
with open(script_path, 'w') as f:
|
| 163 |
+
f.write("#!/bin/bash\n")
|
| 164 |
+
f.write(f"# Training script for {self.job_name}\n")
|
| 165 |
+
f.write(f"# Dataset: {self.dataset_repo_id}\n")
|
| 166 |
+
f.write(f"# Generated by setup_train.py\n\n")
|
| 167 |
+
|
| 168 |
+
f.write("# Check if CUDA is available\n")
|
| 169 |
+
f.write("if ! command -v nvidia-smi &> /dev/null; then\n")
|
| 170 |
+
f.write(' echo "⚠️ Warning: nvidia-smi not found. CUDA may not be available."\n')
|
| 171 |
+
f.write(' read -p "Continue anyway? (y/N): " -n 1 -r\n')
|
| 172 |
+
f.write(' echo\n')
|
| 173 |
+
f.write(' if [[ ! $REPLY =~ ^[Yy]$ ]]; then\n')
|
| 174 |
+
f.write(' exit 1\n')
|
| 175 |
+
f.write(' fi\n')
|
| 176 |
+
f.write("fi\n\n")
|
| 177 |
+
|
| 178 |
+
f.write("# Start training\n")
|
| 179 |
+
f.write(f'echo "Starting training: {self.job_name}"\n')
|
| 180 |
+
f.write(f'echo "Dataset: {self.dataset_repo_id}"\n')
|
| 181 |
+
f.write(f'echo "Output: {self.output_dir / self.job_name}"\n')
|
| 182 |
+
f.write('echo ""\n\n')
|
| 183 |
+
|
| 184 |
+
f.write(self.generate_train_command(**kwargs))
|
| 185 |
+
f.write("\n")
|
| 186 |
+
|
| 187 |
+
# Make script executable
|
| 188 |
+
script_path.chmod(0o755)
|
| 189 |
+
print(f"Created: {script_path}")
|
| 190 |
+
|
| 191 |
+
return script_path
|
| 192 |
+
|
| 193 |
+
def generate_training_configs(self):
|
| 194 |
+
"""
|
| 195 |
+
Generate multiple training scripts with different configurations.
|
| 196 |
+
|
| 197 |
+
Creates:
|
| 198 |
+
- train_quick.sh: Quick test training (1000 steps)
|
| 199 |
+
- train_standard.sh: Standard training (20k steps)
|
| 200 |
+
- train_full.sh: Full training (100k steps)
|
| 201 |
+
"""
|
| 202 |
+
configs = [
|
| 203 |
+
{
|
| 204 |
+
"script_name": "train_quick.sh",
|
| 205 |
+
"steps": 1000,
|
| 206 |
+
"save_freq": 500,
|
| 207 |
+
"eval_freq": 500,
|
| 208 |
+
"batch_size": 8,
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"script_name": "train_standard.sh",
|
| 212 |
+
"steps": 20000,
|
| 213 |
+
"save_freq": 5000,
|
| 214 |
+
"eval_freq": 5000,
|
| 215 |
+
"batch_size": 16,
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"script_name": "train_full.sh",
|
| 219 |
+
"steps": 100000,
|
| 220 |
+
"save_freq": 10000,
|
| 221 |
+
"eval_freq": 10000,
|
| 222 |
+
"batch_size": 16,
|
| 223 |
+
},
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
for config in configs:
|
| 227 |
+
self.generate_training_script(**config)
|
| 228 |
+
|
| 229 |
+
print(f"\n✅ Generated {len(configs)} training scripts in {self.scripts_dir}")
|
| 230 |
+
|
| 231 |
+
def print_usage_instructions(self):
|
| 232 |
+
"""Print instructions for using the generated training scripts."""
|
| 233 |
+
print("\n" + "="*70)
|
| 234 |
+
print("Training Scripts Generated")
|
| 235 |
+
print("="*70)
|
| 236 |
+
print()
|
| 237 |
+
print("Available training scripts:")
|
| 238 |
+
print(f" {self.scripts_dir}/train_quick.sh - Quick test (1k steps)")
|
| 239 |
+
print(f" {self.scripts_dir}/train_standard.sh - Standard training (20k steps)")
|
| 240 |
+
print(f" {self.scripts_dir}/train_full.sh - Full training (100k steps)")
|
| 241 |
+
print()
|
| 242 |
+
print("To start training:")
|
| 243 |
+
print(f" cd {self.scripts_dir}")
|
| 244 |
+
print(" ./train_standard.sh")
|
| 245 |
+
print()
|
| 246 |
+
print("Training outputs will be saved to:")
|
| 247 |
+
print(f" {self.output_dir}/{self.job_name}/")
|
| 248 |
+
print()
|
| 249 |
+
print("Monitor training:")
|
| 250 |
+
print(" - Console: Watch the terminal output")
|
| 251 |
+
print(" - W&B: https://wandb.ai (if enabled)")
|
| 252 |
+
print(f" - Checkpoints: {self.output_dir}/{self.job_name}/checkpoints/")
|
| 253 |
+
print()
|
| 254 |
+
print("Resume training:")
|
| 255 |
+
print(" Add --resume=true to the lerobot-train command")
|
| 256 |
+
print()
|
| 257 |
+
print("="*70)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
"""Main entry point for training setup."""
|
| 262 |
+
# Parse command line arguments
|
| 263 |
+
parser = argparse.ArgumentParser(
|
| 264 |
+
description="Setup Mortis training infrastructure and generate training scripts"
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--dataset-repo-id",
|
| 268 |
+
type=str,
|
| 269 |
+
required=True,
|
| 270 |
+
help="Hugging Face dataset repository ID (e.g., username/dataset-name)"
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--output-dir",
|
| 274 |
+
type=str,
|
| 275 |
+
default="outputs/train",
|
| 276 |
+
help="Base directory for training outputs/checkpoints (default: outputs/train)"
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--scripts-dir",
|
| 280 |
+
type=str,
|
| 281 |
+
default="train",
|
| 282 |
+
help="Directory to save training scripts (default: train)"
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--job-name",
|
| 286 |
+
type=str,
|
| 287 |
+
default=None,
|
| 288 |
+
help="Name for the training job (default: derived from dataset name)"
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--model-repo-id",
|
| 292 |
+
type=str,
|
| 293 |
+
default=None,
|
| 294 |
+
help="Hugging Face model repository ID for pushing trained model"
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--batch-size",
|
| 298 |
+
type=int,
|
| 299 |
+
default=16,
|
| 300 |
+
help="Training batch size (default: 16)"
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--steps",
|
| 304 |
+
type=int,
|
| 305 |
+
default=20000,
|
| 306 |
+
help="Total training steps (default: 20000)"
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--policy-path",
|
| 310 |
+
type=str,
|
| 311 |
+
default="lerobot/smolvla_base",
|
| 312 |
+
help="Path to base policy (default: lerobot/smolvla_base)"
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--no-wandb",
|
| 316 |
+
action="store_true",
|
| 317 |
+
help="Disable Weights & Biases logging"
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--generate-configs",
|
| 321 |
+
action="store_true",
|
| 322 |
+
help="Generate multiple training configurations (quick, standard, full)"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
args = parser.parse_args()
|
| 326 |
+
|
| 327 |
+
# Load environment variables
|
| 328 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 329 |
+
load_dotenv(REPO_ROOT / ".env")
|
| 330 |
+
|
| 331 |
+
print("="*70)
|
| 332 |
+
print("Mortis Training Setup")
|
| 333 |
+
print("="*70)
|
| 334 |
+
print()
|
| 335 |
+
|
| 336 |
+
# Derive job name from dataset if not provided
|
| 337 |
+
job_name = args.job_name
|
| 338 |
+
if not job_name:
|
| 339 |
+
# Extract dataset name from repo_id
|
| 340 |
+
dataset_name = args.dataset_repo_id.split('/')[-1]
|
| 341 |
+
job_name = f"smolvla_{dataset_name}"
|
| 342 |
+
print(f"Using job name: {job_name}")
|
| 343 |
+
print()
|
| 344 |
+
|
| 345 |
+
# Create generator
|
| 346 |
+
generator = TrainingScriptGenerator(
|
| 347 |
+
dataset_repo_id=args.dataset_repo_id,
|
| 348 |
+
output_dir=args.output_dir,
|
| 349 |
+
job_name=job_name,
|
| 350 |
+
model_repo_id=args.model_repo_id,
|
| 351 |
+
scripts_dir=args.scripts_dir
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
print()
|
| 355 |
+
|
| 356 |
+
if args.generate_configs:
|
| 357 |
+
# Generate multiple configurations
|
| 358 |
+
print("Generating training configurations...")
|
| 359 |
+
generator.generate_training_configs()
|
| 360 |
+
else:
|
| 361 |
+
# Generate single training script
|
| 362 |
+
print("Generating training script...")
|
| 363 |
+
generator.generate_training_script(
|
| 364 |
+
script_name="train.sh",
|
| 365 |
+
policy_path=args.policy_path,
|
| 366 |
+
batch_size=args.batch_size,
|
| 367 |
+
steps=args.steps,
|
| 368 |
+
enable_wandb=not args.no_wandb
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Print usage instructions
|
| 372 |
+
generator.print_usage_instructions()
|
| 373 |
+
|
| 374 |
+
# Final tips
|
| 375 |
+
print("\n💡 Tips:")
|
| 376 |
+
print(" - Adjust batch_size based on your GPU memory")
|
| 377 |
+
print(" - Monitor GPU usage with: watch -n 1 nvidia-smi")
|
| 378 |
+
print(" - Training logs are saved in the output directory")
|
| 379 |
+
print(" - Use Ctrl+C to stop training (checkpoints are saved)")
|
| 380 |
+
print()
|
| 381 |
+
print("="*70)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if __name__ == "__main__":
|
| 385 |
+
main()
|
src/mortis/smolvla_executor.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolVLA Executor for vision-language-action robotic manipulation.
|
| 3 |
+
|
| 4 |
+
This module implements the SmolVLA model executor that performs inference
|
| 5 |
+
for manipulation tasks using the trained SmolVLA policy from LeRobot.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Dict, Any, Tuple
|
| 13 |
+
from threading import Lock, Event
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
from PIL import Image as PILImage
|
| 18 |
+
|
| 19 |
+
# LeRobot imports
|
| 20 |
+
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
| 21 |
+
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
| 22 |
+
|
| 23 |
+
# Local imports
|
| 24 |
+
from .robot import MortisArm, HOME_POSE
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Configure logging
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SmolVLAError(Exception):
|
| 32 |
+
"""Base exception for SmolVLA executor errors."""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SafetyViolationError(SmolVLAError):
|
| 37 |
+
"""Exception raised when a safety constraint is violated."""
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TimeoutError(SmolVLAError):
|
| 42 |
+
"""Exception raised when execution exceeds timeout."""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GPUOutOfMemoryError(SmolVLAError):
|
| 47 |
+
"""Exception raised when GPU runs out of memory."""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SmolVLAExecutor:
|
| 52 |
+
"""
|
| 53 |
+
Executor for SmolVLA vision-language-action model inference.
|
| 54 |
+
|
| 55 |
+
This class handles loading the trained SmolVLA model, capturing observations
|
| 56 |
+
from the robot and camera, running inference, and executing predicted actions
|
| 57 |
+
on the SO101 robotic arm.
|
| 58 |
+
|
| 59 |
+
Attributes:
|
| 60 |
+
checkpoint_path: Path to the trained model checkpoint
|
| 61 |
+
device: Device to run inference on ('cuda' or 'cpu')
|
| 62 |
+
policy: Loaded SmolVLA policy model
|
| 63 |
+
robot_arm: Reference to MortisArm instance for action execution
|
| 64 |
+
camera: Camera interface for visual observations (to be implemented)
|
| 65 |
+
valid_commands: List of trained manipulation task commands
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
# Valid manipulation commands that the model was trained on
|
| 69 |
+
VALID_COMMANDS = [
|
| 70 |
+
"Pick up the skull and place it in the green cup",
|
| 71 |
+
"Pick up the skull and place it in the orange cup",
|
| 72 |
+
"Pick up the skull and place it in the purple cup",
|
| 73 |
+
"Pick up the eyeball and place it in the green cup",
|
| 74 |
+
"Pick up the eyeball and place it in the orange cup",
|
| 75 |
+
"Pick up the eyeball and place it in the purple cup",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Safety limits for joint positions (in degrees)
|
| 79 |
+
# These define the safe workspace boundaries
|
| 80 |
+
# Based on SO101 calibration and physical constraints
|
| 81 |
+
JOINT_LIMITS = {
|
| 82 |
+
"shoulder_pan.pos": (-180, 180),
|
| 83 |
+
"shoulder_lift.pos": (-120, 120), # Extended range for SO101
|
| 84 |
+
"elbow_flex.pos": (-135, 135),
|
| 85 |
+
"wrist_flex.pos": (-105, 105), # Extended range for SO101
|
| 86 |
+
"wrist_roll.pos": (-180, 180),
|
| 87 |
+
"gripper.pos": (0, 100), # 0=open, 100=closed
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Maximum allowed joint velocity (degrees per step)
|
| 91 |
+
MAX_JOINT_VELOCITY = 10.0
|
| 92 |
+
|
| 93 |
+
# Default execution timeout (seconds)
|
| 94 |
+
DEFAULT_TIMEOUT = 30.0
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
checkpoint_path: str,
|
| 99 |
+
robot_arm: Optional[MortisArm] = None,
|
| 100 |
+
device: Optional[str] = None,
|
| 101 |
+
enable_safety_checks: bool = True,
|
| 102 |
+
timeout: Optional[float] = None
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Initialize the SmolVLA executor.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
checkpoint_path: Path to the trained SmolVLA model checkpoint
|
| 109 |
+
robot_arm: Optional MortisArm instance (will create if not provided)
|
| 110 |
+
device: Device to run inference on ('cuda', 'cpu', or None for auto-detect)
|
| 111 |
+
enable_safety_checks: Whether to enable workspace safety checks
|
| 112 |
+
timeout: Execution timeout in seconds (None for default)
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
SmolVLAError: If checkpoint path doesn't exist or model loading fails
|
| 116 |
+
"""
|
| 117 |
+
# Initialize attributes first (for cleanup in case of early failure)
|
| 118 |
+
self.camera = None
|
| 119 |
+
self.policy = None
|
| 120 |
+
self.preprocessor = None
|
| 121 |
+
self.postprocessor = None
|
| 122 |
+
|
| 123 |
+
self.checkpoint_path = Path(checkpoint_path)
|
| 124 |
+
|
| 125 |
+
# Validate checkpoint path
|
| 126 |
+
if not self.checkpoint_path.exists():
|
| 127 |
+
raise SmolVLAError(f"Checkpoint path does not exist: {checkpoint_path}")
|
| 128 |
+
|
| 129 |
+
# Set device
|
| 130 |
+
if device is None:
|
| 131 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 132 |
+
else:
|
| 133 |
+
self.device = device
|
| 134 |
+
|
| 135 |
+
logger.info(f"Initializing SmolVLA executor on device: {self.device}")
|
| 136 |
+
|
| 137 |
+
# Safety configuration
|
| 138 |
+
self.enable_safety_checks = enable_safety_checks
|
| 139 |
+
self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT
|
| 140 |
+
|
| 141 |
+
# Emergency stop flag and lock
|
| 142 |
+
self._emergency_stop_flag = Event()
|
| 143 |
+
self._execution_lock = Lock()
|
| 144 |
+
self._is_executing = False
|
| 145 |
+
|
| 146 |
+
# Previous state for velocity checking
|
| 147 |
+
self._previous_state = None
|
| 148 |
+
|
| 149 |
+
logger.info(f"Safety checks: {'enabled' if enable_safety_checks else 'disabled'}")
|
| 150 |
+
logger.info(f"Execution timeout: {self.timeout}s")
|
| 151 |
+
|
| 152 |
+
# Initialize robot arm
|
| 153 |
+
self.robot_arm = robot_arm
|
| 154 |
+
if self.robot_arm is None:
|
| 155 |
+
logger.info("No robot arm provided, creating new MortisArm instance")
|
| 156 |
+
self.robot_arm = MortisArm()
|
| 157 |
+
|
| 158 |
+
# Load the model
|
| 159 |
+
self._load_model()
|
| 160 |
+
|
| 161 |
+
# Model is ready
|
| 162 |
+
logger.info("SmolVLA executor initialized successfully")
|
| 163 |
+
|
| 164 |
+
def _load_model(self):
|
| 165 |
+
"""
|
| 166 |
+
Load the SmolVLA model from checkpoint.
|
| 167 |
+
|
| 168 |
+
Raises:
|
| 169 |
+
SmolVLAError: If model loading fails
|
| 170 |
+
"""
|
| 171 |
+
try:
|
| 172 |
+
logger.info(f"Loading SmolVLA model from: {self.checkpoint_path}")
|
| 173 |
+
|
| 174 |
+
# Load configuration - handle extra fields in config.json
|
| 175 |
+
import json
|
| 176 |
+
config_path = self.checkpoint_path / "config.json"
|
| 177 |
+
|
| 178 |
+
# Load config - ensure 'type' field is set to 'smolvla'
|
| 179 |
+
config_path = self.checkpoint_path / "config.json"
|
| 180 |
+
|
| 181 |
+
if config_path.exists():
|
| 182 |
+
# Load config
|
| 183 |
+
with open(config_path, 'r') as f:
|
| 184 |
+
config_dict = json.load(f)
|
| 185 |
+
|
| 186 |
+
# Ensure 'type' field is set to 'smolvla'
|
| 187 |
+
if 'type' not in config_dict or config_dict['type'] != 'smolvla':
|
| 188 |
+
logger.debug("Setting 'type' field to 'smolvla' in config")
|
| 189 |
+
config_dict['type'] = 'smolvla'
|
| 190 |
+
|
| 191 |
+
# Save updated config back
|
| 192 |
+
with open(config_path, 'w') as f:
|
| 193 |
+
json.dump(config_dict, f, indent=2)
|
| 194 |
+
|
| 195 |
+
# Get VLM model name for tokenizer
|
| 196 |
+
vlm_model_name = config_dict.get('vlm_model_name', 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct')
|
| 197 |
+
else:
|
| 198 |
+
vlm_model_name = 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct'
|
| 199 |
+
|
| 200 |
+
# Load policy using from_pretrained (it will load the config automatically)
|
| 201 |
+
self.policy = SmolVLAPolicy.from_pretrained(str(self.checkpoint_path))
|
| 202 |
+
|
| 203 |
+
# Move to device
|
| 204 |
+
self.policy.to(self.device)
|
| 205 |
+
|
| 206 |
+
# Set to evaluation mode
|
| 207 |
+
self.policy.eval()
|
| 208 |
+
|
| 209 |
+
logger.info("SmolVLA model loaded successfully")
|
| 210 |
+
|
| 211 |
+
# Load preprocessor (handles tokenization automatically)
|
| 212 |
+
self._load_preprocessor()
|
| 213 |
+
|
| 214 |
+
# Perform warmup inference
|
| 215 |
+
self._warmup()
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"Failed to load SmolVLA model: {e}")
|
| 219 |
+
raise SmolVLAError(f"Model loading failed: {e}")
|
| 220 |
+
|
| 221 |
+
def _load_preprocessor(self):
|
| 222 |
+
"""
|
| 223 |
+
Load preprocessor from checkpoint.
|
| 224 |
+
|
| 225 |
+
The preprocessor handles automatic tokenization of task strings
|
| 226 |
+
through the TokenizerProcessorStep.
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
SmolVLAError: If preprocessor loading fails
|
| 230 |
+
"""
|
| 231 |
+
try:
|
| 232 |
+
from lerobot.policies.factory import make_pre_post_processors
|
| 233 |
+
|
| 234 |
+
logger.info("Loading preprocessor from checkpoint...")
|
| 235 |
+
|
| 236 |
+
# Load preprocessor and postprocessor using policy config
|
| 237 |
+
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
| 238 |
+
self.policy.config,
|
| 239 |
+
pretrained_path=str(self.checkpoint_path),
|
| 240 |
+
device=self.device
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
logger.info("Preprocessor and postprocessor loaded successfully")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.error(f"Failed to load preprocessor: {e}")
|
| 247 |
+
raise SmolVLAError(f"Preprocessor loading failed: {e}")
|
| 248 |
+
|
| 249 |
+
def _warmup(self):
|
| 250 |
+
"""
|
| 251 |
+
Perform warmup inference to initialize CUDA kernels and caches.
|
| 252 |
+
|
| 253 |
+
This reduces latency for the first real inference call.
|
| 254 |
+
"""
|
| 255 |
+
if self.device == "cuda":
|
| 256 |
+
logger.info("Performing model warmup...")
|
| 257 |
+
try:
|
| 258 |
+
# Create dummy observation
|
| 259 |
+
dummy_obs = self._create_dummy_observation()
|
| 260 |
+
|
| 261 |
+
# Run dummy inference
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
# SmolVLA expects a batch of observations
|
| 264 |
+
result = self.policy.select_action(dummy_obs)
|
| 265 |
+
# Result may be a dict with 'action' key or just a tensor
|
| 266 |
+
if isinstance(result, dict):
|
| 267 |
+
_ = result.get('action', result)
|
| 268 |
+
|
| 269 |
+
# Clear cache
|
| 270 |
+
torch.cuda.empty_cache()
|
| 271 |
+
|
| 272 |
+
logger.info("Model warmup complete")
|
| 273 |
+
except Exception as e:
|
| 274 |
+
# Warmup is optional - log but don't fail
|
| 275 |
+
logger.debug(f"Warmup skipped: {e}")
|
| 276 |
+
pass
|
| 277 |
+
|
| 278 |
+
def _create_dummy_observation(self) -> Dict[str, torch.Tensor]:
|
| 279 |
+
"""
|
| 280 |
+
Create a dummy observation for warmup.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Dictionary with dummy observation tensors
|
| 284 |
+
"""
|
| 285 |
+
# Create dummy state
|
| 286 |
+
dummy_state = torch.zeros(1, 6, dtype=torch.float32, device=self.device)
|
| 287 |
+
|
| 288 |
+
# Create dummy images
|
| 289 |
+
dummy_image = self._create_dummy_image()
|
| 290 |
+
|
| 291 |
+
observation = {
|
| 292 |
+
"observation.images.camera1": dummy_image,
|
| 293 |
+
"observation.images.camera2": dummy_image.clone(),
|
| 294 |
+
"observation.images.camera3": dummy_image.clone(),
|
| 295 |
+
"observation.state": dummy_state,
|
| 296 |
+
"task": "dummy task" # Task as string (preprocessor will handle it)
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
# Apply preprocessor to tokenize task
|
| 300 |
+
if self.preprocessor is not None:
|
| 301 |
+
observation = self.preprocessor(observation)
|
| 302 |
+
|
| 303 |
+
return observation
|
| 304 |
+
|
| 305 |
+
def validate_command(self, command: str) -> bool:
|
| 306 |
+
"""
|
| 307 |
+
Validate that a command is in the trained task set.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
command: The manipulation command to validate
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
True if command is valid, False otherwise
|
| 314 |
+
"""
|
| 315 |
+
return command in self.VALID_COMMANDS
|
| 316 |
+
|
| 317 |
+
def trigger_emergency_stop(self):
|
| 318 |
+
"""
|
| 319 |
+
Trigger emergency stop from external thread.
|
| 320 |
+
|
| 321 |
+
This can be called from another thread to safely stop execution.
|
| 322 |
+
"""
|
| 323 |
+
logger.warning("Emergency stop triggered externally")
|
| 324 |
+
self._emergency_stop_flag.set()
|
| 325 |
+
|
| 326 |
+
def is_executing(self) -> bool:
|
| 327 |
+
"""
|
| 328 |
+
Check if executor is currently running a task.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
True if a task is being executed
|
| 332 |
+
"""
|
| 333 |
+
return self._is_executing
|
| 334 |
+
|
| 335 |
+
def execute(self, command: str, max_steps: int = 500, timeout: Optional[float] = None) -> bool:
|
| 336 |
+
"""
|
| 337 |
+
Execute a manipulation task using SmolVLA inference.
|
| 338 |
+
|
| 339 |
+
This is the main entry point for executing manipulation commands.
|
| 340 |
+
It runs the inference loop, capturing observations and executing
|
| 341 |
+
predicted actions until the task is complete or max_steps is reached.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
command: Natural language task description (must be in VALID_COMMANDS)
|
| 345 |
+
max_steps: Maximum number of inference steps to execute
|
| 346 |
+
timeout: Optional timeout override (seconds)
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
True if execution completed successfully, False otherwise
|
| 350 |
+
|
| 351 |
+
Raises:
|
| 352 |
+
SmolVLAError: If command is invalid or execution fails critically
|
| 353 |
+
SafetyViolationError: If safety constraints are violated
|
| 354 |
+
TimeoutError: If execution exceeds timeout
|
| 355 |
+
"""
|
| 356 |
+
# Acquire execution lock to prevent concurrent execution
|
| 357 |
+
if not self._execution_lock.acquire(blocking=False):
|
| 358 |
+
raise SmolVLAError("Executor is already running a task")
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# Clear emergency stop flag
|
| 362 |
+
self._emergency_stop_flag.clear()
|
| 363 |
+
self._is_executing = True
|
| 364 |
+
|
| 365 |
+
# Validate command against trained task set
|
| 366 |
+
if not self.validate_command(command):
|
| 367 |
+
raise SmolVLAError(
|
| 368 |
+
f"Invalid command: '{command}'. "
|
| 369 |
+
f"Must be one of: {self.VALID_COMMANDS}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Ensure robot is connected
|
| 373 |
+
if not self.robot_arm.connected:
|
| 374 |
+
logger.info("Robot not connected, attempting to connect...")
|
| 375 |
+
self.robot_arm.connect()
|
| 376 |
+
if not self.robot_arm.connected:
|
| 377 |
+
raise SmolVLAError("Failed to connect to robot arm")
|
| 378 |
+
|
| 379 |
+
# Use provided timeout or default
|
| 380 |
+
execution_timeout = timeout if timeout is not None else self.timeout
|
| 381 |
+
|
| 382 |
+
logger.info(f"Starting SmolVLA execution: '{command}'")
|
| 383 |
+
logger.info(f"Max steps: {max_steps}, Timeout: {execution_timeout}s")
|
| 384 |
+
logger.info(f"Safety checks: {'enabled' if self.enable_safety_checks else 'disabled'}")
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
# Execute the task with timeout
|
| 388 |
+
success = self._execute_task_with_timeout(command, max_steps, execution_timeout)
|
| 389 |
+
|
| 390 |
+
if success:
|
| 391 |
+
logger.info(f"Task completed successfully: '{command}'")
|
| 392 |
+
else:
|
| 393 |
+
logger.warning(f"Task did not complete within constraints")
|
| 394 |
+
|
| 395 |
+
# Return to home position safely
|
| 396 |
+
logger.info("Returning to home position...")
|
| 397 |
+
self._safe_return_home()
|
| 398 |
+
|
| 399 |
+
return success
|
| 400 |
+
|
| 401 |
+
except TimeoutError as e:
|
| 402 |
+
logger.error(f"Execution timeout: {e}")
|
| 403 |
+
self._emergency_stop()
|
| 404 |
+
raise
|
| 405 |
+
except SafetyViolationError as e:
|
| 406 |
+
logger.error(f"Safety violation: {e}")
|
| 407 |
+
self._emergency_stop()
|
| 408 |
+
raise
|
| 409 |
+
except GPUOutOfMemoryError as e:
|
| 410 |
+
logger.error(f"GPU out of memory: {e}")
|
| 411 |
+
self._handle_gpu_oom()
|
| 412 |
+
self._emergency_stop()
|
| 413 |
+
raise
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.error(f"Execution failed: {e}")
|
| 416 |
+
import traceback
|
| 417 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 418 |
+
self._emergency_stop()
|
| 419 |
+
raise SmolVLAError(f"Execution failed: {e}")
|
| 420 |
+
|
| 421 |
+
finally:
|
| 422 |
+
# Always release lock and reset execution flag
|
| 423 |
+
self._is_executing = False
|
| 424 |
+
self._execution_lock.release()
|
| 425 |
+
|
| 426 |
+
def _execute_task_with_timeout(self, command: str, max_steps: int, timeout: float) -> bool:
|
| 427 |
+
"""
|
| 428 |
+
Execute task with timeout monitoring.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
command: The manipulation command
|
| 432 |
+
max_steps: Maximum steps
|
| 433 |
+
timeout: Timeout in seconds
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
True if task completed successfully
|
| 437 |
+
|
| 438 |
+
Raises:
|
| 439 |
+
TimeoutError: If execution exceeds timeout
|
| 440 |
+
"""
|
| 441 |
+
start_time = time.time()
|
| 442 |
+
|
| 443 |
+
try:
|
| 444 |
+
return self._execute_task(command, max_steps, start_time, timeout)
|
| 445 |
+
except Exception as e:
|
| 446 |
+
elapsed = time.time() - start_time
|
| 447 |
+
if elapsed >= timeout:
|
| 448 |
+
raise TimeoutError(f"Execution exceeded timeout of {timeout}s")
|
| 449 |
+
raise
|
| 450 |
+
|
| 451 |
+
def _execute_task(self, command: str, max_steps: int, start_time: float, timeout: float) -> bool:
|
| 452 |
+
"""
|
| 453 |
+
Internal method to execute the task inference loop.
|
| 454 |
+
|
| 455 |
+
This method implements the core inference loop:
|
| 456 |
+
1. Capture visual and state observations
|
| 457 |
+
2. Run SmolVLA inference to predict next action
|
| 458 |
+
3. Execute action on robot
|
| 459 |
+
4. Check for task completion
|
| 460 |
+
5. Repeat until complete or max_steps reached
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
command: The manipulation command to execute
|
| 464 |
+
max_steps: Maximum number of steps
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
True if task completed, False if max steps reached
|
| 468 |
+
"""
|
| 469 |
+
# Reset task completion tracking variables
|
| 470 |
+
self._previous_action = None
|
| 471 |
+
self._stable_count = 0
|
| 472 |
+
self._previous_state = None
|
| 473 |
+
|
| 474 |
+
# Track execution metrics
|
| 475 |
+
last_progress_log = 0
|
| 476 |
+
progress_log_interval = 50 # Log every 50 steps
|
| 477 |
+
|
| 478 |
+
with torch.no_grad():
|
| 479 |
+
for step in range(max_steps):
|
| 480 |
+
# Check for emergency stop
|
| 481 |
+
if self._emergency_stop_flag.is_set():
|
| 482 |
+
logger.warning("Emergency stop detected, aborting execution")
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
+
# Check timeout
|
| 486 |
+
elapsed = time.time() - start_time
|
| 487 |
+
if elapsed >= timeout:
|
| 488 |
+
raise TimeoutError(f"Execution exceeded timeout of {timeout}s at step {step}")
|
| 489 |
+
|
| 490 |
+
# Log progress periodically
|
| 491 |
+
if step - last_progress_log >= progress_log_interval:
|
| 492 |
+
fps = step / elapsed if elapsed > 0 else 0
|
| 493 |
+
logger.info(
|
| 494 |
+
f"Execution progress: step {step}/{max_steps} "
|
| 495 |
+
f"({step/max_steps*100:.1f}%) - {fps:.1f} FPS - {elapsed:.1f}s elapsed"
|
| 496 |
+
)
|
| 497 |
+
last_progress_log = step
|
| 498 |
+
|
| 499 |
+
try:
|
| 500 |
+
# Capture current observation
|
| 501 |
+
observation = self._get_observation()
|
| 502 |
+
|
| 503 |
+
# Add task string (preprocessor will tokenize it)
|
| 504 |
+
observation = self._add_task_string(observation, command)
|
| 505 |
+
|
| 506 |
+
# Apply preprocessor (tokenizes task string automatically)
|
| 507 |
+
observation = self.preprocessor(observation)
|
| 508 |
+
|
| 509 |
+
# Run inference to predict next action (normalized)
|
| 510 |
+
action_normalized = self._run_inference_with_oom_handling(observation)
|
| 511 |
+
|
| 512 |
+
# Debug: log normalized action
|
| 513 |
+
logger.debug(f"Normalized action type: {type(action_normalized)}, shape: {action_normalized.shape if hasattr(action_normalized, 'shape') else 'N/A'}")
|
| 514 |
+
|
| 515 |
+
# Denormalize action using postprocessor
|
| 516 |
+
action = self.postprocessor(action_normalized)
|
| 517 |
+
|
| 518 |
+
# Debug: log denormalized action
|
| 519 |
+
logger.debug(f"Denormalized action: {action}")
|
| 520 |
+
|
| 521 |
+
# Validate action safety (on denormalized action)
|
| 522 |
+
if self.enable_safety_checks:
|
| 523 |
+
self._check_action_safety(action, observation)
|
| 524 |
+
|
| 525 |
+
# Send action to robot
|
| 526 |
+
self._send_action(action)
|
| 527 |
+
|
| 528 |
+
# Check if task is complete (use normalized action for stability check)
|
| 529 |
+
try:
|
| 530 |
+
is_complete = self._is_task_complete(observation, step, action_normalized)
|
| 531 |
+
if is_complete:
|
| 532 |
+
elapsed = time.time() - start_time
|
| 533 |
+
logger.info(
|
| 534 |
+
f"Task completed at step {step} "
|
| 535 |
+
f"(elapsed: {elapsed:.2f}s, avg FPS: {step/elapsed:.1f})"
|
| 536 |
+
)
|
| 537 |
+
return True
|
| 538 |
+
except Exception as e:
|
| 539 |
+
logger.error(f"Error in _is_task_complete: {e}")
|
| 540 |
+
raise
|
| 541 |
+
|
| 542 |
+
# Small delay between steps to maintain ~30 FPS
|
| 543 |
+
time.sleep(0.033)
|
| 544 |
+
|
| 545 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 546 |
+
logger.error(f"GPU out of memory at step {step}")
|
| 547 |
+
raise GPUOutOfMemoryError(f"GPU OOM at step {step}: {e}")
|
| 548 |
+
except SafetyViolationError:
|
| 549 |
+
# Re-raise safety violations
|
| 550 |
+
raise
|
| 551 |
+
except Exception as e:
|
| 552 |
+
logger.error(f"Error at step {step}: {e}")
|
| 553 |
+
raise
|
| 554 |
+
|
| 555 |
+
# Max steps reached without completion
|
| 556 |
+
elapsed = time.time() - start_time
|
| 557 |
+
logger.warning(
|
| 558 |
+
f"Task did not complete within {max_steps} steps "
|
| 559 |
+
f"(elapsed: {elapsed:.2f}s)"
|
| 560 |
+
)
|
| 561 |
+
return False
|
| 562 |
+
|
| 563 |
+
def _get_observation(self) -> Dict[str, torch.Tensor]:
|
| 564 |
+
"""
|
| 565 |
+
Get current robot observation (image + state).
|
| 566 |
+
|
| 567 |
+
Captures robot state from robot.get_observation() and images from cameras.
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
Dictionary with observation tensors formatted for SmolVLA:
|
| 571 |
+
- observation.images.camera1: RGB image tensor [1, 3, H, W]
|
| 572 |
+
- observation.images.camera2: RGB image tensor [1, 3, H, W] (if available)
|
| 573 |
+
- observation.images.camera3: RGB image tensor [1, 3, H, W] (if available)
|
| 574 |
+
- observation.state: Joint positions tensor [1, 6]
|
| 575 |
+
"""
|
| 576 |
+
try:
|
| 577 |
+
# Get robot state (joint positions)
|
| 578 |
+
robot_obs = self.robot_arm.robot.get_observation()
|
| 579 |
+
|
| 580 |
+
# Extract joint positions in order
|
| 581 |
+
joint_names = [
|
| 582 |
+
"shoulder_pan.pos",
|
| 583 |
+
"shoulder_lift.pos",
|
| 584 |
+
"elbow_flex.pos",
|
| 585 |
+
"wrist_flex.pos",
|
| 586 |
+
"wrist_roll.pos",
|
| 587 |
+
"gripper.pos"
|
| 588 |
+
]
|
| 589 |
+
|
| 590 |
+
# Build state vector
|
| 591 |
+
state_values = [robot_obs[name] for name in joint_names]
|
| 592 |
+
state_tensor = torch.tensor(
|
| 593 |
+
state_values,
|
| 594 |
+
dtype=torch.float32,
|
| 595 |
+
device=self.device
|
| 596 |
+
).unsqueeze(0) # Add batch dimension
|
| 597 |
+
|
| 598 |
+
# Get camera images (robot.cameras is a dict of camera objects)
|
| 599 |
+
observation = {"observation.state": state_tensor}
|
| 600 |
+
|
| 601 |
+
if hasattr(self.robot_arm.robot, 'cameras') and self.robot_arm.robot.cameras:
|
| 602 |
+
# Get images from robot's cameras
|
| 603 |
+
for i, (camera_name, camera) in enumerate(self.robot_arm.robot.cameras.items(), start=1):
|
| 604 |
+
try:
|
| 605 |
+
image = camera.read()
|
| 606 |
+
# Convert to tensor: (H, W, C) -> (1, C, H, W)
|
| 607 |
+
image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0).to(self.device) / 255.0
|
| 608 |
+
observation[f"observation.images.camera{i}"] = image_tensor
|
| 609 |
+
logger.debug(f"Captured image from {camera_name}: shape={image.shape}")
|
| 610 |
+
except Exception as e:
|
| 611 |
+
logger.warning(f"Failed to read from {camera_name}: {e}")
|
| 612 |
+
observation[f"observation.images.camera{i}"] = self._create_dummy_image()
|
| 613 |
+
else:
|
| 614 |
+
logger.debug("No cameras configured on robot, using dummy images")
|
| 615 |
+
|
| 616 |
+
# Ensure we have 3 camera images (duplicate if needed)
|
| 617 |
+
for i in range(1, 4):
|
| 618 |
+
key = f"observation.images.camera{i}"
|
| 619 |
+
if key not in observation:
|
| 620 |
+
# Use first camera or dummy
|
| 621 |
+
if "observation.images.camera1" in observation:
|
| 622 |
+
observation[key] = observation["observation.images.camera1"].clone()
|
| 623 |
+
else:
|
| 624 |
+
observation[key] = self._create_dummy_image()
|
| 625 |
+
|
| 626 |
+
logger.debug(f"Captured observation with keys: {list(observation.keys())}")
|
| 627 |
+
return observation
|
| 628 |
+
|
| 629 |
+
except Exception as e:
|
| 630 |
+
logger.warning(f"Failed to get robot observation: {e}. Using dummy observation.")
|
| 631 |
+
return self._create_dummy_observation_without_task()
|
| 632 |
+
|
| 633 |
+
def _create_dummy_observation_without_task(self) -> Dict[str, torch.Tensor]:
|
| 634 |
+
"""
|
| 635 |
+
Create a dummy observation without task string (for error recovery).
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
Dictionary with dummy observation tensors
|
| 639 |
+
"""
|
| 640 |
+
dummy_state = torch.zeros(1, 6, dtype=torch.float32, device=self.device)
|
| 641 |
+
dummy_image = self._create_dummy_image()
|
| 642 |
+
|
| 643 |
+
return {
|
| 644 |
+
"observation.images.camera1": dummy_image,
|
| 645 |
+
"observation.images.camera2": dummy_image.clone(),
|
| 646 |
+
"observation.images.camera3": dummy_image.clone(),
|
| 647 |
+
"observation.state": dummy_state,
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
def _add_task_string(self, observation: Dict[str, torch.Tensor], command: str) -> Dict[str, torch.Tensor]:
|
| 651 |
+
"""
|
| 652 |
+
Add task string to observation.
|
| 653 |
+
|
| 654 |
+
The preprocessor will automatically tokenize this string through
|
| 655 |
+
the TokenizerProcessorStep.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
observation: Current observation dictionary
|
| 659 |
+
command: Natural language command string
|
| 660 |
+
|
| 661 |
+
Returns:
|
| 662 |
+
Observation dictionary with added task string
|
| 663 |
+
"""
|
| 664 |
+
# Simply add the task string - the preprocessor will tokenize it
|
| 665 |
+
observation["task"] = command
|
| 666 |
+
|
| 667 |
+
logger.debug(f"Added task string: '{command}'")
|
| 668 |
+
|
| 669 |
+
return observation
|
| 670 |
+
|
| 671 |
+
def _create_dummy_image(self) -> torch.Tensor:
|
| 672 |
+
"""
|
| 673 |
+
Create a dummy image tensor for testing without camera.
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
Dummy image tensor [1, 3, 256, 256] with batch dimension
|
| 677 |
+
"""
|
| 678 |
+
# Create black image
|
| 679 |
+
dummy_image = torch.zeros(1, 3, 256, 256, dtype=torch.float32, device=self.device)
|
| 680 |
+
return dummy_image
|
| 681 |
+
|
| 682 |
+
def _send_action(self, action: torch.Tensor):
|
| 683 |
+
"""
|
| 684 |
+
Send predicted action to robot.
|
| 685 |
+
|
| 686 |
+
Converts the action tensor from SmolVLA to SO101 command format
|
| 687 |
+
and sends it to the robot arm for execution.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
action: Action tensor from policy (shape: [batch, action_dim])
|
| 691 |
+
|
| 692 |
+
Raises:
|
| 693 |
+
SmolVLAError: If action execution fails
|
| 694 |
+
"""
|
| 695 |
+
try:
|
| 696 |
+
# Convert action tensor to robot command dictionary
|
| 697 |
+
action_dict = self._action_to_dict(action)
|
| 698 |
+
|
| 699 |
+
# Send to robot
|
| 700 |
+
self.robot_arm.robot.send_action(action_dict)
|
| 701 |
+
|
| 702 |
+
# Log action at debug level (verbose)
|
| 703 |
+
logger.debug(f"Action sent: {action_dict}")
|
| 704 |
+
|
| 705 |
+
except Exception as e:
|
| 706 |
+
logger.error(f"Failed to send action to robot: {e}")
|
| 707 |
+
raise SmolVLAError(f"Action execution failed: {e}")
|
| 708 |
+
|
| 709 |
+
def _action_to_dict(self, action: torch.Tensor) -> Dict[str, float]:
|
| 710 |
+
"""
|
| 711 |
+
Convert action tensor to SO101 command format.
|
| 712 |
+
|
| 713 |
+
Maps the action tensor dimensions to SO101 joint names and converts
|
| 714 |
+
to the dictionary format expected by the robot driver.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
action: Action tensor from policy (shape: [batch, 6] or [6])
|
| 718 |
+
|
| 719 |
+
Returns:
|
| 720 |
+
Dictionary mapping joint names to positions (in degrees or normalized units)
|
| 721 |
+
|
| 722 |
+
Raises:
|
| 723 |
+
SmolVLAError: If action tensor has invalid shape
|
| 724 |
+
"""
|
| 725 |
+
# Remove batch dimension if present
|
| 726 |
+
if action.dim() > 1:
|
| 727 |
+
action = action.squeeze(0)
|
| 728 |
+
|
| 729 |
+
# Validate action dimension
|
| 730 |
+
if action.shape[0] != 6:
|
| 731 |
+
raise SmolVLAError(
|
| 732 |
+
f"Invalid action shape: expected 6 dimensions, got {action.shape[0]}"
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Convert to numpy
|
| 736 |
+
action_np = action.cpu().numpy()
|
| 737 |
+
|
| 738 |
+
# Map action dimensions to joint names
|
| 739 |
+
# Order must match the training data format
|
| 740 |
+
joint_names = [
|
| 741 |
+
"shoulder_pan.pos",
|
| 742 |
+
"shoulder_lift.pos",
|
| 743 |
+
"elbow_flex.pos",
|
| 744 |
+
"wrist_flex.pos",
|
| 745 |
+
"wrist_roll.pos",
|
| 746 |
+
"gripper.pos"
|
| 747 |
+
]
|
| 748 |
+
|
| 749 |
+
# Create action dictionary
|
| 750 |
+
action_dict = {
|
| 751 |
+
name: float(action_np[i])
|
| 752 |
+
for i, name in enumerate(joint_names)
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
return action_dict
|
| 756 |
+
|
| 757 |
+
def _is_task_complete(
|
| 758 |
+
self,
|
| 759 |
+
observation: Dict[str, torch.Tensor],
|
| 760 |
+
step: int,
|
| 761 |
+
action: torch.Tensor
|
| 762 |
+
) -> bool:
|
| 763 |
+
"""
|
| 764 |
+
Determine if the task is complete.
|
| 765 |
+
|
| 766 |
+
This method uses multiple heuristics to detect task completion:
|
| 767 |
+
1. Minimum step count (ensure task has progressed)
|
| 768 |
+
2. Maximum step count (assume completion after sufficient time)
|
| 769 |
+
3. Action stability (detect when robot has settled)
|
| 770 |
+
|
| 771 |
+
In a production system, this could be enhanced with:
|
| 772 |
+
- Learned termination classifier
|
| 773 |
+
- Visual goal detection
|
| 774 |
+
- Force/torque feedback
|
| 775 |
+
- Success detection from camera
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
observation: Current observation dictionary
|
| 779 |
+
step: Current step number
|
| 780 |
+
action: Predicted action tensor
|
| 781 |
+
|
| 782 |
+
Returns:
|
| 783 |
+
True if task should be considered complete
|
| 784 |
+
"""
|
| 785 |
+
# Minimum steps before considering completion (allow task to progress)
|
| 786 |
+
MIN_STEPS = 100
|
| 787 |
+
|
| 788 |
+
# Maximum steps - assume task is complete after this many steps
|
| 789 |
+
# Most manipulation tasks should complete within 400-450 steps at 30 FPS
|
| 790 |
+
# (approximately 13-15 seconds)
|
| 791 |
+
MAX_STEPS = 450
|
| 792 |
+
|
| 793 |
+
# Early exit: not enough steps yet
|
| 794 |
+
if step < MIN_STEPS:
|
| 795 |
+
return False
|
| 796 |
+
|
| 797 |
+
# Late exit: max steps reached, consider complete
|
| 798 |
+
if step >= MAX_STEPS:
|
| 799 |
+
logger.info(f"Task completion: max steps ({MAX_STEPS}) reached")
|
| 800 |
+
return True
|
| 801 |
+
|
| 802 |
+
# Check for action stability (robot has settled into final position)
|
| 803 |
+
if hasattr(self, '_previous_action') and self._previous_action is not None:
|
| 804 |
+
action_diff = torch.abs(action - self._previous_action).max().item()
|
| 805 |
+
|
| 806 |
+
# If action changes are very small, robot may have settled
|
| 807 |
+
if action_diff < 0.01: # Threshold for "stable" action
|
| 808 |
+
if not hasattr(self, '_stable_count'):
|
| 809 |
+
self._stable_count = 0
|
| 810 |
+
self._stable_count += 1
|
| 811 |
+
|
| 812 |
+
# If stable for 30 consecutive steps (~1 second), consider complete
|
| 813 |
+
if self._stable_count >= 30:
|
| 814 |
+
logger.info(
|
| 815 |
+
f"Task completion: action stability detected at step {step} "
|
| 816 |
+
f"(stable for {self._stable_count} steps)"
|
| 817 |
+
)
|
| 818 |
+
return True
|
| 819 |
+
else:
|
| 820 |
+
# Reset stability counter if action changes significantly
|
| 821 |
+
self._stable_count = 0
|
| 822 |
+
|
| 823 |
+
# Store current action for next comparison
|
| 824 |
+
self._previous_action = action.clone()
|
| 825 |
+
|
| 826 |
+
# Not complete yet
|
| 827 |
+
return False
|
| 828 |
+
|
| 829 |
+
def _check_action_safety(self, action: torch.Tensor, observation: Dict[str, torch.Tensor]):
|
| 830 |
+
"""
|
| 831 |
+
Check if predicted action is safe to execute.
|
| 832 |
+
|
| 833 |
+
Validates:
|
| 834 |
+
1. Joint position limits
|
| 835 |
+
2. Joint velocity limits
|
| 836 |
+
3. Workspace boundaries
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
action: Predicted action tensor
|
| 840 |
+
observation: Current observation
|
| 841 |
+
|
| 842 |
+
Raises:
|
| 843 |
+
SafetyViolationError: If action violates safety constraints
|
| 844 |
+
"""
|
| 845 |
+
# Convert action to dict for checking
|
| 846 |
+
action_dict = self._action_to_dict(action)
|
| 847 |
+
|
| 848 |
+
# Check joint position limits
|
| 849 |
+
for joint_name, position in action_dict.items():
|
| 850 |
+
if joint_name in self.JOINT_LIMITS:
|
| 851 |
+
min_pos, max_pos = self.JOINT_LIMITS[joint_name]
|
| 852 |
+
if position < min_pos or position > max_pos:
|
| 853 |
+
raise SafetyViolationError(
|
| 854 |
+
f"Joint {joint_name} position {position:.2f} exceeds limits "
|
| 855 |
+
f"[{min_pos}, {max_pos}]"
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
# Check joint velocity limits (if we have previous state)
|
| 859 |
+
if self._previous_state is not None:
|
| 860 |
+
current_state = observation["observation.state"].squeeze(0).cpu().numpy()
|
| 861 |
+
velocity = np.abs(current_state - self._previous_state)
|
| 862 |
+
max_velocity = np.max(velocity)
|
| 863 |
+
|
| 864 |
+
if max_velocity > self.MAX_JOINT_VELOCITY:
|
| 865 |
+
raise SafetyViolationError(
|
| 866 |
+
f"Joint velocity {max_velocity:.2f} exceeds limit "
|
| 867 |
+
f"{self.MAX_JOINT_VELOCITY}"
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Update previous state for next check
|
| 871 |
+
self._previous_state = observation["observation.state"].squeeze(0).cpu().numpy().copy()
|
| 872 |
+
|
| 873 |
+
def _run_inference_with_oom_handling(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 874 |
+
"""
|
| 875 |
+
Run inference with GPU out-of-memory handling.
|
| 876 |
+
|
| 877 |
+
Args:
|
| 878 |
+
observation: Current observation
|
| 879 |
+
|
| 880 |
+
Returns:
|
| 881 |
+
Predicted action tensor
|
| 882 |
+
|
| 883 |
+
Raises:
|
| 884 |
+
GPUOutOfMemoryError: If GPU runs out of memory
|
| 885 |
+
"""
|
| 886 |
+
try:
|
| 887 |
+
result = self.policy.select_action(observation)
|
| 888 |
+
|
| 889 |
+
# Debug: log what we got back
|
| 890 |
+
logger.debug(f"Policy returned type: {type(result)}")
|
| 891 |
+
if isinstance(result, dict):
|
| 892 |
+
logger.debug(f"Policy returned dict keys: {result.keys()}")
|
| 893 |
+
|
| 894 |
+
# SmolVLA returns a dictionary with 'action' key
|
| 895 |
+
if isinstance(result, dict):
|
| 896 |
+
if 'action' in result:
|
| 897 |
+
return result['action']
|
| 898 |
+
else:
|
| 899 |
+
# Try to find the action in the dict
|
| 900 |
+
logger.error(f"Policy returned dict without 'action' key. Keys: {result.keys()}")
|
| 901 |
+
raise SmolVLAError(f"Policy returned unexpected format: {type(result)}")
|
| 902 |
+
return result
|
| 903 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 904 |
+
logger.error("GPU out of memory during inference")
|
| 905 |
+
# Try to recover by clearing cache
|
| 906 |
+
torch.cuda.empty_cache()
|
| 907 |
+
# Try one more time
|
| 908 |
+
try:
|
| 909 |
+
result = self.policy.select_action(observation)
|
| 910 |
+
if isinstance(result, dict):
|
| 911 |
+
if 'action' in result:
|
| 912 |
+
return result['action']
|
| 913 |
+
else:
|
| 914 |
+
raise SmolVLAError(f"Policy returned unexpected format: {type(result)}")
|
| 915 |
+
return result
|
| 916 |
+
except torch.cuda.OutOfMemoryError:
|
| 917 |
+
raise GPUOutOfMemoryError("GPU out of memory, cannot recover")
|
| 918 |
+
|
| 919 |
+
def _handle_gpu_oom(self):
|
| 920 |
+
"""
|
| 921 |
+
Handle GPU out-of-memory error by clearing cache and resetting state.
|
| 922 |
+
"""
|
| 923 |
+
logger.info("Handling GPU out-of-memory error...")
|
| 924 |
+
|
| 925 |
+
if self.device == "cuda":
|
| 926 |
+
# Clear CUDA cache
|
| 927 |
+
torch.cuda.empty_cache()
|
| 928 |
+
|
| 929 |
+
# Log memory stats
|
| 930 |
+
if torch.cuda.is_available():
|
| 931 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 932 |
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 933 |
+
logger.info(f"GPU memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
| 934 |
+
|
| 935 |
+
logger.info("GPU memory cleared")
|
| 936 |
+
|
| 937 |
+
def _safe_return_home(self):
|
| 938 |
+
"""
|
| 939 |
+
Safely return robot to home position with error handling.
|
| 940 |
+
"""
|
| 941 |
+
try:
|
| 942 |
+
self.robot_arm.move_arm("idle")
|
| 943 |
+
logger.info("Robot returned to home position")
|
| 944 |
+
except Exception as e:
|
| 945 |
+
logger.error(f"Failed to return to home position: {e}")
|
| 946 |
+
# Try direct position command as fallback
|
| 947 |
+
try:
|
| 948 |
+
self.robot_arm.robot.send_action(HOME_POSE)
|
| 949 |
+
logger.info("Robot returned to home using direct command")
|
| 950 |
+
except Exception as e2:
|
| 951 |
+
logger.error(f"Direct home command also failed: {e2}")
|
| 952 |
+
|
| 953 |
+
def _emergency_stop(self):
|
| 954 |
+
"""
|
| 955 |
+
Emergency stop: return robot to safe idle position.
|
| 956 |
+
|
| 957 |
+
This is called when an error occurs during execution.
|
| 958 |
+
Sets the emergency stop flag and attempts to safely stop the robot.
|
| 959 |
+
"""
|
| 960 |
+
logger.warning("Emergency stop triggered")
|
| 961 |
+
|
| 962 |
+
# Set emergency stop flag
|
| 963 |
+
self._emergency_stop_flag.set()
|
| 964 |
+
|
| 965 |
+
try:
|
| 966 |
+
# Try to stop robot immediately
|
| 967 |
+
self._safe_return_home()
|
| 968 |
+
logger.info("Emergency stop completed - robot in safe position")
|
| 969 |
+
except Exception as e:
|
| 970 |
+
logger.error(f"Emergency stop failed: {e}")
|
| 971 |
+
logger.error("MANUAL INTERVENTION MAY BE REQUIRED")
|
| 972 |
+
|
| 973 |
+
def cleanup(self):
|
| 974 |
+
"""
|
| 975 |
+
Clean up resources (camera, GPU memory, etc.).
|
| 976 |
+
|
| 977 |
+
Should be called when the executor is no longer needed.
|
| 978 |
+
"""
|
| 979 |
+
logger.info("Cleaning up SmolVLA executor...")
|
| 980 |
+
|
| 981 |
+
# Disconnect camera
|
| 982 |
+
if hasattr(self, 'camera') and self.camera is not None:
|
| 983 |
+
try:
|
| 984 |
+
self.camera.disconnect()
|
| 985 |
+
except Exception as e:
|
| 986 |
+
logger.warning(f"Camera disconnect failed: {e}")
|
| 987 |
+
|
| 988 |
+
# Clear GPU memory
|
| 989 |
+
if hasattr(self, 'device') and self.device == "cuda":
|
| 990 |
+
torch.cuda.empty_cache()
|
| 991 |
+
|
| 992 |
+
logger.info("Cleanup complete")
|
| 993 |
+
|
| 994 |
+
def __del__(self):
|
| 995 |
+
"""Destructor to ensure cleanup."""
|
| 996 |
+
try:
|
| 997 |
+
self.cleanup()
|
| 998 |
+
except Exception:
|
| 999 |
+
# Silently ignore cleanup errors in destructor
|
| 1000 |
+
pass
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def init_smolvla_executor(
|
| 1004 |
+
checkpoint_path: Optional[str] = None,
|
| 1005 |
+
robot_arm: Optional[MortisArm] = None,
|
| 1006 |
+
device: Optional[str] = None
|
| 1007 |
+
) -> SmolVLAExecutor:
|
| 1008 |
+
"""
|
| 1009 |
+
Factory function to initialize SmolVLA executor with environment configuration.
|
| 1010 |
+
|
| 1011 |
+
Args:
|
| 1012 |
+
checkpoint_path: Path to model checkpoint (uses env var if not provided)
|
| 1013 |
+
robot_arm: Optional MortisArm instance
|
| 1014 |
+
device: Device to use (uses env var or auto-detect if not provided)
|
| 1015 |
+
|
| 1016 |
+
Returns:
|
| 1017 |
+
Initialized SmolVLAExecutor instance
|
| 1018 |
+
|
| 1019 |
+
Raises:
|
| 1020 |
+
SmolVLAError: If initialization fails
|
| 1021 |
+
"""
|
| 1022 |
+
# Get checkpoint path from environment if not provided
|
| 1023 |
+
if checkpoint_path is None:
|
| 1024 |
+
checkpoint_path = os.getenv("SMOLVLA_CHECKPOINT_PATH")
|
| 1025 |
+
if checkpoint_path is None:
|
| 1026 |
+
raise SmolVLAError(
|
| 1027 |
+
"No checkpoint path provided and SMOLVLA_CHECKPOINT_PATH not set"
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
# Get device from environment if not provided
|
| 1031 |
+
if device is None:
|
| 1032 |
+
device = os.getenv("SMOLVLA_DEVICE")
|
| 1033 |
+
|
| 1034 |
+
logger.info(f"Initializing SmolVLA executor with checkpoint: {checkpoint_path}")
|
| 1035 |
+
|
| 1036 |
+
return SmolVLAExecutor(
|
| 1037 |
+
checkpoint_path=checkpoint_path,
|
| 1038 |
+
robot_arm=robot_arm,
|
| 1039 |
+
device=device
|
| 1040 |
+
)
|
src/mortis/stt_service.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speech-to-Text service for Mortis voice input.
|
| 3 |
+
|
| 4 |
+
This module provides the STTService class for converting audio input to text,
|
| 5 |
+
with support for Gemini native audio processing and fallback to Google Cloud Speech-to-Text.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Literal
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai import types
|
| 17 |
+
|
| 18 |
+
# Load environment variables
|
| 19 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 20 |
+
load_dotenv(REPO_ROOT / ".env")
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class STTProvider(Enum):
|
| 27 |
+
"""Available Speech-to-Text providers."""
|
| 28 |
+
GEMINI = "gemini"
|
| 29 |
+
GOOGLE_STT = "google_stt"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AudioFormat(Enum):
|
| 33 |
+
"""Supported audio formats."""
|
| 34 |
+
WAV = "wav"
|
| 35 |
+
MP3 = "mp3"
|
| 36 |
+
WEBM = "webm"
|
| 37 |
+
OGG = "ogg"
|
| 38 |
+
FLAC = "flac"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AudioProcessingError(Exception):
|
| 42 |
+
"""Base exception for audio processing errors."""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class STTService:
|
| 47 |
+
"""
|
| 48 |
+
Speech-to-Text service for converting audio input to text.
|
| 49 |
+
|
| 50 |
+
Supports multiple STT providers:
|
| 51 |
+
- Gemini native audio (primary, recommended)
|
| 52 |
+
- Google Cloud Speech-to-Text (fallback)
|
| 53 |
+
|
| 54 |
+
The service automatically handles audio format validation and conversion.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
provider: Optional[STTProvider] = None,
|
| 60 |
+
api_key: Optional[str] = None,
|
| 61 |
+
model_name: Optional[str] = None,
|
| 62 |
+
language_code: str = "en-US",
|
| 63 |
+
enable_fallback: bool = True
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Initialize STT service.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
provider: STT provider to use (defaults to GEMINI from env or GEMINI)
|
| 70 |
+
api_key: API key for Gemini (defaults to GEMINI_API_KEY env var)
|
| 71 |
+
model_name: Gemini model to use (defaults to GEMINI_MODEL env var or gemini-1.5-flash)
|
| 72 |
+
language_code: Language code for transcription (default: en-US)
|
| 73 |
+
enable_fallback: Whether to enable fallback to Google STT on Gemini failure
|
| 74 |
+
"""
|
| 75 |
+
# Determine provider from environment or default to Gemini
|
| 76 |
+
if provider is None:
|
| 77 |
+
provider_str = os.getenv("STT_PROVIDER", "gemini").lower()
|
| 78 |
+
try:
|
| 79 |
+
provider = STTProvider(provider_str)
|
| 80 |
+
except ValueError:
|
| 81 |
+
logger.warning(f"Invalid STT_PROVIDER '{provider_str}', defaulting to GEMINI")
|
| 82 |
+
provider = STTProvider.GEMINI
|
| 83 |
+
|
| 84 |
+
self.provider = provider
|
| 85 |
+
self.language_code = language_code
|
| 86 |
+
self.enable_fallback = enable_fallback
|
| 87 |
+
|
| 88 |
+
# Initialize Gemini client for audio processing
|
| 89 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 90 |
+
if not self.api_key:
|
| 91 |
+
raise ValueError("GEMINI_API_KEY must be provided or set in environment")
|
| 92 |
+
|
| 93 |
+
self.model_name = model_name or os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
|
| 94 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 95 |
+
|
| 96 |
+
# Initialize Google Cloud STT client (lazy loading)
|
| 97 |
+
self._google_stt_client = None
|
| 98 |
+
|
| 99 |
+
logger.info(
|
| 100 |
+
f"STTService initialized with provider: {self.provider.value}, "
|
| 101 |
+
f"model: {self.model_name}, language: {self.language_code}, "
|
| 102 |
+
f"fallback: {self.enable_fallback}"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def transcribe(self, audio_path: str) -> str:
|
| 106 |
+
"""
|
| 107 |
+
Transcribe audio file to text.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
audio_path: Path to audio file
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Transcribed text
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
AudioProcessingError: If transcription fails with all providers
|
| 117 |
+
FileNotFoundError: If audio file doesn't exist
|
| 118 |
+
"""
|
| 119 |
+
# Validate audio file exists
|
| 120 |
+
audio_file = Path(audio_path)
|
| 121 |
+
if not audio_file.exists():
|
| 122 |
+
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
| 123 |
+
|
| 124 |
+
# Validate audio format
|
| 125 |
+
if not self._validate_audio_format(audio_file):
|
| 126 |
+
raise AudioProcessingError(
|
| 127 |
+
f"Unsupported audio format: {audio_file.suffix}. "
|
| 128 |
+
f"Supported formats: {[fmt.value for fmt in AudioFormat]}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
logger.info(f"Transcribing audio file: {audio_path} using {self.provider.value}")
|
| 132 |
+
|
| 133 |
+
# Try primary provider
|
| 134 |
+
try:
|
| 135 |
+
if self.provider == STTProvider.GEMINI:
|
| 136 |
+
return self._transcribe_with_gemini(audio_path)
|
| 137 |
+
elif self.provider == STTProvider.GOOGLE_STT:
|
| 138 |
+
return self._transcribe_with_google_stt(audio_path)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.warning(f"Primary STT provider ({self.provider.value}) failed: {e}")
|
| 141 |
+
|
| 142 |
+
# Try fallback if enabled
|
| 143 |
+
if self.enable_fallback:
|
| 144 |
+
logger.info("Attempting fallback STT provider...")
|
| 145 |
+
try:
|
| 146 |
+
if self.provider == STTProvider.GEMINI:
|
| 147 |
+
# Fallback to Google STT
|
| 148 |
+
return self._transcribe_with_google_stt(audio_path)
|
| 149 |
+
else:
|
| 150 |
+
# Fallback to Gemini
|
| 151 |
+
return self._transcribe_with_gemini(audio_path)
|
| 152 |
+
except Exception as fallback_error:
|
| 153 |
+
logger.error(f"Fallback STT provider also failed: {fallback_error}")
|
| 154 |
+
raise AudioProcessingError(
|
| 155 |
+
f"All STT providers failed. Primary: {e}, Fallback: {fallback_error}"
|
| 156 |
+
) from fallback_error
|
| 157 |
+
else:
|
| 158 |
+
raise AudioProcessingError(f"STT transcription failed: {e}") from e
|
| 159 |
+
|
| 160 |
+
def _validate_audio_format(self, audio_file: Path) -> bool:
|
| 161 |
+
"""
|
| 162 |
+
Validate that audio file format is supported.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
audio_file: Path to audio file
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
True if format is supported, False otherwise
|
| 169 |
+
"""
|
| 170 |
+
suffix = audio_file.suffix.lstrip('.').lower()
|
| 171 |
+
supported_formats = [fmt.value for fmt in AudioFormat]
|
| 172 |
+
return suffix in supported_formats
|
| 173 |
+
|
| 174 |
+
def _transcribe_with_gemini(self, audio_path: str) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Transcribe audio using Gemini native audio support.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
audio_path: Path to audio file
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Transcribed text
|
| 183 |
+
|
| 184 |
+
Raises:
|
| 185 |
+
Exception: If Gemini API call fails
|
| 186 |
+
"""
|
| 187 |
+
logger.debug(f"Transcribing with Gemini: {audio_path}")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
# Upload audio file to Gemini
|
| 191 |
+
audio_file = self.client.files.upload(file=audio_path)
|
| 192 |
+
logger.debug(f"Audio file uploaded: {audio_file.name}")
|
| 193 |
+
|
| 194 |
+
# Create prompt for transcription
|
| 195 |
+
prompt = (
|
| 196 |
+
"Transcribe this audio accurately. "
|
| 197 |
+
"Return only the transcribed text without any additional commentary or formatting."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Generate content with audio
|
| 201 |
+
response = self.client.models.generate_content(
|
| 202 |
+
model=self.model_name,
|
| 203 |
+
contents=[prompt, audio_file]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Extract transcribed text
|
| 207 |
+
if response.text is None:
|
| 208 |
+
logger.warning("Gemini returned None for transcription")
|
| 209 |
+
logger.debug(f"Response object: {response}")
|
| 210 |
+
# Check if there are candidates with parts
|
| 211 |
+
if hasattr(response, 'candidates') and response.candidates:
|
| 212 |
+
logger.debug(f"Response has {len(response.candidates)} candidates")
|
| 213 |
+
for i, candidate in enumerate(response.candidates):
|
| 214 |
+
logger.debug(f"Candidate {i}: {candidate}")
|
| 215 |
+
transcript = ""
|
| 216 |
+
else:
|
| 217 |
+
transcript = response.text.strip()
|
| 218 |
+
|
| 219 |
+
if transcript:
|
| 220 |
+
logger.info(f"Gemini transcription successful: '{transcript[:50]}...'")
|
| 221 |
+
else:
|
| 222 |
+
logger.warning("Gemini transcription returned empty result")
|
| 223 |
+
|
| 224 |
+
# Clean up uploaded file
|
| 225 |
+
try:
|
| 226 |
+
self.client.files.delete(name=audio_file.name)
|
| 227 |
+
logger.debug(f"Deleted uploaded audio file: {audio_file.name}")
|
| 228 |
+
except Exception as cleanup_error:
|
| 229 |
+
logger.warning(f"Failed to delete uploaded audio file: {cleanup_error}")
|
| 230 |
+
|
| 231 |
+
return transcript
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.error(f"Gemini transcription failed: {type(e).__name__}: {e}")
|
| 235 |
+
raise
|
| 236 |
+
|
| 237 |
+
def _transcribe_with_google_stt(self, audio_path: str) -> str:
|
| 238 |
+
"""
|
| 239 |
+
Transcribe audio using Google Cloud Speech-to-Text API.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
audio_path: Path to audio file
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Transcribed text
|
| 246 |
+
|
| 247 |
+
Raises:
|
| 248 |
+
Exception: If Google STT API call fails
|
| 249 |
+
ImportError: If google-cloud-speech is not installed
|
| 250 |
+
"""
|
| 251 |
+
logger.debug(f"Transcribing with Google STT: {audio_path}")
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
from google.cloud import speech_v1
|
| 255 |
+
except ImportError:
|
| 256 |
+
raise ImportError(
|
| 257 |
+
"google-cloud-speech is not installed. "
|
| 258 |
+
"Install it with: pip install google-cloud-speech"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Initialize Google STT client (lazy loading)
|
| 262 |
+
if self._google_stt_client is None:
|
| 263 |
+
self._google_stt_client = speech_v1.SpeechClient()
|
| 264 |
+
logger.debug("Google STT client initialized")
|
| 265 |
+
|
| 266 |
+
# Read audio file
|
| 267 |
+
with open(audio_path, "rb") as audio_file:
|
| 268 |
+
audio_content = audio_file.read()
|
| 269 |
+
|
| 270 |
+
# Determine audio encoding from file extension
|
| 271 |
+
audio_path_obj = Path(audio_path)
|
| 272 |
+
suffix = audio_path_obj.suffix.lstrip('.').lower()
|
| 273 |
+
|
| 274 |
+
encoding_map = {
|
| 275 |
+
"wav": speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
|
| 276 |
+
"mp3": speech_v1.RecognitionConfig.AudioEncoding.MP3,
|
| 277 |
+
"flac": speech_v1.RecognitionConfig.AudioEncoding.FLAC,
|
| 278 |
+
"ogg": speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS,
|
| 279 |
+
"webm": speech_v1.RecognitionConfig.AudioEncoding.WEBM_OPUS,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
encoding = encoding_map.get(suffix, speech_v1.RecognitionConfig.AudioEncoding.LINEAR16)
|
| 283 |
+
|
| 284 |
+
# Configure recognition
|
| 285 |
+
audio = speech_v1.RecognitionAudio(content=audio_content)
|
| 286 |
+
config = speech_v1.RecognitionConfig(
|
| 287 |
+
encoding=encoding,
|
| 288 |
+
language_code=self.language_code,
|
| 289 |
+
enable_automatic_punctuation=True,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Perform transcription
|
| 293 |
+
try:
|
| 294 |
+
response = self._google_stt_client.recognize(config=config, audio=audio)
|
| 295 |
+
|
| 296 |
+
# Extract transcript from results
|
| 297 |
+
if not response.results:
|
| 298 |
+
logger.warning("Google STT returned no results")
|
| 299 |
+
return ""
|
| 300 |
+
|
| 301 |
+
# Combine all alternatives (usually just one)
|
| 302 |
+
transcript = " ".join(
|
| 303 |
+
result.alternatives[0].transcript
|
| 304 |
+
for result in response.results
|
| 305 |
+
if result.alternatives
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
logger.info(f"Google STT transcription successful: '{transcript[:50]}...'")
|
| 309 |
+
return transcript.strip()
|
| 310 |
+
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"Google STT transcription failed: {type(e).__name__}: {e}")
|
| 313 |
+
raise
|
| 314 |
+
|
| 315 |
+
def configure(
|
| 316 |
+
self,
|
| 317 |
+
provider: Optional[STTProvider] = None,
|
| 318 |
+
language_code: Optional[str] = None,
|
| 319 |
+
enable_fallback: Optional[bool] = None
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
Reconfigure STT service settings.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
provider: New STT provider to use
|
| 326 |
+
language_code: New language code
|
| 327 |
+
enable_fallback: Whether to enable fallback
|
| 328 |
+
"""
|
| 329 |
+
if provider is not None:
|
| 330 |
+
self.provider = provider
|
| 331 |
+
logger.info(f"STT provider changed to: {provider.value}")
|
| 332 |
+
|
| 333 |
+
if language_code is not None:
|
| 334 |
+
self.language_code = language_code
|
| 335 |
+
logger.info(f"Language code changed to: {language_code}")
|
| 336 |
+
|
| 337 |
+
if enable_fallback is not None:
|
| 338 |
+
self.enable_fallback = enable_fallback
|
| 339 |
+
logger.info(f"Fallback {'enabled' if enable_fallback else 'disabled'}")
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Example usage
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
import sys
|
| 345 |
+
|
| 346 |
+
# Configure logging for testing
|
| 347 |
+
logging.basicConfig(
|
| 348 |
+
level=logging.INFO,
|
| 349 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Check for audio file argument
|
| 353 |
+
if len(sys.argv) < 2:
|
| 354 |
+
print("Usage: python -m mortis.stt_service <audio_file>")
|
| 355 |
+
print("Example: python -m mortis.stt_service test_audio.wav")
|
| 356 |
+
sys.exit(1)
|
| 357 |
+
|
| 358 |
+
audio_file = sys.argv[1]
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# Create STT service
|
| 362 |
+
stt_service = STTService()
|
| 363 |
+
|
| 364 |
+
# Transcribe audio
|
| 365 |
+
print(f"\nTranscribing: {audio_file}")
|
| 366 |
+
print("-" * 60)
|
| 367 |
+
transcript = stt_service.transcribe(audio_file)
|
| 368 |
+
print(f"Transcript: {transcript}")
|
| 369 |
+
print("-" * 60)
|
| 370 |
+
|
| 371 |
+
except FileNotFoundError as e:
|
| 372 |
+
print(f"Error: {e}")
|
| 373 |
+
sys.exit(1)
|
| 374 |
+
except AudioProcessingError as e:
|
| 375 |
+
print(f"Audio processing error: {e}")
|
| 376 |
+
sys.exit(1)
|
| 377 |
+
except ValueError as e:
|
| 378 |
+
print(f"Configuration error: {e}")
|
| 379 |
+
print("Please set GEMINI_API_KEY in your .env file")
|
| 380 |
+
sys.exit(1)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"Unexpected error: {type(e).__name__}: {e}")
|
| 383 |
+
sys.exit(1)
|
src/mortis/tools.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM integration for Mortis conversational AI.
|
| 3 |
+
|
| 4 |
+
This module provides the ask_mortis() function that integrates with the Gemini API
|
| 5 |
+
to generate character-driven responses and coordinate gesture execution.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
from typing import Tuple, Optional
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from .robot import MortisArm
|
| 14 |
+
from .gemini_client import GeminiClient
|
| 15 |
+
from .models import GeminiResponse
|
| 16 |
+
|
| 17 |
+
# Configure logging
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Global instances
|
| 21 |
+
mortis_arm = MortisArm()
|
| 22 |
+
gemini_client = None # Lazy initialization
|
| 23 |
+
stt_service = None # Lazy initialization
|
| 24 |
+
tts_service = None # Lazy initialization
|
| 25 |
+
intent_router = None # Lazy initialization
|
| 26 |
+
smolvla_executor = None # Lazy initialization
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _get_gemini_client() -> GeminiClient:
|
| 30 |
+
"""
|
| 31 |
+
Get or create the global GeminiClient instance.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
GeminiClient instance
|
| 35 |
+
"""
|
| 36 |
+
global gemini_client
|
| 37 |
+
if gemini_client is None:
|
| 38 |
+
gemini_client = GeminiClient()
|
| 39 |
+
logger.info("GeminiClient initialized")
|
| 40 |
+
return gemini_client
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _get_stt_service():
|
| 44 |
+
"""
|
| 45 |
+
Get or create the global STTService instance.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
STTService instance
|
| 49 |
+
"""
|
| 50 |
+
global stt_service
|
| 51 |
+
if stt_service is None:
|
| 52 |
+
from .stt_service import STTService
|
| 53 |
+
stt_service = STTService()
|
| 54 |
+
logger.info("STTService initialized")
|
| 55 |
+
return stt_service
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_tts_service():
|
| 59 |
+
"""
|
| 60 |
+
Get or create the global TTSService instance.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
TTSService instance
|
| 64 |
+
"""
|
| 65 |
+
global tts_service
|
| 66 |
+
if tts_service is None:
|
| 67 |
+
from .tts_service import get_tts_service
|
| 68 |
+
tts_service = get_tts_service()
|
| 69 |
+
logger.info("TTSService initialized")
|
| 70 |
+
return tts_service
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_intent_router():
|
| 74 |
+
"""
|
| 75 |
+
Get or create the global IntentRouter instance.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
IntentRouter instance
|
| 79 |
+
"""
|
| 80 |
+
global intent_router
|
| 81 |
+
if intent_router is None:
|
| 82 |
+
from .intent_router import IntentRouter
|
| 83 |
+
intent_router = IntentRouter()
|
| 84 |
+
logger.info("IntentRouter initialized")
|
| 85 |
+
return intent_router
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _get_smolvla_executor():
|
| 89 |
+
"""
|
| 90 |
+
Get or create the global SmolVLAExecutor instance.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
SmolVLAExecutor instance or None if not configured
|
| 94 |
+
"""
|
| 95 |
+
global smolvla_executor
|
| 96 |
+
if smolvla_executor is None:
|
| 97 |
+
import os
|
| 98 |
+
|
| 99 |
+
# Check if we're in simulation mode
|
| 100 |
+
robot_mode = os.getenv("ROBOT_MODE", "physical").lower()
|
| 101 |
+
if robot_mode == "simulation":
|
| 102 |
+
logger.info("SmolVLA disabled in simulation mode")
|
| 103 |
+
smolvla_executor = None
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
checkpoint_path = os.getenv("SMOLVLA_CHECKPOINT_PATH")
|
| 107 |
+
|
| 108 |
+
if checkpoint_path:
|
| 109 |
+
try:
|
| 110 |
+
from .smolvla_executor import SmolVLAExecutor
|
| 111 |
+
smolvla_executor = SmolVLAExecutor(
|
| 112 |
+
checkpoint_path=checkpoint_path,
|
| 113 |
+
robot_arm=mortis_arm
|
| 114 |
+
)
|
| 115 |
+
logger.info(f"SmolVLAExecutor initialized with checkpoint: {checkpoint_path}")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.warning(f"Failed to initialize SmolVLAExecutor: {e}")
|
| 118 |
+
logger.warning("Manipulation commands will fall back to gestures")
|
| 119 |
+
smolvla_executor = None
|
| 120 |
+
else:
|
| 121 |
+
logger.info("SMOLVLA_CHECKPOINT_PATH not set, manipulation commands will use gestures")
|
| 122 |
+
smolvla_executor = None
|
| 123 |
+
|
| 124 |
+
return smolvla_executor
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def ask_mortis(
|
| 128 |
+
user_msg: Optional[str] = None,
|
| 129 |
+
model_name: Optional[str] = None,
|
| 130 |
+
audio_path: Optional[str] = None
|
| 131 |
+
) -> Tuple[str, str, str]:
|
| 132 |
+
"""
|
| 133 |
+
Send user message to Gemini API and get Mortis response with gesture.
|
| 134 |
+
|
| 135 |
+
This function supports both text and voice input through a unified interface.
|
| 136 |
+
It implements the complete voice-to-text-to-Gemini-to-TTS pipeline with
|
| 137 |
+
latency monitoring.
|
| 138 |
+
|
| 139 |
+
Processing flow:
|
| 140 |
+
1. If audio_path provided, transcribe to text using STT
|
| 141 |
+
2. Connect to robot arm if not already connected
|
| 142 |
+
3. Send text message to Gemini API
|
| 143 |
+
4. Parse structured JSON response
|
| 144 |
+
5. Return message, mood, and gesture for execution
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
user_msg: User's input message text (optional if audio_path provided)
|
| 148 |
+
model_name: Optional Gemini model name (uses default from env if not provided)
|
| 149 |
+
audio_path: Optional path to audio file for voice input
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Tuple of (message, mood, gesture) where:
|
| 153 |
+
- message: Text response from Mortis
|
| 154 |
+
- mood: Emotional mood (e.g., "ominous", "playful")
|
| 155 |
+
- gesture: Gesture to execute (e.g., "wave", "idle")
|
| 156 |
+
|
| 157 |
+
Raises:
|
| 158 |
+
ValueError: If neither user_msg nor audio_path is provided
|
| 159 |
+
|
| 160 |
+
Note:
|
| 161 |
+
This function maintains backward compatibility with the previous API.
|
| 162 |
+
The gesture is returned but not automatically executed - the caller
|
| 163 |
+
is responsible for executing the gesture via mortis_arm.move_arm().
|
| 164 |
+
|
| 165 |
+
Latency monitoring logs are generated for voice processing pipeline.
|
| 166 |
+
"""
|
| 167 |
+
pipeline_start = time.time()
|
| 168 |
+
|
| 169 |
+
# Validate input
|
| 170 |
+
if user_msg is None and audio_path is None:
|
| 171 |
+
raise ValueError("Either user_msg or audio_path must be provided")
|
| 172 |
+
|
| 173 |
+
# Voice input processing
|
| 174 |
+
if audio_path is not None:
|
| 175 |
+
logger.info(f"🎤 Processing voice input from: {audio_path}")
|
| 176 |
+
stt_start = time.time()
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
# Get STT service
|
| 180 |
+
stt = _get_stt_service()
|
| 181 |
+
|
| 182 |
+
# Transcribe audio to text
|
| 183 |
+
user_msg = stt.transcribe(audio_path)
|
| 184 |
+
|
| 185 |
+
stt_latency = time.time() - stt_start
|
| 186 |
+
logger.info(f"⏱️ STT latency: {stt_latency:.2f}s")
|
| 187 |
+
logger.info(f"📝 Transcribed: '{user_msg[:50]}...'")
|
| 188 |
+
|
| 189 |
+
if not user_msg or not user_msg.strip():
|
| 190 |
+
logger.warning("⚠️ STT returned empty transcription")
|
| 191 |
+
return "I couldn't hear you... speak again.", "nervous", "idle"
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"❌ Voice input processing failed: {e}")
|
| 195 |
+
return "The spirits couldn't understand... try again.", "ominous", "idle"
|
| 196 |
+
|
| 197 |
+
# Ensure robot is connected
|
| 198 |
+
if not mortis_arm.connected:
|
| 199 |
+
try:
|
| 200 |
+
mortis_arm.connect()
|
| 201 |
+
logger.info("Robot arm connected")
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"Failed to connect to robot arm: {e}")
|
| 204 |
+
# Continue anyway - we can still generate responses
|
| 205 |
+
|
| 206 |
+
# Get Gemini client
|
| 207 |
+
client = _get_gemini_client()
|
| 208 |
+
|
| 209 |
+
# Reconfigure model if specified
|
| 210 |
+
if model_name:
|
| 211 |
+
client.configure_model(model_name=model_name)
|
| 212 |
+
logger.info(f"Using Gemini model: {model_name}")
|
| 213 |
+
|
| 214 |
+
# Send message to Gemini
|
| 215 |
+
logger.info(f"💬 Asking Mortis: {user_msg[:50]}...")
|
| 216 |
+
gemini_start = time.time()
|
| 217 |
+
|
| 218 |
+
response_json = client.send_message(user_msg)
|
| 219 |
+
|
| 220 |
+
gemini_latency = time.time() - gemini_start
|
| 221 |
+
logger.info(f"⏱️ Gemini latency: {gemini_latency:.2f}s")
|
| 222 |
+
|
| 223 |
+
# Parse response using IntentRouter
|
| 224 |
+
try:
|
| 225 |
+
# Get intent router
|
| 226 |
+
router = _get_intent_router()
|
| 227 |
+
|
| 228 |
+
# Parse Gemini response into Intent
|
| 229 |
+
intent = router.parse_gemini_response(response_json)
|
| 230 |
+
|
| 231 |
+
# Extract fields for return
|
| 232 |
+
message = intent.message
|
| 233 |
+
mood = intent.mood
|
| 234 |
+
gesture = intent.gesture if intent.gesture else "idle"
|
| 235 |
+
|
| 236 |
+
# Route based on intent type
|
| 237 |
+
execution_path = router.route_intent(intent)
|
| 238 |
+
|
| 239 |
+
if execution_path == "manipulation":
|
| 240 |
+
# Valid manipulation command - attempt SmolVLA execution
|
| 241 |
+
logger.info(f"🤖 Manipulation command detected: '{intent.command}'")
|
| 242 |
+
|
| 243 |
+
# Try to get SmolVLA executor
|
| 244 |
+
executor = _get_smolvla_executor()
|
| 245 |
+
|
| 246 |
+
if executor is not None:
|
| 247 |
+
try:
|
| 248 |
+
# Execute manipulation task
|
| 249 |
+
logger.info(f"Executing manipulation task: {intent.command}")
|
| 250 |
+
success = executor.execute(intent.command)
|
| 251 |
+
|
| 252 |
+
if success:
|
| 253 |
+
logger.info(f"✅ Manipulation task completed successfully")
|
| 254 |
+
else:
|
| 255 |
+
logger.warning(f"⚠️ Manipulation task did not complete fully")
|
| 256 |
+
|
| 257 |
+
# Return with "manipulation" as gesture to indicate manipulation was executed
|
| 258 |
+
gesture = "manipulation"
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"❌ SmolVLA execution failed: {e}")
|
| 262 |
+
logger.info("Falling back to gesture execution")
|
| 263 |
+
|
| 264 |
+
# Fallback to gesture execution
|
| 265 |
+
gesture = "idle"
|
| 266 |
+
if mortis_arm.connected:
|
| 267 |
+
mortis_arm.move_arm(gesture)
|
| 268 |
+
else:
|
| 269 |
+
# No SmolVLA executor available, fall back to gesture
|
| 270 |
+
logger.warning("SmolVLA executor not available, falling back to gesture")
|
| 271 |
+
gesture = "idle"
|
| 272 |
+
if mortis_arm.connected:
|
| 273 |
+
mortis_arm.move_arm(gesture)
|
| 274 |
+
|
| 275 |
+
elif execution_path == "gesture":
|
| 276 |
+
# Conversational response with gesture
|
| 277 |
+
logger.info(f"💬 Conversation with gesture: {gesture}")
|
| 278 |
+
|
| 279 |
+
# Execute gesture immediately
|
| 280 |
+
if mortis_arm.connected:
|
| 281 |
+
try:
|
| 282 |
+
mortis_arm.move_arm(gesture)
|
| 283 |
+
except Exception as e:
|
| 284 |
+
logger.error(f"Failed to execute gesture '{gesture}': {e}")
|
| 285 |
+
|
| 286 |
+
elif execution_path == "invalid":
|
| 287 |
+
# Invalid intent - fall back to gesture
|
| 288 |
+
logger.warning(f"⚠️ Invalid intent: {intent.validation_error}")
|
| 289 |
+
logger.info("Falling back to conversational gesture")
|
| 290 |
+
|
| 291 |
+
# Use gesture from intent or default to idle
|
| 292 |
+
gesture = intent.gesture if intent.gesture else "idle"
|
| 293 |
+
|
| 294 |
+
# Execute gesture
|
| 295 |
+
if mortis_arm.connected:
|
| 296 |
+
try:
|
| 297 |
+
mortis_arm.move_arm(gesture)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error(f"Failed to execute fallback gesture '{gesture}': {e}")
|
| 300 |
+
|
| 301 |
+
# Calculate total pipeline latency
|
| 302 |
+
total_latency = time.time() - pipeline_start
|
| 303 |
+
logger.info(f"⏱️ Total pipeline latency: {total_latency:.2f}s")
|
| 304 |
+
logger.info(f"👻 Mortis responds (path: {execution_path}, mood: {mood}, gesture: {gesture})")
|
| 305 |
+
|
| 306 |
+
return message, mood, gesture
|
| 307 |
+
|
| 308 |
+
except (ValueError, KeyError) as e:
|
| 309 |
+
# If parsing fails, return safe defaults
|
| 310 |
+
logger.error(f"Failed to parse Gemini response: {e}")
|
| 311 |
+
logger.error(f"Response JSON: {response_json}")
|
| 312 |
+
|
| 313 |
+
# Return fallback response
|
| 314 |
+
return "The spirits are confused... try again.", "ominous", "idle"
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def ask_mortis_with_voice(
|
| 318 |
+
user_msg: Optional[str] = None,
|
| 319 |
+
model_name: Optional[str] = None,
|
| 320 |
+
audio_path: Optional[str] = None,
|
| 321 |
+
generate_audio: bool = True
|
| 322 |
+
) -> Tuple[str, str, str, Optional[str]]:
|
| 323 |
+
"""
|
| 324 |
+
Complete voice-to-text-to-Gemini-to-TTS pipeline with audio output.
|
| 325 |
+
|
| 326 |
+
This is a convenience function that wraps ask_mortis() and adds TTS
|
| 327 |
+
generation for the response. It provides the full multi-modal experience.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
user_msg: User's input message text (optional if audio_path provided)
|
| 331 |
+
model_name: Optional Gemini model name
|
| 332 |
+
audio_path: Optional path to audio file for voice input
|
| 333 |
+
generate_audio: Whether to generate audio output (default: True)
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Tuple of (message, mood, gesture, audio_path) where:
|
| 337 |
+
- message: Text response from Mortis
|
| 338 |
+
- mood: Emotional mood
|
| 339 |
+
- gesture: Gesture to execute
|
| 340 |
+
- audio_path: Path to generated audio file (None if generation fails)
|
| 341 |
+
|
| 342 |
+
Note:
|
| 343 |
+
This function logs latency for the complete voice processing pipeline
|
| 344 |
+
including STT, Gemini inference, and TTS generation.
|
| 345 |
+
"""
|
| 346 |
+
pipeline_start = time.time()
|
| 347 |
+
|
| 348 |
+
# Get text response from Gemini (handles STT if audio_path provided)
|
| 349 |
+
message, mood, gesture = ask_mortis(
|
| 350 |
+
user_msg=user_msg,
|
| 351 |
+
model_name=model_name,
|
| 352 |
+
audio_path=audio_path
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Generate audio response if requested
|
| 356 |
+
response_audio_path = None
|
| 357 |
+
if generate_audio:
|
| 358 |
+
tts_start = time.time()
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# Get TTS service
|
| 362 |
+
tts = _get_tts_service()
|
| 363 |
+
|
| 364 |
+
# Generate audio
|
| 365 |
+
response_audio_path = tts.synthesize(message)
|
| 366 |
+
|
| 367 |
+
tts_latency = time.time() - tts_start
|
| 368 |
+
logger.info(f"⏱️ TTS latency: {tts_latency:.2f}s")
|
| 369 |
+
|
| 370 |
+
if response_audio_path:
|
| 371 |
+
logger.info(f"🔊 Audio generated: {response_audio_path}")
|
| 372 |
+
else:
|
| 373 |
+
logger.warning("⚠️ TTS returned None")
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"❌ TTS generation failed: {e}")
|
| 377 |
+
# Continue without audio - text response is still valid
|
| 378 |
+
|
| 379 |
+
# Log total pipeline latency including TTS
|
| 380 |
+
total_latency = time.time() - pipeline_start
|
| 381 |
+
logger.info(f"⏱️ Complete voice pipeline latency: {total_latency:.2f}s")
|
| 382 |
+
|
| 383 |
+
return message, mood, gesture, response_audio_path
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
# Configure logging for testing
|
| 389 |
+
logging.basicConfig(level=logging.INFO)
|
| 390 |
+
|
| 391 |
+
# Test conversational interactions
|
| 392 |
+
print("=== Test 1: Greeting ===")
|
| 393 |
+
message, mood, gesture = ask_mortis("Mortis, someone is entering the lab… act!")
|
| 394 |
+
print(f"Message: {message}")
|
| 395 |
+
print(f"Mood: {mood}")
|
| 396 |
+
print(f"Gesture: {gesture}")
|
| 397 |
+
print()
|
| 398 |
+
|
| 399 |
+
print("=== Test 2: Introduction ===")
|
| 400 |
+
message, mood, gesture = ask_mortis("Introduce yourself with a sinister bow.")
|
| 401 |
+
print(f"Message: {message}")
|
| 402 |
+
print(f"Mood: {mood}")
|
| 403 |
+
print(f"Gesture: {gesture}")
|
| 404 |
+
print()
|
| 405 |
+
|
| 406 |
+
print("=== Test 3: Action sequence ===")
|
| 407 |
+
message, mood, gesture = ask_mortis("Grab the cursed vial and then release it.")
|
| 408 |
+
print(f"Message: {message}")
|
| 409 |
+
print(f"Mood: {mood}")
|
| 410 |
+
print(f"Gesture: {gesture}")
|
| 411 |
+
print()
|
| 412 |
+
|
| 413 |
+
print("=== Test 4: Manipulation command ===")
|
| 414 |
+
message, mood, gesture = ask_mortis("Can you move the skull to the green cup?")
|
| 415 |
+
print(f"Message: {message}")
|
| 416 |
+
print(f"Mood: {mood}")
|
| 417 |
+
print(f"Gesture: {gesture}")
|
| 418 |
+
print()
|
src/mortis/tts_service.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text-to-Speech service for Mortis voice output.
|
| 3 |
+
|
| 4 |
+
Provides TTS capabilities using Google Cloud Text-to-Speech API with
|
| 5 |
+
fallback to local gTTS for offline scenarios.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TTSService:
|
| 18 |
+
"""
|
| 19 |
+
Text-to-Speech service for converting Mortis responses to audio.
|
| 20 |
+
|
| 21 |
+
Uses Google Cloud TTS as primary service with gTTS as fallback.
|
| 22 |
+
Configured for a deep, ominous voice suitable for Mortis character.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
output_dir: str = "outputs",
|
| 28 |
+
use_google_tts: bool = True,
|
| 29 |
+
voice_name: str = "en-US-Neural2-D",
|
| 30 |
+
speaking_rate: float = 0.9,
|
| 31 |
+
pitch: float = -2.0
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Initialize TTS service.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
output_dir: Directory for generated audio files
|
| 38 |
+
use_google_tts: Whether to use Google Cloud TTS (requires credentials)
|
| 39 |
+
voice_name: Google TTS voice name (Neural2-D is deep male voice)
|
| 40 |
+
speaking_rate: Speech speed (0.9 = slightly slower for ominous effect)
|
| 41 |
+
pitch: Voice pitch (-2.0 = lower for spooky voice)
|
| 42 |
+
"""
|
| 43 |
+
self.output_dir = Path(output_dir)
|
| 44 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
self.use_google_tts = use_google_tts
|
| 47 |
+
self.voice_name = voice_name
|
| 48 |
+
self.speaking_rate = speaking_rate
|
| 49 |
+
self.pitch = pitch
|
| 50 |
+
|
| 51 |
+
# Try to initialize Google TTS client
|
| 52 |
+
self.google_client = None
|
| 53 |
+
self.texttospeech = None
|
| 54 |
+
if self.use_google_tts:
|
| 55 |
+
try:
|
| 56 |
+
from google.cloud import texttospeech
|
| 57 |
+
self.google_client = texttospeech.TextToSpeechClient()
|
| 58 |
+
self.texttospeech = texttospeech
|
| 59 |
+
logger.info("Google Cloud TTS initialized successfully")
|
| 60 |
+
except ImportError as e:
|
| 61 |
+
logger.warning(f"Google Cloud TTS not available: {e}. Will use gTTS fallback.")
|
| 62 |
+
self.use_google_tts = False
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.warning(f"Failed to initialize Google TTS: {e}. Will use gTTS fallback.")
|
| 65 |
+
self.use_google_tts = False
|
| 66 |
+
|
| 67 |
+
logger.info(f"TTS Service initialized (Google TTS: {self.use_google_tts})")
|
| 68 |
+
|
| 69 |
+
def synthesize(self, text: str, filename: Optional[str] = None) -> Optional[str]:
|
| 70 |
+
"""
|
| 71 |
+
Convert text to speech audio file.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
text: Text to convert to speech
|
| 75 |
+
filename: Optional custom filename (without extension)
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Path to generated audio file, or None if synthesis fails
|
| 79 |
+
"""
|
| 80 |
+
if not text or not text.strip():
|
| 81 |
+
logger.warning("Empty text provided to TTS service")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
# Generate filename if not provided
|
| 85 |
+
if filename is None:
|
| 86 |
+
timestamp = int(time.time() * 1000)
|
| 87 |
+
filename = f"mortis_response_{timestamp}"
|
| 88 |
+
|
| 89 |
+
# Try Google TTS first
|
| 90 |
+
if self.use_google_tts and self.google_client:
|
| 91 |
+
try:
|
| 92 |
+
audio_path = self._synthesize_google_tts(text, filename)
|
| 93 |
+
logger.info(f"Generated audio with Google TTS: {audio_path}")
|
| 94 |
+
return audio_path
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Google TTS failed: {e}. Falling back to gTTS.")
|
| 97 |
+
|
| 98 |
+
# Fallback to gTTS
|
| 99 |
+
try:
|
| 100 |
+
audio_path = self._synthesize_gtts(text, filename)
|
| 101 |
+
logger.info(f"Generated audio with gTTS: {audio_path}")
|
| 102 |
+
return audio_path
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"gTTS also failed: {e}. No audio generated.")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def _synthesize_google_tts(self, text: str, filename: str) -> str:
|
| 108 |
+
"""
|
| 109 |
+
Synthesize speech using Google Cloud TTS.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
text: Text to synthesize
|
| 113 |
+
filename: Base filename (without extension)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Path to generated MP3 file
|
| 117 |
+
"""
|
| 118 |
+
# Prepare synthesis input
|
| 119 |
+
synthesis_input = self.texttospeech.SynthesisInput(text=text)
|
| 120 |
+
|
| 121 |
+
# Configure voice parameters for Mortis character
|
| 122 |
+
voice = self.texttospeech.VoiceSelectionParams(
|
| 123 |
+
language_code="en-US",
|
| 124 |
+
name=self.voice_name,
|
| 125 |
+
ssml_gender=self.texttospeech.SsmlVoiceGender.MALE
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Configure audio output
|
| 129 |
+
audio_config = self.texttospeech.AudioConfig(
|
| 130 |
+
audio_encoding=self.texttospeech.AudioEncoding.MP3,
|
| 131 |
+
speaking_rate=self.speaking_rate,
|
| 132 |
+
pitch=self.pitch
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Perform synthesis
|
| 136 |
+
response = self.google_client.synthesize_speech(
|
| 137 |
+
input=synthesis_input,
|
| 138 |
+
voice=voice,
|
| 139 |
+
audio_config=audio_config
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Save audio file
|
| 143 |
+
output_path = self.output_dir / f"{filename}.mp3"
|
| 144 |
+
with open(output_path, "wb") as out:
|
| 145 |
+
out.write(response.audio_content)
|
| 146 |
+
|
| 147 |
+
return str(output_path)
|
| 148 |
+
|
| 149 |
+
def _synthesize_gtts(self, text: str, filename: str) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Synthesize speech using gTTS (local fallback).
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
text: Text to synthesize
|
| 155 |
+
filename: Base filename (without extension)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Path to generated MP3 file
|
| 159 |
+
"""
|
| 160 |
+
from gtts import gTTS
|
| 161 |
+
|
| 162 |
+
# Create TTS object with slower speech for ominous effect
|
| 163 |
+
tts = gTTS(text=text, lang='en', slow=True)
|
| 164 |
+
|
| 165 |
+
# Save audio file
|
| 166 |
+
output_path = self.output_dir / f"{filename}.mp3"
|
| 167 |
+
tts.save(str(output_path))
|
| 168 |
+
|
| 169 |
+
return str(output_path)
|
| 170 |
+
|
| 171 |
+
def cleanup_old_files(self, max_age_seconds: int = 3600):
|
| 172 |
+
"""
|
| 173 |
+
Remove old audio files to prevent disk space issues.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
max_age_seconds: Maximum age of files to keep (default: 1 hour)
|
| 177 |
+
"""
|
| 178 |
+
current_time = time.time()
|
| 179 |
+
removed_count = 0
|
| 180 |
+
|
| 181 |
+
for audio_file in self.output_dir.glob("mortis_response_*.mp3"):
|
| 182 |
+
try:
|
| 183 |
+
file_age = current_time - audio_file.stat().st_mtime
|
| 184 |
+
if file_age > max_age_seconds:
|
| 185 |
+
audio_file.unlink()
|
| 186 |
+
removed_count += 1
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.warning(f"Failed to remove old file {audio_file}: {e}")
|
| 189 |
+
|
| 190 |
+
if removed_count > 0:
|
| 191 |
+
logger.info(f"Cleaned up {removed_count} old audio files")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Global TTS service instance
|
| 195 |
+
_tts_service: Optional[TTSService] = None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_tts_service() -> TTSService:
|
| 199 |
+
"""
|
| 200 |
+
Get or create global TTS service instance.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Singleton TTSService instance
|
| 204 |
+
"""
|
| 205 |
+
global _tts_service
|
| 206 |
+
if _tts_service is None:
|
| 207 |
+
# Check if Google Cloud credentials are available
|
| 208 |
+
use_google = bool(os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
|
| 209 |
+
_tts_service = TTSService(use_google_tts=use_google)
|
| 210 |
+
return _tts_service
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def synthesize_speech(text: str, filename: Optional[str] = None) -> Optional[str]:
|
| 214 |
+
"""
|
| 215 |
+
Convenience function to synthesize speech using global TTS service.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
text: Text to convert to speech
|
| 219 |
+
filename: Optional custom filename
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Path to generated audio file, or None if synthesis fails
|
| 223 |
+
"""
|
| 224 |
+
service = get_tts_service()
|
| 225 |
+
return service.synthesize(text, filename)
|