Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Agent.md +190 -0
- README.md +6 -23
- app.py +147 -333
- config.ini +4 -9
- pyproject.toml +4 -4
- requirements.txt +1 -1
- test_api.py +58 -0
- tinytroupe/agent/memory.py +0 -18
- tinytroupe/agent/tiny_person.py +1 -70
- tinytroupe/config.ini +5 -5
- tinytroupe/factory/tiny_person_factory.py +10 -103
- tinytroupe/openai_utils.py +58 -74
- tinytroupe/utils/llm.py +1 -1
- tinytroupe/utils/semantics.py +0 -42
Agent.md
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agent.md
|
| 2 |
+
|
| 3 |
+
## 1. Deployment Configuration
|
| 4 |
+
|
| 5 |
+
### Target Space
|
| 6 |
+
- **Profile:** `AUXteam`
|
| 7 |
+
- **Space:** `tiny_factory`
|
| 8 |
+
- **Full Identifier:** `AUXteam/tiny_factory`
|
| 9 |
+
- **Frontend Port:** `7860` (mandatory for all Hugging Face Spaces)
|
| 10 |
+
|
| 11 |
+
### Deployment Method
|
| 12 |
+
Choose the correct SDK based on the app type based on the codebase language:
|
| 13 |
+
|
| 14 |
+
- **Gradio SDK** — for Gradio applications
|
| 15 |
+
- **Streamlit SDK** — for Streamlit applications
|
| 16 |
+
- **Docker SDK** — for all other applications (recommended default for flexibility)
|
| 17 |
+
|
| 18 |
+
### HF Token
|
| 19 |
+
- The environment variable **`$HF_TOKEN` will always be provided at execution time**.
|
| 20 |
+
- Never hardcode the token. Always read it from the environment.
|
| 21 |
+
- All monitoring and log‑streaming commands rely on `$HF_TOKEN`.
|
| 22 |
+
|
| 23 |
+
### Required Files
|
| 24 |
+
- `Dockerfile` (or `app.py` for Gradio/Streamlit SDKs)
|
| 25 |
+
- `README.md` with Hugging Face YAML frontmatter:
|
| 26 |
+
```yaml
|
| 27 |
+
---
|
| 28 |
+
title: <APP NAME>
|
| 29 |
+
sdk: docker | gradio | streamlit
|
| 30 |
+
app_port: 7860
|
| 31 |
+
---
|
| 32 |
+
```
|
| 33 |
+
- `.hfignore` to exclude unnecessary files
|
| 34 |
+
- This `Agent.md` file (must be committed before deployment)
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## 2. API Exposure and Documentation
|
| 39 |
+
|
| 40 |
+
### Mandatory Endpoints
|
| 41 |
+
Every deployment **must** expose:
|
| 42 |
+
|
| 43 |
+
- **`/health`**
|
| 44 |
+
- Returns HTTP 200 when the app is ready.
|
| 45 |
+
- Required for Hugging Face to transition the Space from *starting* → *running*.
|
| 46 |
+
|
| 47 |
+
- **`/api-docs`**
|
| 48 |
+
- Documents **all** available API endpoints.
|
| 49 |
+
- Must be reachable at:
|
| 50 |
+
`https://HF_PROFILE-tiny_factory.hf.space/api-docs`
|
| 51 |
+
|
| 52 |
+
### Functional Endpoints
|
| 53 |
+
Document each endpoint here. For every endpoint, include:
|
| 54 |
+
|
| 55 |
+
- **Method:** GET/POST/PUT/DELETE
|
| 56 |
+
- **Path:** `/predict`, `/generate`, `/upload`, etc.
|
| 57 |
+
- **Purpose:** What the endpoint does
|
| 58 |
+
- **Request Example:** JSON or query parameters
|
| 59 |
+
- **Response Example:** JSON schema or example payload
|
| 60 |
+
|
| 61 |
+
Example format:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
### /predict
|
| 65 |
+
- Method: POST
|
| 66 |
+
- Purpose: Run model inference
|
| 67 |
+
- Request:
|
| 68 |
+
{
|
| 69 |
+
"text": "hello world"
|
| 70 |
+
}
|
| 71 |
+
- Response:
|
| 72 |
+
{
|
| 73 |
+
"prediction": "…"
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
All endpoints listed here **must** appear in `/api-docs`.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## 3. Deployment Workflow
|
| 82 |
+
|
| 83 |
+
### Standard Deployment Command
|
| 84 |
+
After any code change, run:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
hf upload AUXteam/tiny_factory --repo-type=space
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
This command must be executed **after updating and committing Agent.md**.
|
| 91 |
+
|
| 92 |
+
### Deployment Steps
|
| 93 |
+
1. Ensure all code changes are committed.
|
| 94 |
+
2. Ensure `Agent.md` is updated and committed.
|
| 95 |
+
3. Run the upload command.
|
| 96 |
+
4. Wait for the Space to build.
|
| 97 |
+
5. Monitor logs (see next section).
|
| 98 |
+
6. When the Space is running, execute all test cases.
|
| 99 |
+
|
| 100 |
+
### Continuous Deployment Rule
|
| 101 |
+
After **every** relevant edit (logic, dependencies, API changes):
|
| 102 |
+
|
| 103 |
+
- Update `Agent.md`
|
| 104 |
+
- Redeploy using the upload command
|
| 105 |
+
- Re-run all test cases
|
| 106 |
+
- Confirm `/health` and `/api-docs` are functional
|
| 107 |
+
|
| 108 |
+
This applies even for long-running projects.
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## 4. Monitoring and Logs
|
| 113 |
+
|
| 114 |
+
### Build Logs (SSE)
|
| 115 |
+
```bash
|
| 116 |
+
curl -N \
|
| 117 |
+
-H "Authorization: Bearer $HF_TOKEN" \
|
| 118 |
+
"https://huggingface.co/api/spaces/AUXteam/tiny_factory/logs/build"
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Run Logs (SSE)
|
| 122 |
+
```bash
|
| 123 |
+
curl -N \
|
| 124 |
+
-H "Authorization: Bearer $HF_TOKEN" \
|
| 125 |
+
"https://huggingface.co/api/spaces/AUXteam/tiny_factory/logs/run"
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Notes
|
| 129 |
+
- If the Space stays in *starting* for too long, `/health` is usually failing.
|
| 130 |
+
- If the Space times out after ~30 minutes, check logs immediately.
|
| 131 |
+
- Fix issues, commit changes, redeploy.
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## 5. Test Run Cases (Mandatory After Every Deployment)
|
| 136 |
+
|
| 137 |
+
These tests ensure the agentic system can verify the deployment automatically.
|
| 138 |
+
|
| 139 |
+
### 1. Health Check
|
| 140 |
+
```
|
| 141 |
+
GET https://HF_PROFILE-tiny_factory.hf.space/health
|
| 142 |
+
Expected: HTTP 200, body: {"status": "ok"} or similar
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### 2. API Docs Check
|
| 146 |
+
```
|
| 147 |
+
GET https://HF_PROFILE-tiny_factory.hf.space/api-docs
|
| 148 |
+
Expected: HTTP 200, valid documentation UI or JSON spec
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### 3. Functional Endpoint Tests
|
| 152 |
+
For each endpoint documented above, define:
|
| 153 |
+
|
| 154 |
+
- Example request
|
| 155 |
+
- Expected response structure
|
| 156 |
+
- Validation criteria (e.g., non-empty output, valid JSON)
|
| 157 |
+
|
| 158 |
+
Example:
|
| 159 |
+
|
| 160 |
+
```
|
| 161 |
+
POST https://HF_PROFILE-tiny_factory.hf.space/predict
|
| 162 |
+
Payload:
|
| 163 |
+
{
|
| 164 |
+
"text": "test"
|
| 165 |
+
}
|
| 166 |
+
Expected:
|
| 167 |
+
- HTTP 200
|
| 168 |
+
- JSON with key "prediction"
|
| 169 |
+
- No error fields
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
### 4. End-to-End Behaviour
|
| 173 |
+
- Confirm the UI loads (if applicable)
|
| 174 |
+
- Confirm API endpoints respond within reasonable time
|
| 175 |
+
- Confirm no errors appear in run logs
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## 6. Maintenance Rules
|
| 180 |
+
|
| 181 |
+
- `Agent.md` must always reflect the **current** deployment configuration, API surface, and test cases.
|
| 182 |
+
- Any change to:
|
| 183 |
+
- API routes
|
| 184 |
+
- Dockerfile
|
| 185 |
+
- Dependencies
|
| 186 |
+
- App logic
|
| 187 |
+
- Deployment method
|
| 188 |
+
requires updating this file.
|
| 189 |
+
- This file must be committed **before** every deployment.
|
| 190 |
+
- This file is the operational contract for autonomous agents interacting with the project.
|
README.md
CHANGED
|
@@ -1,29 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
-
sdk:
|
| 7 |
-
|
|
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
Deep Persona Factory is a specialized simulation engine for persona generation and social content testing.
|
| 14 |
-
|
| 15 |
-
## Features
|
| 16 |
-
- **Social Network Engine:** Graph-based modeling and influence propagation.
|
| 17 |
-
- **Prediction Engine:** ML and LLM-based engagement scoring.
|
| 18 |
-
- **Deep Persona Generation:** Sequential enrichment for high-fidelity character profiles.
|
| 19 |
-
- **API Documentation:** Accessible via \`/api-docs\`.
|
| 20 |
-
- **Health Check:** Accessible via \`/health\`.
|
| 21 |
-
|
| 22 |
-
## API Documentation
|
| 23 |
-
The application exposes a mandatory \`/api-docs\` endpoint providing Swagger UI for all available endpoints.
|
| 24 |
-
|
| 25 |
-
## Local Setup
|
| 26 |
-
\`\`\`bash
|
| 27 |
-
pip install -r requirements.txt
|
| 28 |
-
uvicorn app:app --host 0.0.0.0 --port 7860
|
| 29 |
-
\`\`\`
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Tiny Factory
|
| 3 |
+
emoji: 💻
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.3.0
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,248 +1,181 @@
|
|
| 1 |
import sys
|
| 2 |
import os
|
| 3 |
-
from fastapi import FastAPI
|
| 4 |
-
from fastapi.responses import RedirectResponse
|
| 5 |
import gradio as gr
|
| 6 |
import json
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
from
|
| 10 |
-
from deeppersona.simulation_manager import SimulationManager, SimulationConfig
|
| 11 |
-
from deeppersona.agent.social_types import Content
|
| 12 |
-
from huggingface_hub import hf_hub_download, upload_file
|
| 13 |
-
|
| 14 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 15 |
-
REPO_ID = "AUXteam/tiny_factory"
|
| 16 |
-
PERSONA_BASE_FILE = "persona_base.json"
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
return []
|
| 24 |
-
try:
|
| 25 |
-
path = hf_hub_download(repo_id=REPO_ID, filename=PERSONA_BASE_FILE, repo_type="space", token=HF_TOKEN)
|
| 26 |
-
with open(path, 'r', encoding='utf-8') as f:
|
| 27 |
-
return json.load(f)
|
| 28 |
-
except Exception as e:
|
| 29 |
-
print(f"Error loading persona base: {e}")
|
| 30 |
-
return []
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
json.dump(personas, f, indent=4)
|
| 37 |
-
upload_file(
|
| 38 |
-
path_or_fileobj=PERSONA_BASE_FILE,
|
| 39 |
-
path_in_repo=PERSONA_BASE_FILE,
|
| 40 |
-
repo_id=REPO_ID,
|
| 41 |
-
repo_type="space",
|
| 42 |
-
token=HF_TOKEN
|
| 43 |
-
)
|
| 44 |
-
except Exception as e:
|
| 45 |
-
print(f"Error saving persona base: {e}")
|
| 46 |
|
| 47 |
-
def
|
| 48 |
-
|
| 49 |
-
os.environ["BLABLADOR_API_KEY"] = blablador_key
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
generated = factory.generate_people(number_of_people=int(num_personas))
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
"name": p.name,
|
| 62 |
-
"persona": p._persona,
|
| 63 |
-
"minibio": p.minibio()
|
| 64 |
-
})
|
| 65 |
-
save_persona_base(base)
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def find_best_persona(criteria):
|
| 70 |
-
base = load_persona_base()
|
| 71 |
-
if not base: return {"error": "No personas in base"}
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"""
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
| 81 |
-
|
| 82 |
-
# Add example agents from disk
|
| 83 |
-
example_files = glob.glob("deeppersona/examples/agents/*.json")
|
| 84 |
-
for ef in example_files:
|
| 85 |
-
try:
|
| 86 |
-
with open(ef, 'r') as f:
|
| 87 |
-
data = json.load(f)
|
| 88 |
-
base.append({
|
| 89 |
-
"name": data.get("name", "Unknown"),
|
| 90 |
-
"persona": data.get("persona", {}),
|
| 91 |
-
"minibio": "Example agent"
|
| 92 |
-
})
|
| 93 |
-
except: pass
|
| 94 |
-
|
| 95 |
-
relevant = select_relevant_personas_utility(base, context)
|
| 96 |
-
return relevant
|
| 97 |
-
|
| 98 |
-
# API Wrappers for SimulationManager
|
| 99 |
-
def generate_social_network_api(name, persona_count, network_type, focus_group_name=None):
|
| 100 |
-
try:
|
| 101 |
-
config = SimulationConfig(name=name, persona_count=int(persona_count), network_type=network_type)
|
| 102 |
-
simulation = simulation_manager.create_simulation(config, focus_group_name=focus_group_name)
|
| 103 |
-
return {"simulation_id": simulation.id, "status": "created"}
|
| 104 |
-
except Exception as e:
|
| 105 |
-
return {"error": str(e)}
|
| 106 |
-
|
| 107 |
-
def predict_engagement_api(simulation_id, content_text, format="text"):
|
| 108 |
-
try:
|
| 109 |
-
content = Content(text=content_text, format=format)
|
| 110 |
-
results = simulation_manager.predict_engagement(simulation_id, content)
|
| 111 |
-
return results
|
| 112 |
-
except Exception as e:
|
| 113 |
-
return {"error": str(e)}
|
| 114 |
-
|
| 115 |
-
def start_simulation_async_api(simulation_id, content_text, format="text"):
|
| 116 |
-
try:
|
| 117 |
-
content = Content(text=content_text, format=format)
|
| 118 |
-
simulation_manager.run_simulation(simulation_id, content, background=True)
|
| 119 |
-
return {"status": "started", "simulation_id": simulation_id}
|
| 120 |
-
except Exception as e:
|
| 121 |
-
return {"error": str(e)}
|
| 122 |
-
|
| 123 |
-
def get_simulation_status_api(simulation_id):
|
| 124 |
-
try:
|
| 125 |
-
sim = simulation_manager.get_simulation(simulation_id)
|
| 126 |
-
if not sim: return {"error": "Not found"}
|
| 127 |
-
return {
|
| 128 |
-
"status": sim.status,
|
| 129 |
-
"progress": sim.progress,
|
| 130 |
-
"result_ready": sim.result is not None
|
| 131 |
-
}
|
| 132 |
-
except Exception as e:
|
| 133 |
-
return {"error": str(e)}
|
| 134 |
-
|
| 135 |
-
def send_chat_message_api(simulation_id, sender, message):
|
| 136 |
-
try:
|
| 137 |
-
res = simulation_manager.chat_with_simulation(simulation_id, sender, message)
|
| 138 |
-
return res
|
| 139 |
-
except Exception as e:
|
| 140 |
-
return {"error": str(e)}
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
sim = simulation_manager.get_simulation(simulation_id)
|
| 145 |
-
if not sim: return {"error": "Not found"}
|
| 146 |
-
return sim.chat_history
|
| 147 |
-
except Exception as e:
|
| 148 |
-
return {"error": str(e)}
|
| 149 |
-
|
| 150 |
-
def generate_variants_api(content_text, count=5):
|
| 151 |
-
try:
|
| 152 |
-
content = Content(text=content_text)
|
| 153 |
-
variants = simulation_manager.generate_content_variants(content, int(count))
|
| 154 |
-
return [v.text for v in variants]
|
| 155 |
-
except Exception as e:
|
| 156 |
-
return {"error": str(e)}
|
| 157 |
-
|
| 158 |
-
def list_simulations_api():
|
| 159 |
-
return list(simulation_manager.simulations.keys())
|
| 160 |
-
|
| 161 |
-
def list_personas_api(simulation_id):
|
| 162 |
-
try:
|
| 163 |
-
sim = simulation_manager.get_simulation(simulation_id)
|
| 164 |
-
if not sim: return []
|
| 165 |
-
return [p.name for p in sim.personas]
|
| 166 |
-
except Exception as e:
|
| 167 |
-
return {"error": str(e)}
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if not sim: return None
|
| 173 |
-
for p in sim.personas:
|
| 174 |
-
if p.name == persona_name: return p._persona
|
| 175 |
-
return None
|
| 176 |
-
except Exception as e:
|
| 177 |
-
return {"error": str(e)}
|
| 178 |
-
|
| 179 |
-
def delete_simulation_api(simulation_id):
|
| 180 |
-
try:
|
| 181 |
-
success = simulation_manager.delete_simulation(simulation_id)
|
| 182 |
-
return {"success": success}
|
| 183 |
-
except Exception as e:
|
| 184 |
-
return {"error": str(e)}
|
| 185 |
-
|
| 186 |
-
def export_simulation_api(simulation_id):
|
| 187 |
try:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
})
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
return {"error": str(e)}
|
| 214 |
-
|
| 215 |
-
def list_focus_groups_api():
|
| 216 |
-
try:
|
| 217 |
-
return simulation_manager.list_focus_groups()
|
| 218 |
-
except Exception as e:
|
| 219 |
-
return {"error": str(e)}
|
| 220 |
|
| 221 |
-
def save_focus_group_api(name, simulation_id):
|
| 222 |
-
try:
|
| 223 |
-
sim = simulation_manager.get_simulation(simulation_id)
|
| 224 |
-
if not sim: return {"error": "Simulation not found"}
|
| 225 |
-
simulation_manager.save_focus_group(name, sim.personas)
|
| 226 |
-
return {"status": "success", "name": name}
|
| 227 |
except Exception as e:
|
| 228 |
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
# Gradio Interface
|
| 231 |
with gr.Blocks() as demo:
|
| 232 |
-
gr.Markdown("<h1>
|
| 233 |
with gr.Row():
|
| 234 |
with gr.Column():
|
| 235 |
business_description_input = gr.Textbox(label="What is your business about?", lines=5)
|
| 236 |
customer_profile_input = gr.Textbox(label="Information about your customer profile", lines=5)
|
| 237 |
num_personas_input = gr.Number(label="Number of personas to generate", value=1, minimum=1, step=1)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
| 244 |
with gr.Column():
|
| 245 |
-
output_json = gr.JSON(label="
|
| 246 |
|
| 247 |
generate_button.click(
|
| 248 |
fn=generate_personas,
|
|
@@ -251,126 +184,7 @@ with gr.Blocks() as demo:
|
|
| 251 |
api_name="generate_personas"
|
| 252 |
)
|
| 253 |
|
| 254 |
-
find_button.click(
|
| 255 |
-
fn=find_best_persona,
|
| 256 |
-
inputs=[criteria_input],
|
| 257 |
-
outputs=output_json,
|
| 258 |
-
api_name="find_best_persona"
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
with gr.Tab("Identify Deep Personas API", visible=False):
|
| 262 |
-
api_id_context = gr.Textbox(label="Context")
|
| 263 |
-
api_id_btn = gr.Button("Identify Deep Personas")
|
| 264 |
-
api_id_out = gr.JSON()
|
| 265 |
-
api_id_btn.click(identify_personas, inputs=[api_id_context], outputs=api_id_out, api_name="identify_personas")
|
| 266 |
-
|
| 267 |
-
with gr.Tab("Social Network API", visible=False):
|
| 268 |
-
api_net_name = gr.Textbox(label="Network Name")
|
| 269 |
-
api_net_count = gr.Number(label="Deep Persona Count", value=10)
|
| 270 |
-
api_net_type = gr.Dropdown(choices=["scale_free", "small_world"], label="Network Type")
|
| 271 |
-
api_net_focus = gr.Textbox(label="Focus Group Name (optional)")
|
| 272 |
-
api_net_btn = gr.Button("Generate Network")
|
| 273 |
-
api_net_out = gr.JSON()
|
| 274 |
-
api_net_btn.click(generate_social_network_api, inputs=[api_net_name, api_net_count, api_net_type, api_net_focus], outputs=api_net_out, api_name="generate_social_network")
|
| 275 |
-
|
| 276 |
-
with gr.Tab("Engagement Prediction API", visible=False):
|
| 277 |
-
api_pred_sim_id = gr.Textbox(label="Simulation ID")
|
| 278 |
-
api_pred_content = gr.Textbox(label="Content Text")
|
| 279 |
-
api_pred_format = gr.Textbox(label="Format", value="text")
|
| 280 |
-
api_pred_btn = gr.Button("Predict Engagement")
|
| 281 |
-
api_pred_out = gr.JSON()
|
| 282 |
-
api_pred_btn.click(predict_engagement_api, inputs=[api_pred_sim_id, api_pred_content, api_pred_format], outputs=api_pred_out, api_name="predict_engagement")
|
| 283 |
-
|
| 284 |
-
with gr.Tab("Async Simulation API", visible=False):
|
| 285 |
-
api_async_sim_id = gr.Textbox(label="Simulation ID")
|
| 286 |
-
api_async_content = gr.Textbox(label="Content Text")
|
| 287 |
-
api_async_format = gr.Textbox(label="Format", value="text")
|
| 288 |
-
api_async_btn = gr.Button("Start Simulation")
|
| 289 |
-
api_async_out = gr.JSON()
|
| 290 |
-
api_async_btn.click(start_simulation_async_api, inputs=[api_async_sim_id, api_async_content, api_async_format], outputs=api_async_out, api_name="start_simulation_async")
|
| 291 |
-
api_status_id = gr.Textbox(label="Simulation ID")
|
| 292 |
-
api_status_btn = gr.Button("Check Status")
|
| 293 |
-
api_status_out = gr.JSON()
|
| 294 |
-
api_status_btn.click(get_simulation_status_api, inputs=[api_status_id], outputs=api_status_out, api_name="get_simulation_status")
|
| 295 |
-
|
| 296 |
-
with gr.Tab("Chat API", visible=False):
|
| 297 |
-
api_chat_sim_id = gr.Textbox(label="Simulation ID")
|
| 298 |
-
api_chat_sender = gr.Textbox(label="Sender", value="User")
|
| 299 |
-
api_chat_msg = gr.Textbox(label="Message")
|
| 300 |
-
api_chat_send_btn = gr.Button("Send Message")
|
| 301 |
-
api_chat_send_out = gr.JSON()
|
| 302 |
-
api_chat_send_btn.click(send_chat_message_api, inputs=[api_chat_sim_id, api_chat_sender, api_chat_msg], outputs=api_chat_send_out, api_name="send_chat_message")
|
| 303 |
-
api_chat_hist_btn = gr.Button("Get History")
|
| 304 |
-
api_chat_hist_out = gr.JSON()
|
| 305 |
-
api_chat_hist_btn.click(get_chat_history_api, inputs=[api_chat_sim_id], outputs=api_chat_hist_out, api_name="get_chat_history")
|
| 306 |
-
|
| 307 |
-
with gr.Tab("Content Variants API", visible=False):
|
| 308 |
-
api_var_content = gr.Textbox(label="Original Content")
|
| 309 |
-
api_var_count = gr.Number(label="Number of Variants", value=5)
|
| 310 |
-
api_var_btn = gr.Button("Generate Variants")
|
| 311 |
-
api_var_out = gr.JSON()
|
| 312 |
-
api_var_btn.click(generate_variants_api, inputs=[api_var_content, api_var_count], outputs=api_var_out, api_name="generate_variants")
|
| 313 |
-
|
| 314 |
-
with gr.Tab("List Simulations API", visible=False):
|
| 315 |
-
api_list_sim_btn = gr.Button("List Simulations")
|
| 316 |
-
api_list_sim_out = gr.JSON()
|
| 317 |
-
api_list_sim_btn.click(list_simulations_api, outputs=api_list_sim_out, api_name="list_simulations")
|
| 318 |
-
|
| 319 |
-
with gr.Tab("List Deep Personas API", visible=False):
|
| 320 |
-
api_list_per_sim_id = gr.Textbox(label="Simulation ID")
|
| 321 |
-
api_list_per_btn = gr.Button("List Deep Personas")
|
| 322 |
-
api_list_per_out = gr.JSON()
|
| 323 |
-
api_list_per_btn.click(list_personas_api, inputs=[api_list_per_sim_id], outputs=api_list_per_out, api_name="list_personas")
|
| 324 |
-
|
| 325 |
-
with gr.Tab("Get Deep Persona API", visible=False):
|
| 326 |
-
api_get_per_sim_id = gr.Textbox(label="Simulation ID")
|
| 327 |
-
api_get_per_name = gr.Textbox(label="Deep Persona Name")
|
| 328 |
-
api_get_per_btn = gr.Button("Get Deep Persona")
|
| 329 |
-
api_get_per_out = gr.JSON()
|
| 330 |
-
api_get_per_btn.click(get_persona_api, inputs=[api_get_per_sim_id, api_get_per_name], outputs=api_get_per_out, api_name="get_persona")
|
| 331 |
-
|
| 332 |
-
with gr.Tab("Delete Simulation API", visible=False):
|
| 333 |
-
api_del_sim_id = gr.Textbox(label="Simulation ID")
|
| 334 |
-
api_del_btn = gr.Button("Delete Simulation")
|
| 335 |
-
api_del_out = gr.JSON()
|
| 336 |
-
api_del_btn.click(delete_simulation_api, inputs=[api_del_sim_id], outputs=api_del_out, api_name="delete_simulation")
|
| 337 |
-
|
| 338 |
-
with gr.Tab("Export Simulation API", visible=False):
|
| 339 |
-
api_exp_sim_id = gr.Textbox(label="Simulation ID")
|
| 340 |
-
api_exp_btn = gr.Button("Export Simulation")
|
| 341 |
-
api_exp_out = gr.JSON()
|
| 342 |
-
api_exp_btn.click(export_simulation_api, inputs=[api_exp_sim_id], outputs=api_exp_out, api_name="export_simulation")
|
| 343 |
-
|
| 344 |
-
with gr.Tab("Network Graph API", visible=False):
|
| 345 |
-
api_graph_sim_id = gr.Textbox(label="Simulation ID")
|
| 346 |
-
api_graph_btn = gr.Button("Get Graph Data")
|
| 347 |
-
api_graph_out = gr.JSON()
|
| 348 |
-
api_graph_btn.click(get_network_graph_api, inputs=[api_graph_sim_id], outputs=api_graph_out, api_name="get_network_graph")
|
| 349 |
-
|
| 350 |
-
with gr.Tab("Focus Group API", visible=False):
|
| 351 |
-
api_list_fg_btn = gr.Button("List Focus Groups")
|
| 352 |
-
api_list_fg_out = gr.JSON()
|
| 353 |
-
api_list_fg_btn.click(list_focus_groups_api, outputs=api_list_fg_out, api_name="list_focus_groups")
|
| 354 |
-
api_save_fg_name = gr.Textbox(label="Focus Group Name")
|
| 355 |
-
api_save_fg_sim_id = gr.Textbox(label="Simulation ID")
|
| 356 |
-
api_save_fg_btn = gr.Button("Save Focus Group")
|
| 357 |
-
api_save_fg_out = gr.JSON()
|
| 358 |
-
api_save_fg_btn.click(save_focus_group_api, inputs=[api_save_fg_name, api_save_fg_sim_id], outputs=api_save_fg_out, api_name="save_focus_group")
|
| 359 |
-
|
| 360 |
-
# FastAPI App
|
| 361 |
-
app = FastAPI()
|
| 362 |
-
|
| 363 |
-
@app.get("/health")
|
| 364 |
-
def health_check():
|
| 365 |
-
return {"status": "ok"}
|
| 366 |
-
|
| 367 |
-
@app.get("/api-docs")
|
| 368 |
-
def api_docs():
|
| 369 |
-
return RedirectResponse(url="/docs")
|
| 370 |
-
|
| 371 |
-
# Mount Gradio
|
| 372 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 373 |
|
| 374 |
if __name__ == "__main__":
|
| 375 |
-
import uvicorn
|
| 376 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
import sys
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import json
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
import uvicorn
|
| 7 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
app = FastAPI()
|
| 10 |
|
| 11 |
+
@app.get("/health")
|
| 12 |
+
def health():
|
| 13 |
+
return {"status": "ok"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
@app.get("/api-docs")
|
| 16 |
+
def api_docs():
|
| 17 |
+
# In fastapi /docs is the swagger ui, but let's provide a JSON response as well for this specific endpoint.
|
| 18 |
+
return {"message": "API documentation is available at /docs"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def extract_persona_parameters(business_description: str, customer_profile: str) -> dict:
|
| 21 |
+
from tinytroupe.openai_utils import client
|
|
|
|
| 22 |
|
| 23 |
+
system_prompt = """
|
| 24 |
+
You are an expert persona parameter extractor.
|
| 25 |
+
Based on the provided business description and customer profile, you must deduce and generate 10 specific parameters needed for a deep persona generator.
|
| 26 |
+
The parameters are:
|
| 27 |
+
- `age` (float): The age of the persona.
|
| 28 |
+
- `gender` (str): The gender of the persona.
|
| 29 |
+
- `occupation` (str): The occupation of the persona.
|
| 30 |
+
- `city` (str): The city of the persona.
|
| 31 |
+
- `country` (str): The country of the persona.
|
| 32 |
+
- `custom_values` (str): The personal values of the persona.
|
| 33 |
+
- `custom_life_attitude` (str): The life attitude of the persona.
|
| 34 |
+
- `life_story` (str): A brief life story of the persona.
|
| 35 |
+
- `interests_hobbies` (str): Interests and hobbies of the persona.
|
| 36 |
+
- `attribute_count` (float): Attribute richness, default to 200.
|
| 37 |
+
|
| 38 |
+
You must return a valid JSON object containing exactly these keys.
|
| 39 |
+
"""
|
| 40 |
|
| 41 |
+
user_prompt = f"Business Description: {business_description}\nCustomer Profile: {customer_profile}\n\nReturn the 10 parameters as JSON."
|
|
|
|
| 42 |
|
| 43 |
+
messages = [
|
| 44 |
+
{"role": "system", "content": system_prompt},
|
| 45 |
+
{"role": "user", "content": user_prompt}
|
| 46 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
api_client = client()
|
| 49 |
+
response = api_client.send_message(messages, response_format={"type": "json_object"})
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
if response and "content" in response:
|
| 52 |
+
try:
|
| 53 |
+
# Attempt to parse it if the model returned string json
|
| 54 |
+
import json
|
| 55 |
+
import tinytroupe.utils as utils
|
| 56 |
+
extracted_json = utils.extract_json(response["content"])
|
| 57 |
+
|
| 58 |
+
# Ensure all keys are present
|
| 59 |
+
required_keys = ['age', 'gender', 'occupation', 'city', 'country', 'custom_values', 'custom_life_attitude', 'life_story', 'interests_hobbies', 'attribute_count']
|
| 60 |
+
|
| 61 |
+
# If extracting JSON list vs dict
|
| 62 |
+
if isinstance(extracted_json, list) and len(extracted_json) > 0:
|
| 63 |
+
extracted_json = extracted_json[0]
|
| 64 |
+
|
| 65 |
+
for key in required_keys:
|
| 66 |
+
if key not in extracted_json:
|
| 67 |
+
# provide defaults for missing ones
|
| 68 |
+
if key in ['age', 'attribute_count']:
|
| 69 |
+
extracted_json[key] = 200 if key == 'attribute_count' else 30
|
| 70 |
+
else:
|
| 71 |
+
extracted_json[key] = "Unknown"
|
| 72 |
+
|
| 73 |
+
return extracted_json
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error parsing JSON from LLM: {e}")
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
# Fallback
|
| 79 |
+
return {
|
| 80 |
+
"age": 30,
|
| 81 |
+
"gender": "Non-binary",
|
| 82 |
+
"occupation": "Professional",
|
| 83 |
+
"city": "Metropolis",
|
| 84 |
+
"country": "Country",
|
| 85 |
+
"custom_values": "Innovation, Community",
|
| 86 |
+
"custom_life_attitude": "Optimistic",
|
| 87 |
+
"life_story": "A standard professional background with a passion for their field.",
|
| 88 |
+
"interests_hobbies": "Technology, Reading",
|
| 89 |
+
"attribute_count": 200
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def generate_personas(business_description, customer_profile, num_personas, blablador_api_key=None):
|
| 93 |
"""
|
| 94 |
+
Generates a list of personas based on the provided inputs, utilizing a double
|
| 95 |
+
sequential generation pipeline:
|
| 96 |
+
1. Extract parameters from context via LLM.
|
| 97 |
+
2. Generate persona using deeppersona-experience via gradio client.
|
| 98 |
"""
|
| 99 |
+
api_key_to_use = blablador_api_key or os.getenv("BLABLADOR_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
if not api_key_to_use:
|
| 102 |
+
return {"error": "BLABLADOR_API_KEY not found. Please provide it in your API call or set it as a secret in the Space settings."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
original_key = os.getenv("BLABLADOR_API_KEY")
|
| 105 |
+
os.environ["BLABLADOR_API_KEY"] = api_key_to_use
|
| 106 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
try:
|
| 108 |
+
from gradio_client import Client
|
| 109 |
+
|
| 110 |
+
num_personas = int(num_personas)
|
| 111 |
+
personas_data = []
|
| 112 |
|
| 113 |
+
# Step 1: Extract 10 parameters based on the high-level inputs
|
| 114 |
+
# For multiple personas, we could call this in a loop or once.
|
| 115 |
+
# The prompt implies we want to do it in a pipeline. We'll do it per persona or once based on the prompt.
|
| 116 |
+
# Let's do it per persona to generate distinct ones, passing an index or just relying on LLM variance.
|
| 117 |
+
|
| 118 |
+
# Connect to gradio client
|
| 119 |
+
# In a real scenario, the Hugging Face Token might be needed if the Space is private.
|
| 120 |
+
# But deeppersona-experience is public or assumed accessible.
|
| 121 |
+
client = Client("THzva/deeppersona-experience")
|
| 122 |
+
|
| 123 |
+
for i in range(num_personas):
|
| 124 |
+
# To get variety, we can append a note about variety to the profile
|
| 125 |
+
profile_with_variance = customer_profile + f"\n\nMake this persona distinct. Persona {i+1} of {num_personas}."
|
| 126 |
+
|
| 127 |
+
# Extract parameters using the LLM
|
| 128 |
+
params = extract_persona_parameters(business_description, profile_with_variance)
|
| 129 |
+
|
| 130 |
+
# Step 2: Call the Gradio API with the extracted parameters
|
| 131 |
+
result = client.predict(
|
| 132 |
+
age=float(params.get("age", 30)),
|
| 133 |
+
gender=str(params.get("gender", "Non-binary")),
|
| 134 |
+
occupation=str(params.get("occupation", "Professional")),
|
| 135 |
+
city=str(params.get("city", "Metropolis")),
|
| 136 |
+
country=str(params.get("country", "Country")),
|
| 137 |
+
custom_values=str(params.get("custom_values", "Innovation, Community")),
|
| 138 |
+
custom_life_attitude=str(params.get("custom_life_attitude", "Optimistic")),
|
| 139 |
+
life_story=str(params.get("life_story", "A standard professional background with a passion for their field.")),
|
| 140 |
+
interests_hobbies=str(params.get("interests_hobbies", "Technology, Reading")),
|
| 141 |
+
attribute_count=float(params.get("attribute_count", 200)),
|
| 142 |
+
api_name="/generate_persona"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Note: The result from this API is a string (persona profile text)
|
| 146 |
+
personas_data.append({
|
| 147 |
+
"parameters_used": params,
|
| 148 |
+
"persona_profile": result
|
| 149 |
})
|
| 150 |
+
|
| 151 |
+
return personas_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
except Exception as e:
|
| 154 |
return {"error": str(e)}
|
| 155 |
+
|
| 156 |
+
finally:
|
| 157 |
+
if original_key is None:
|
| 158 |
+
if "BLABLADOR_API_KEY" in os.environ:
|
| 159 |
+
del os.environ["BLABLADOR_API_KEY"]
|
| 160 |
+
else:
|
| 161 |
+
os.environ["BLABLADOR_API_KEY"] = original_key
|
| 162 |
|
|
|
|
| 163 |
with gr.Blocks() as demo:
|
| 164 |
+
gr.Markdown("<h1>Tiny Persona Generator</h1>")
|
| 165 |
with gr.Row():
|
| 166 |
with gr.Column():
|
| 167 |
business_description_input = gr.Textbox(label="What is your business about?", lines=5)
|
| 168 |
customer_profile_input = gr.Textbox(label="Information about your customer profile", lines=5)
|
| 169 |
num_personas_input = gr.Number(label="Number of personas to generate", value=1, minimum=1, step=1)
|
| 170 |
+
|
| 171 |
+
blablador_api_key_input = gr.Textbox(
|
| 172 |
+
label="Blablador API Key (for API client use)",
|
| 173 |
+
visible=False
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
generate_button = gr.Button("Generate Personas")
|
| 177 |
with gr.Column():
|
| 178 |
+
output_json = gr.JSON(label="Generated Personas")
|
| 179 |
|
| 180 |
generate_button.click(
|
| 181 |
fn=generate_personas,
|
|
|
|
| 184 |
api_name="generate_personas"
|
| 185 |
)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 188 |
|
| 189 |
if __name__ == "__main__":
|
|
|
|
| 190 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
config.ini
CHANGED
|
@@ -1,12 +1,7 @@
|
|
| 1 |
[OpenAI]
|
| 2 |
API_TYPE=helmholtz-blablador
|
| 3 |
-
MODEL=alias-
|
| 4 |
-
REASONING_MODEL=alias-
|
| 5 |
-
FALLBACK_MODEL_LARGE=alias-large
|
| 6 |
-
FALLBACK_MODEL_HUGE=alias-huge
|
| 7 |
TOP_P=1.0
|
| 8 |
-
MAX_ATTEMPTS=
|
| 9 |
-
WAITING_TIME=
|
| 10 |
-
|
| 11 |
-
[Logging]
|
| 12 |
-
LOGLEVEL=DEBUG
|
|
|
|
| 1 |
[OpenAI]
|
| 2 |
API_TYPE=helmholtz-blablador
|
| 3 |
+
MODEL=alias-large
|
| 4 |
+
REASONING_MODEL=alias-large
|
|
|
|
|
|
|
| 5 |
TOP_P=1.0
|
| 6 |
+
MAX_ATTEMPTS=5
|
| 7 |
+
WAITING_TIME=20
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -3,11 +3,11 @@ requires = ["setuptools>=61.0"]
|
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
| 5 |
[tool.setuptools]
|
| 6 |
-
packages = ["
|
| 7 |
include-package-data = true
|
| 8 |
|
| 9 |
[project]
|
| 10 |
-
name = "
|
| 11 |
version = "0.5.2"
|
| 12 |
authors = [
|
| 13 |
{ name="Paulo Salem", email="paulo.salem@microsoft.com" }
|
|
@@ -41,7 +41,7 @@ dependencies = [
|
|
| 41 |
]
|
| 42 |
|
| 43 |
[project.urls]
|
| 44 |
-
"Homepage" = "https://github.com/microsoft/
|
| 45 |
|
| 46 |
[tool.pytest.ini_options]
|
| 47 |
pythonpath = [
|
|
@@ -56,4 +56,4 @@ markers = [
|
|
| 56 |
"examples: mark a test as the execution of examples",
|
| 57 |
"notebooks: mark a test as a more specific Jupyter notebook execution example",
|
| 58 |
]
|
| 59 |
-
addopts = "--cov=
|
|
|
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
| 5 |
[tool.setuptools]
|
| 6 |
+
packages = ["tinytroupe"]
|
| 7 |
include-package-data = true
|
| 8 |
|
| 9 |
[project]
|
| 10 |
+
name = "tinytroupe"
|
| 11 |
version = "0.5.2"
|
| 12 |
authors = [
|
| 13 |
{ name="Paulo Salem", email="paulo.salem@microsoft.com" }
|
|
|
|
| 41 |
]
|
| 42 |
|
| 43 |
[project.urls]
|
| 44 |
+
"Homepage" = "https://github.com/microsoft/tinytroupe"
|
| 45 |
|
| 46 |
[tool.pytest.ini_options]
|
| 47 |
pythonpath = [
|
|
|
|
| 56 |
"examples: mark a test as the execution of examples",
|
| 57 |
"notebooks: mark a test as a more specific Jupyter notebook execution example",
|
| 58 |
]
|
| 59 |
+
addopts = "--cov=tinytroupe --cov-report=html --cov-report=xml"
|
requirements.txt
CHANGED
|
@@ -21,7 +21,7 @@ pydantic
|
|
| 21 |
textdistance
|
| 22 |
scipy
|
| 23 |
transformers==4.38.2
|
| 24 |
-
huggingface-hub
|
| 25 |
gradio_client
|
| 26 |
fastapi
|
| 27 |
uvicorn
|
|
|
|
| 21 |
textdistance
|
| 22 |
scipy
|
| 23 |
transformers==4.38.2
|
| 24 |
+
huggingface-hub==0.22.2
|
| 25 |
gradio_client
|
| 26 |
fastapi
|
| 27 |
uvicorn
|
test_api.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Ensure the current directory is in the path
|
| 9 |
+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
| 10 |
+
|
| 11 |
+
from app import extract_persona_parameters, generate_personas
|
| 12 |
+
|
| 13 |
+
def test_extract_persona_parameters_fallback():
|
| 14 |
+
# If the LLM call fails or returns empty, the fallback should return
|
| 15 |
+
with patch('tinytroupe.openai_utils.client') as mock_client:
|
| 16 |
+
mock_instance = MagicMock()
|
| 17 |
+
mock_instance.send_message.return_value = None
|
| 18 |
+
mock_client.return_value = mock_instance
|
| 19 |
+
|
| 20 |
+
result = extract_persona_parameters("Test Business", "Test Customer")
|
| 21 |
+
assert "age" in result
|
| 22 |
+
assert result["age"] == 30
|
| 23 |
+
|
| 24 |
+
def test_extract_persona_parameters_success():
|
| 25 |
+
with patch('tinytroupe.openai_utils.client') as mock_client:
|
| 26 |
+
mock_instance = MagicMock()
|
| 27 |
+
mock_instance.send_message.return_value = {
|
| 28 |
+
"content": '{"age": 25, "gender": "Female", "occupation": "Engineer", "city": "NYC", "country": "USA", "custom_values": "Innovation", "custom_life_attitude": "Positive", "life_story": "A story", "interests_hobbies": "Coding", "attribute_count": 200}'
|
| 29 |
+
}
|
| 30 |
+
mock_client.return_value = mock_instance
|
| 31 |
+
|
| 32 |
+
result = extract_persona_parameters("Tech Startup", "Young professionals")
|
| 33 |
+
assert result["age"] == 25
|
| 34 |
+
assert result["gender"] == "Female"
|
| 35 |
+
assert result["city"] == "NYC"
|
| 36 |
+
|
| 37 |
+
@patch('gradio_client.Client') # Mocking gradio_client Client
|
| 38 |
+
def test_generate_personas(mock_client_class):
|
| 39 |
+
mock_client_instance = MagicMock()
|
| 40 |
+
mock_client_instance.predict.return_value = "Generated persona profile text"
|
| 41 |
+
mock_client_class.return_value = mock_client_instance
|
| 42 |
+
|
| 43 |
+
with patch('app.extract_persona_parameters') as mock_extract:
|
| 44 |
+
mock_extract.return_value = {
|
| 45 |
+
"age": 25, "gender": "Female", "occupation": "Engineer",
|
| 46 |
+
"city": "NYC", "country": "USA", "custom_values": "Innovation",
|
| 47 |
+
"custom_life_attitude": "Positive", "life_story": "A story",
|
| 48 |
+
"interests_hobbies": "Coding", "attribute_count": 200
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# We need an API key to pass the check
|
| 52 |
+
result = generate_personas("Tech Startup", "Young professionals", 1, blablador_api_key="TEST_KEY")
|
| 53 |
+
|
| 54 |
+
assert isinstance(result, list)
|
| 55 |
+
assert len(result) == 1
|
| 56 |
+
assert "parameters_used" in result[0]
|
| 57 |
+
assert "persona_profile" in result[0]
|
| 58 |
+
assert result[0]["persona_profile"] == "Generated persona profile text"
|
tinytroupe/agent/memory.py
CHANGED
|
@@ -88,24 +88,6 @@ class TinyMemory(TinyMentalFaculty):
|
|
| 88 |
"""
|
| 89 |
raise NotImplementedError("Subclasses must implement this method.")
|
| 90 |
|
| 91 |
-
def store_interaction(self, interaction: Any) -> None:
|
| 92 |
-
"""
|
| 93 |
-
Stores an interaction in memory.
|
| 94 |
-
"""
|
| 95 |
-
self.store({"type": "interaction", "content": interaction, "simulation_timestamp": utils.pretty_datetime(datetime.now())})
|
| 96 |
-
|
| 97 |
-
def get_memory_summary(self) -> str:
|
| 98 |
-
"""
|
| 99 |
-
Returns a summary of the memory.
|
| 100 |
-
"""
|
| 101 |
-
raise NotImplementedError("Subclasses must implement this method.")
|
| 102 |
-
|
| 103 |
-
def consolidate_memories(self) -> None:
|
| 104 |
-
"""
|
| 105 |
-
Consolidates memories (e.g., from episodic to semantic).
|
| 106 |
-
"""
|
| 107 |
-
raise NotImplementedError("Subclasses must implement this method.")
|
| 108 |
-
|
| 109 |
def summarize_relevant_via_full_scan(self, relevance_target: str, batch_size: int = 20, item_type: str = None) -> str:
|
| 110 |
"""
|
| 111 |
Performs a full scan of the memory, extracting and accumulating information relevant to a query.
|
|
|
|
| 88 |
"""
|
| 89 |
raise NotImplementedError("Subclasses must implement this method.")
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def summarize_relevant_via_full_scan(self, relevance_target: str, batch_size: int = 20, item_type: str = None) -> str:
|
| 92 |
"""
|
| 93 |
Performs a full scan of the memory, extracting and accumulating information relevant to a query.
|
tinytroupe/agent/tiny_person.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from tinytroupe.agent import logger, default, Self, AgentOrWorld, CognitiveActionModel
|
| 2 |
from tinytroupe.agent.memory import EpisodicMemory, SemanticMemory, EpisodicConsolidator
|
| 3 |
-
from tinytroupe.agent.social_types import ConnectionEdge, BehavioralEvent, InfluenceProfile, Content, Reaction
|
| 4 |
import tinytroupe.openai_utils as openai_utils
|
| 5 |
from tinytroupe.utils import JsonSerializableRegistry, repeat_on_error, name_or_empty
|
| 6 |
import tinytroupe.utils as utils
|
|
@@ -43,8 +42,7 @@ class TinyPerson(JsonSerializableRegistry):
|
|
| 43 |
|
| 44 |
PP_TEXT_WIDTH = 100
|
| 45 |
|
| 46 |
-
serializable_attributes = ["_persona", "_mental_state", "_mental_faculties", "_current_episode_event_count", "episodic_memory", "semantic_memory"
|
| 47 |
-
"social_connections", "engagement_patterns", "behavioral_history", "influence_metrics", "prediction_confidence", "behavioral_traits"]
|
| 48 |
serializable_attributes_renaming = {"_mental_faculties": "mental_faculties", "_persona": "persona", "_mental_state": "mental_state", "_current_episode_event_count": "current_episode_event_count"}
|
| 49 |
|
| 50 |
# A dict of all agents instantiated so far.
|
|
@@ -211,29 +209,6 @@ class TinyPerson(JsonSerializableRegistry):
|
|
| 211 |
|
| 212 |
if not hasattr(self, 'stimuli_count'):
|
| 213 |
self.stimuli_count = 0
|
| 214 |
-
|
| 215 |
-
if not hasattr(self, 'social_connections'):
|
| 216 |
-
self.social_connections = {}
|
| 217 |
-
|
| 218 |
-
if not hasattr(self, 'engagement_patterns'):
|
| 219 |
-
self.engagement_patterns = {
|
| 220 |
-
"content_type_preferences": {},
|
| 221 |
-
"topic_affinities": {},
|
| 222 |
-
"posting_time_preferences": {},
|
| 223 |
-
"engagement_likelihood": {}
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
if not hasattr(self, 'behavioral_history'):
|
| 227 |
-
self.behavioral_history = []
|
| 228 |
-
|
| 229 |
-
if not hasattr(self, 'influence_metrics'):
|
| 230 |
-
self.influence_metrics = InfluenceProfile()
|
| 231 |
-
|
| 232 |
-
if not hasattr(self, 'prediction_confidence'):
|
| 233 |
-
self.prediction_confidence = 0.0
|
| 234 |
-
|
| 235 |
-
if not hasattr(self, 'behavioral_traits'):
|
| 236 |
-
self.behavioral_traits = {}
|
| 237 |
|
| 238 |
self._prompt_template_path = os.path.join(
|
| 239 |
os.path.dirname(__file__), "prompts/tiny_person.mustache"
|
|
@@ -1819,47 +1794,3 @@ max_content_length=max_content_length,
|
|
| 1819 |
Clears the global list of agents.
|
| 1820 |
"""
|
| 1821 |
TinyPerson.all_agents = {}
|
| 1822 |
-
|
| 1823 |
-
############################################################################
|
| 1824 |
-
# Social and Engagement methods
|
| 1825 |
-
############################################################################
|
| 1826 |
-
|
| 1827 |
-
def calculate_engagement_probability(self, content: Content) -> float:
|
| 1828 |
-
"""
|
| 1829 |
-
Analyze content features and return probability of engagement using the prediction engine.
|
| 1830 |
-
"""
|
| 1831 |
-
from tinytroupe.ml_models import EngagementPredictor
|
| 1832 |
-
predictor = EngagementPredictor()
|
| 1833 |
-
|
| 1834 |
-
# Use the environment's network topology if available
|
| 1835 |
-
network = getattr(self.environment, 'network', None)
|
| 1836 |
-
|
| 1837 |
-
return predictor.predict(self, content, network)
|
| 1838 |
-
|
| 1839 |
-
def predict_reaction(self, content: Content) -> Reaction:
|
| 1840 |
-
"""
|
| 1841 |
-
Determine reaction type using the LLM-based predictor.
|
| 1842 |
-
"""
|
| 1843 |
-
from tinytroupe.llm_predictor import LLMPredictor
|
| 1844 |
-
predictor = LLMPredictor()
|
| 1845 |
-
|
| 1846 |
-
return predictor.predict(self, content)
|
| 1847 |
-
|
| 1848 |
-
def update_from_interaction(self, interaction: Any) -> None:
|
| 1849 |
-
"""
|
| 1850 |
-
Learn from actual interactions and update patterns.
|
| 1851 |
-
"""
|
| 1852 |
-
# interaction could be a dict with content and outcome
|
| 1853 |
-
if isinstance(interaction, dict):
|
| 1854 |
-
content = interaction.get("content")
|
| 1855 |
-
outcome = interaction.get("outcome") # e.g. "like", "comment", "none"
|
| 1856 |
-
|
| 1857 |
-
# Update patterns based on outcome
|
| 1858 |
-
# This is a simplified learning mechanism
|
| 1859 |
-
pass
|
| 1860 |
-
|
| 1861 |
-
def get_content_affinity(self, content: Content) -> float:
|
| 1862 |
-
"""
|
| 1863 |
-
Score content relevance to persona.
|
| 1864 |
-
"""
|
| 1865 |
-
return self.calculate_engagement_probability(content)
|
|
|
|
| 1 |
from tinytroupe.agent import logger, default, Self, AgentOrWorld, CognitiveActionModel
|
| 2 |
from tinytroupe.agent.memory import EpisodicMemory, SemanticMemory, EpisodicConsolidator
|
|
|
|
| 3 |
import tinytroupe.openai_utils as openai_utils
|
| 4 |
from tinytroupe.utils import JsonSerializableRegistry, repeat_on_error, name_or_empty
|
| 5 |
import tinytroupe.utils as utils
|
|
|
|
| 42 |
|
| 43 |
PP_TEXT_WIDTH = 100
|
| 44 |
|
| 45 |
+
serializable_attributes = ["_persona", "_mental_state", "_mental_faculties", "_current_episode_event_count", "episodic_memory", "semantic_memory"]
|
|
|
|
| 46 |
serializable_attributes_renaming = {"_mental_faculties": "mental_faculties", "_persona": "persona", "_mental_state": "mental_state", "_current_episode_event_count": "current_episode_event_count"}
|
| 47 |
|
| 48 |
# A dict of all agents instantiated so far.
|
|
|
|
| 209 |
|
| 210 |
if not hasattr(self, 'stimuli_count'):
|
| 211 |
self.stimuli_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
self._prompt_template_path = os.path.join(
|
| 214 |
os.path.dirname(__file__), "prompts/tiny_person.mustache"
|
|
|
|
| 1794 |
Clears the global list of agents.
|
| 1795 |
"""
|
| 1796 |
TinyPerson.all_agents = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tinytroupe/config.ini
CHANGED
|
@@ -15,10 +15,10 @@ AZURE_API_VERSION=2023-05-15
|
|
| 15 |
#
|
| 16 |
|
| 17 |
# The main text generation model, used for agent responses
|
| 18 |
-
MODEL=
|
| 19 |
|
| 20 |
# Reasoning model is used when precise reasoning is required, such as when computing detailed analyses of simulation properties.
|
| 21 |
-
REASONING_MODEL=
|
| 22 |
|
| 23 |
# Embedding model is used for text similarity tasks
|
| 24 |
EMBEDDING_MODEL=text-embedding-3-small
|
|
@@ -31,8 +31,8 @@ TEMPERATURE=1.5
|
|
| 31 |
FREQ_PENALTY=0.1
|
| 32 |
PRESENCE_PENALTY=0.1
|
| 33 |
TIMEOUT=480
|
| 34 |
-
MAX_ATTEMPTS=
|
| 35 |
-
WAITING_TIME=
|
| 36 |
EXPONENTIAL_BACKOFF_FACTOR=5
|
| 37 |
|
| 38 |
REASONING_EFFORT=high
|
|
@@ -90,7 +90,7 @@ QUALITY_THRESHOLD = 5
|
|
| 90 |
|
| 91 |
|
| 92 |
[Logging]
|
| 93 |
-
LOGLEVEL=
|
| 94 |
# ERROR
|
| 95 |
# WARNING
|
| 96 |
# INFO
|
|
|
|
| 15 |
#
|
| 16 |
|
| 17 |
# The main text generation model, used for agent responses
|
| 18 |
+
MODEL=gpt-4.1-mini
|
| 19 |
|
| 20 |
# Reasoning model is used when precise reasoning is required, such as when computing detailed analyses of simulation properties.
|
| 21 |
+
REASONING_MODEL=o3-mini
|
| 22 |
|
| 23 |
# Embedding model is used for text similarity tasks
|
| 24 |
EMBEDDING_MODEL=text-embedding-3-small
|
|
|
|
| 31 |
FREQ_PENALTY=0.1
|
| 32 |
PRESENCE_PENALTY=0.1
|
| 33 |
TIMEOUT=480
|
| 34 |
+
MAX_ATTEMPTS=5
|
| 35 |
+
WAITING_TIME=1
|
| 36 |
EXPONENTIAL_BACKOFF_FACTOR=5
|
| 37 |
|
| 38 |
REASONING_EFFORT=high
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
[Logging]
|
| 93 |
+
LOGLEVEL=ERROR
|
| 94 |
# ERROR
|
| 95 |
# WARNING
|
| 96 |
# INFO
|
tinytroupe/factory/tiny_person_factory.py
CHANGED
|
@@ -180,8 +180,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 180 |
frequency_penalty:float=0.0,
|
| 181 |
presence_penalty:float=0.0,
|
| 182 |
attempts:int=10,
|
| 183 |
-
post_processing_func=None
|
| 184 |
-
deep_persona:bool=True) -> TinyPerson:
|
| 185 |
"""
|
| 186 |
Generate a TinyPerson instance using OpenAI's LLM.
|
| 187 |
|
|
@@ -319,10 +318,6 @@ class TinyPersonFactory(TinyFactory):
|
|
| 319 |
|
| 320 |
# create the fresh agent
|
| 321 |
if agent_spec is not None:
|
| 322 |
-
# If deep_persona is requested, perform the second API call to enrich the persona
|
| 323 |
-
if deep_persona:
|
| 324 |
-
agent_spec = self._generate_deep_persona_internal(agent_spec)
|
| 325 |
-
|
| 326 |
# the agent is created here. This is why the present method cannot be cached. Instead, an auxiliary method is used
|
| 327 |
# for the actual model call, so that it gets cached properly without skipping the agent creation.
|
| 328 |
|
|
@@ -347,46 +342,6 @@ class TinyPersonFactory(TinyFactory):
|
|
| 347 |
|
| 348 |
|
| 349 |
@config_manager.config_defaults(parallelize="parallel_agent_generation")
|
| 350 |
-
def generate_from_linkedin_profile(self, profile_data: Dict) -> TinyPerson:
|
| 351 |
-
"""
|
| 352 |
-
Generate a TinyPerson from a LinkedIn profile with enriched traits.
|
| 353 |
-
"""
|
| 354 |
-
description = f"Professional with headline: {profile_data.get('headline', '')}. " \
|
| 355 |
-
f"Industry: {profile_data.get('industry', '')}. " \
|
| 356 |
-
f"Location: {profile_data.get('location', 'Global')}. " \
|
| 357 |
-
f"Career level: {profile_data.get('career_level', 'Mid Level')}. " \
|
| 358 |
-
f"Summary: {profile_data.get('summary', '')}"
|
| 359 |
-
|
| 360 |
-
return self.generate_person(agent_particularities=description)
|
| 361 |
-
|
| 362 |
-
def generate_persona_cluster(self, archetype: str, count: int) -> List[TinyPerson]:
|
| 363 |
-
"""
|
| 364 |
-
Generate a cluster of personas following a specific archetype.
|
| 365 |
-
"""
|
| 366 |
-
return self.generate_people(number_of_people=count, agent_particularities=f"Archetype: {archetype}")
|
| 367 |
-
|
| 368 |
-
def generate_diverse_population(self, size: int, distribution: Dict) -> List[TinyPerson]:
|
| 369 |
-
"""
|
| 370 |
-
Generate a diverse population based on a distribution.
|
| 371 |
-
"""
|
| 372 |
-
# distribution could specify proportions of various characteristics
|
| 373 |
-
# This is a simplified implementation
|
| 374 |
-
return self.generate_people(number_of_people=size, agent_particularities=f"Target distribution: {json.dumps(distribution)}")
|
| 375 |
-
|
| 376 |
-
def ensure_consistency(self, persona: TinyPerson) -> bool:
|
| 377 |
-
"""
|
| 378 |
-
Ensure the generated persona is consistent.
|
| 379 |
-
"""
|
| 380 |
-
# Implementation would involve checking traits, demographics, etc.
|
| 381 |
-
return True # Placeholder
|
| 382 |
-
|
| 383 |
-
def calculate_diversity_score(self, personas: List[TinyPerson]) -> float:
|
| 384 |
-
"""
|
| 385 |
-
Calculate a diversity score for a list of personas.
|
| 386 |
-
"""
|
| 387 |
-
# Placeholder for diversity metric calculation
|
| 388 |
-
return 0.5
|
| 389 |
-
|
| 390 |
def generate_people(self, number_of_people:int=None,
|
| 391 |
agent_particularities:str=None,
|
| 392 |
temperature:float=1.2,
|
|
@@ -395,8 +350,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 395 |
attempts:int=10,
|
| 396 |
post_processing_func=None,
|
| 397 |
parallelize=None,
|
| 398 |
-
verbose:bool=False
|
| 399 |
-
deep_persona:bool=True) -> list:
|
| 400 |
"""
|
| 401 |
Generate a list of TinyPerson instances using OpenAI's LLM.
|
| 402 |
|
|
@@ -436,8 +390,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 436 |
presence_penalty=presence_penalty,
|
| 437 |
attempts=attempts,
|
| 438 |
post_processing_func=post_processing_func,
|
| 439 |
-
verbose=verbose
|
| 440 |
-
deep_persona=deep_persona)
|
| 441 |
else:
|
| 442 |
people = self._generate_people_sequentially(number_of_people=number_of_people,
|
| 443 |
agent_particularities=agent_particularities,
|
|
@@ -446,8 +399,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 446 |
presence_penalty=presence_penalty,
|
| 447 |
attempts=attempts,
|
| 448 |
post_processing_func=post_processing_func,
|
| 449 |
-
verbose=verbose
|
| 450 |
-
deep_persona=deep_persona)
|
| 451 |
|
| 452 |
return people
|
| 453 |
|
|
@@ -460,8 +412,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 460 |
presence_penalty:float=0.0,
|
| 461 |
attempts:int=10,
|
| 462 |
post_processing_func=None,
|
| 463 |
-
verbose:bool=False
|
| 464 |
-
deep_persona:bool=True) -> list:
|
| 465 |
people = []
|
| 466 |
|
| 467 |
#
|
|
@@ -473,20 +424,19 @@ class TinyPersonFactory(TinyFactory):
|
|
| 473 |
|
| 474 |
# this is the function that will be executed in parallel
|
| 475 |
def generate_person_wrapper(args):
|
| 476 |
-
self, i, agent_particularities, temperature, frequency_penalty, presence_penalty, attempts, post_processing_func
|
| 477 |
person = self.generate_person(agent_particularities=agent_particularities,
|
| 478 |
temperature=temperature,
|
| 479 |
frequency_penalty=frequency_penalty,
|
| 480 |
presence_penalty=presence_penalty,
|
| 481 |
attempts=attempts,
|
| 482 |
-
post_processing_func=post_processing_func
|
| 483 |
-
deep_persona=deep_persona)
|
| 484 |
return i, person
|
| 485 |
|
| 486 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 487 |
# we use a list of futures to keep track of the results
|
| 488 |
futures = [
|
| 489 |
-
executor.submit(generate_person_wrapper, (self, i, agent_particularities, temperature, frequency_penalty, presence_penalty, attempts, post_processing_func
|
| 490 |
for i in range(number_of_people)
|
| 491 |
]
|
| 492 |
|
|
@@ -513,8 +463,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 513 |
presence_penalty:float=0.0,
|
| 514 |
attempts:int=10,
|
| 515 |
post_processing_func=None,
|
| 516 |
-
verbose:bool=False
|
| 517 |
-
deep_persona:bool=True) -> list:
|
| 518 |
"""
|
| 519 |
Generate the people sequentially, not in parallel. This is a simpler alternative.
|
| 520 |
"""
|
|
@@ -525,8 +474,7 @@ class TinyPersonFactory(TinyFactory):
|
|
| 525 |
frequency_penalty=frequency_penalty,
|
| 526 |
presence_penalty=presence_penalty,
|
| 527 |
attempts=attempts,
|
| 528 |
-
post_processing_func=post_processing_func
|
| 529 |
-
deep_persona=deep_persona)
|
| 530 |
if person is not None:
|
| 531 |
people.append(person)
|
| 532 |
info_msg = f"Generated person {i+1}/{number_of_people}: {person.minibio()}"
|
|
@@ -610,11 +558,6 @@ class TinyPersonFactory(TinyFactory):
|
|
| 610 |
if len(self.remaining_characteristics_sample) != n:
|
| 611 |
logger.warning(f"Expected {n} samples, but got {len(self.remaining_characteristics_sample)} samples. The LLM may have failed to sum up the quantities in the sampling plan correctly.")
|
| 612 |
|
| 613 |
-
# If we got more samples than requested, we truncate them to avoid generating too many names or personas.
|
| 614 |
-
if len(self.remaining_characteristics_sample) > n:
|
| 615 |
-
logger.info(f"Truncating {len(self.remaining_characteristics_sample)} samples to the requested {n} samples.")
|
| 616 |
-
self.remaining_characteristics_sample = self.remaining_characteristics_sample[:n]
|
| 617 |
-
|
| 618 |
logger.info(f"Sample plan has been flattened, contains {len(self.remaining_characteristics_sample)} total samples.")
|
| 619 |
logger.debug(f"Remaining characteristics sample: {json.dumps(self.remaining_characteristics_sample, indent=4)}")
|
| 620 |
|
|
@@ -1352,42 +1295,6 @@ class TinyPersonFactory(TinyFactory):
|
|
| 1352 |
presence_penalty=presence_penalty,
|
| 1353 |
response_format={"type": "json_object"})
|
| 1354 |
|
| 1355 |
-
def _generate_deep_persona_internal(self, initial_spec: dict) -> dict:
|
| 1356 |
-
"""
|
| 1357 |
-
Performs a second API call to enrich the persona with a depth of 350 attributes.
|
| 1358 |
-
"""
|
| 1359 |
-
logger.info(f"Enriching persona {initial_spec.get('name')} to deep persona (depth 350)...")
|
| 1360 |
-
|
| 1361 |
-
prompt = f"""
|
| 1362 |
-
You are an expert persona generator. You have been provided with an initial persona profile:
|
| 1363 |
-
{json.dumps(initial_spec, indent=4)}
|
| 1364 |
-
|
| 1365 |
-
TASK:
|
| 1366 |
-
Take all the attributes from this initial profile and expand them significantly to reach a depth of 350 attributes/nuances.
|
| 1367 |
-
The final profile must be incredibly detailed, authentic, and realistic.
|
| 1368 |
-
Expand on every field: education, occupation, style, personality, preferences, beliefs, skills, behaviors, health, relationships, and other_facts.
|
| 1369 |
-
Provide at least 50 detailed entries for each complex field (preferences, beliefs, other_facts).
|
| 1370 |
-
|
| 1371 |
-
Rules:
|
| 1372 |
-
- Maintain consistency with the initial profile.
|
| 1373 |
-
- Output ONLY a valid JSON object.
|
| 1374 |
-
- Use the same field structure as the input.
|
| 1375 |
-
"""
|
| 1376 |
-
|
| 1377 |
-
messages = [
|
| 1378 |
-
{"role": "system", "content": "You are a specialized system for creating ultra-deep, 350-attribute persona specifications."},
|
| 1379 |
-
{"role": "user", "content": prompt}
|
| 1380 |
-
]
|
| 1381 |
-
|
| 1382 |
-
# Use the Helmholtz client via send_message
|
| 1383 |
-
message = self._aux_model_call(messages=messages, temperature=1.2, frequency_penalty=0.0, presence_penalty=0.0)
|
| 1384 |
-
|
| 1385 |
-
if message is not None:
|
| 1386 |
-
enriched_spec = utils.extract_json(message["content"])
|
| 1387 |
-
return enriched_spec
|
| 1388 |
-
|
| 1389 |
-
return initial_spec
|
| 1390 |
-
|
| 1391 |
@transactional()
|
| 1392 |
def _setup_agent(self, agent, configuration):
|
| 1393 |
"""
|
|
|
|
| 180 |
frequency_penalty:float=0.0,
|
| 181 |
presence_penalty:float=0.0,
|
| 182 |
attempts:int=10,
|
| 183 |
+
post_processing_func=None) -> TinyPerson:
|
|
|
|
| 184 |
"""
|
| 185 |
Generate a TinyPerson instance using OpenAI's LLM.
|
| 186 |
|
|
|
|
| 318 |
|
| 319 |
# create the fresh agent
|
| 320 |
if agent_spec is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
# the agent is created here. This is why the present method cannot be cached. Instead, an auxiliary method is used
|
| 322 |
# for the actual model call, so that it gets cached properly without skipping the agent creation.
|
| 323 |
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
@config_manager.config_defaults(parallelize="parallel_agent_generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
def generate_people(self, number_of_people:int=None,
|
| 346 |
agent_particularities:str=None,
|
| 347 |
temperature:float=1.2,
|
|
|
|
| 350 |
attempts:int=10,
|
| 351 |
post_processing_func=None,
|
| 352 |
parallelize=None,
|
| 353 |
+
verbose:bool=False) -> list:
|
|
|
|
| 354 |
"""
|
| 355 |
Generate a list of TinyPerson instances using OpenAI's LLM.
|
| 356 |
|
|
|
|
| 390 |
presence_penalty=presence_penalty,
|
| 391 |
attempts=attempts,
|
| 392 |
post_processing_func=post_processing_func,
|
| 393 |
+
verbose=verbose)
|
|
|
|
| 394 |
else:
|
| 395 |
people = self._generate_people_sequentially(number_of_people=number_of_people,
|
| 396 |
agent_particularities=agent_particularities,
|
|
|
|
| 399 |
presence_penalty=presence_penalty,
|
| 400 |
attempts=attempts,
|
| 401 |
post_processing_func=post_processing_func,
|
| 402 |
+
verbose=verbose)
|
|
|
|
| 403 |
|
| 404 |
return people
|
| 405 |
|
|
|
|
| 412 |
presence_penalty:float=0.0,
|
| 413 |
attempts:int=10,
|
| 414 |
post_processing_func=None,
|
| 415 |
+
verbose:bool=False) -> list:
|
|
|
|
| 416 |
people = []
|
| 417 |
|
| 418 |
#
|
|
|
|
| 424 |
|
| 425 |
# this is the function that will be executed in parallel
|
| 426 |
def generate_person_wrapper(args):
|
| 427 |
+
self, i, agent_particularities, temperature, frequency_penalty, presence_penalty, attempts, post_processing_func = args
|
| 428 |
person = self.generate_person(agent_particularities=agent_particularities,
|
| 429 |
temperature=temperature,
|
| 430 |
frequency_penalty=frequency_penalty,
|
| 431 |
presence_penalty=presence_penalty,
|
| 432 |
attempts=attempts,
|
| 433 |
+
post_processing_func=post_processing_func)
|
|
|
|
| 434 |
return i, person
|
| 435 |
|
| 436 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 437 |
# we use a list of futures to keep track of the results
|
| 438 |
futures = [
|
| 439 |
+
executor.submit(generate_person_wrapper, (self, i, agent_particularities, temperature, frequency_penalty, presence_penalty, attempts, post_processing_func))
|
| 440 |
for i in range(number_of_people)
|
| 441 |
]
|
| 442 |
|
|
|
|
| 463 |
presence_penalty:float=0.0,
|
| 464 |
attempts:int=10,
|
| 465 |
post_processing_func=None,
|
| 466 |
+
verbose:bool=False) -> list:
|
|
|
|
| 467 |
"""
|
| 468 |
Generate the people sequentially, not in parallel. This is a simpler alternative.
|
| 469 |
"""
|
|
|
|
| 474 |
frequency_penalty=frequency_penalty,
|
| 475 |
presence_penalty=presence_penalty,
|
| 476 |
attempts=attempts,
|
| 477 |
+
post_processing_func=post_processing_func)
|
|
|
|
| 478 |
if person is not None:
|
| 479 |
people.append(person)
|
| 480 |
info_msg = f"Generated person {i+1}/{number_of_people}: {person.minibio()}"
|
|
|
|
| 558 |
if len(self.remaining_characteristics_sample) != n:
|
| 559 |
logger.warning(f"Expected {n} samples, but got {len(self.remaining_characteristics_sample)} samples. The LLM may have failed to sum up the quantities in the sampling plan correctly.")
|
| 560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
logger.info(f"Sample plan has been flattened, contains {len(self.remaining_characteristics_sample)} total samples.")
|
| 562 |
logger.debug(f"Remaining characteristics sample: {json.dumps(self.remaining_characteristics_sample, indent=4)}")
|
| 563 |
|
|
|
|
| 1295 |
presence_penalty=presence_penalty,
|
| 1296 |
response_format={"type": "json_object"})
|
| 1297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1298 |
@transactional()
|
| 1299 |
def _setup_agent(self, agent, configuration):
|
| 1300 |
"""
|
tinytroupe/openai_utils.py
CHANGED
|
@@ -31,8 +31,6 @@ class OpenAIClient:
|
|
| 31 |
def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
|
| 32 |
logger.debug("Initializing OpenAIClient")
|
| 33 |
|
| 34 |
-
self.client = None
|
| 35 |
-
|
| 36 |
# should we cache api calls and reuse them?
|
| 37 |
self.set_api_cache(cache_api_calls, cache_file_name)
|
| 38 |
|
|
@@ -54,8 +52,7 @@ class OpenAIClient:
|
|
| 54 |
"""
|
| 55 |
Sets up the OpenAI API configurations for this client.
|
| 56 |
"""
|
| 57 |
-
|
| 58 |
-
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 59 |
|
| 60 |
@config_manager.config_defaults(
|
| 61 |
model="model",
|
|
@@ -159,33 +156,14 @@ class OpenAIClient:
|
|
| 159 |
chat_api_params["response_format"] = response_format
|
| 160 |
|
| 161 |
i = 0
|
| 162 |
-
while
|
| 163 |
try:
|
| 164 |
i += 1
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
# Model fallback and retry strategy requested by the user:
|
| 168 |
-
# 1. alias-fast for 3 attempts, 35s wait
|
| 169 |
-
# 2. alias-large for 2 attempts, 35s wait
|
| 170 |
-
# 3. alias-huge until success, 60s wait
|
| 171 |
-
#
|
| 172 |
-
# Model fallback strategy using config
|
| 173 |
-
if i <= 3:
|
| 174 |
-
current_model = config["OpenAI"].get("MODEL", "alias-fast")
|
| 175 |
-
current_wait_time = 35
|
| 176 |
-
elif i <= 5:
|
| 177 |
-
current_model = config["OpenAI"].get("FALLBACK_MODEL_LARGE", "alias-large")
|
| 178 |
-
current_wait_time = 35
|
| 179 |
-
else:
|
| 180 |
-
current_model = config["OpenAI"].get("FALLBACK_MODEL_HUGE", "alias-huge")
|
| 181 |
-
current_wait_time = 60
|
| 182 |
-
|
| 183 |
-
chat_api_params["model"] = current_model
|
| 184 |
-
|
| 185 |
try:
|
| 186 |
-
logger.debug(f"Sending messages to OpenAI API.
|
| 187 |
except NotImplementedError:
|
| 188 |
-
logger.debug(f"Token count not implemented for model {
|
| 189 |
|
| 190 |
start_time = time.monotonic()
|
| 191 |
logger.debug(f"Calling model with client class {self.__class__.__name__}.")
|
|
@@ -193,11 +171,15 @@ class OpenAIClient:
|
|
| 193 |
###############################################################
|
| 194 |
# call the model, either from the cache or from the API
|
| 195 |
###############################################################
|
| 196 |
-
cache_key = str((
|
| 197 |
if self.cache_api_calls and (cache_key in self.api_cache):
|
| 198 |
response = self.api_cache[cache_key]
|
| 199 |
else:
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
if self.cache_api_calls:
|
| 202 |
self.api_cache[cache_key] = response
|
| 203 |
self._save_cache()
|
|
@@ -213,21 +195,35 @@ class OpenAIClient:
|
|
| 213 |
else:
|
| 214 |
return utils.sanitize_dict(self._raw_model_response_extractor(response))
|
| 215 |
|
| 216 |
-
except
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
logger.error(f"[{i}] Invalid request error, won't retry: {e}")
|
|
|
|
|
|
|
|
|
|
| 218 |
return None
|
| 219 |
|
| 220 |
-
except
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
def _raw_model_call(self, model, chat_api_params):
|
| 233 |
"""
|
|
@@ -250,12 +246,8 @@ class OpenAIClient:
|
|
| 250 |
chat_api_params["reasoning_effort"] = default["reasoning_effort"]
|
| 251 |
|
| 252 |
|
| 253 |
-
# To make the log cleaner, we remove the messages from the logged parameters
|
| 254 |
-
|
| 255 |
-
if logger.getEffectiveLevel() <= logging.DEBUG:
|
| 256 |
-
logged_params = chat_api_params
|
| 257 |
-
else:
|
| 258 |
-
logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"}
|
| 259 |
|
| 260 |
if "response_format" in chat_api_params:
|
| 261 |
# to enforce the response format via pydantic, we need to use a different method
|
|
@@ -404,23 +396,22 @@ class AzureClient(OpenAIClient):
|
|
| 404 |
Sets up the Azure OpenAI Service API configurations for this client,
|
| 405 |
including the API endpoint and key.
|
| 406 |
"""
|
| 407 |
-
if
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
)
|
| 424 |
|
| 425 |
|
| 426 |
class HelmholtzBlabladorClient(OpenAIClient):
|
|
@@ -433,17 +424,10 @@ class HelmholtzBlabladorClient(OpenAIClient):
|
|
| 433 |
"""
|
| 434 |
Sets up the Helmholtz Blablador API configurations for this client.
|
| 435 |
"""
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
if self.client is None or self.client.api_key != api_key:
|
| 442 |
-
logger.debug(f"Setting up Helmholtz client with base_url and key.")
|
| 443 |
-
self.client = OpenAI(
|
| 444 |
-
base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
|
| 445 |
-
api_key=api_key,
|
| 446 |
-
)
|
| 447 |
|
| 448 |
###########################################################################
|
| 449 |
# Exceptions
|
|
|
|
| 31 |
def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
|
| 32 |
logger.debug("Initializing OpenAIClient")
|
| 33 |
|
|
|
|
|
|
|
| 34 |
# should we cache api calls and reuse them?
|
| 35 |
self.set_api_cache(cache_api_calls, cache_file_name)
|
| 36 |
|
|
|
|
| 52 |
"""
|
| 53 |
Sets up the OpenAI API configurations for this client.
|
| 54 |
"""
|
| 55 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
|
| 56 |
|
| 57 |
@config_manager.config_defaults(
|
| 58 |
model="model",
|
|
|
|
| 156 |
chat_api_params["response_format"] = response_format
|
| 157 |
|
| 158 |
i = 0
|
| 159 |
+
while i < max_attempts:
|
| 160 |
try:
|
| 161 |
i += 1
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
+
logger.debug(f"Sending messages to OpenAI API. Token count={self._count_tokens(current_messages, model)}.")
|
| 165 |
except NotImplementedError:
|
| 166 |
+
logger.debug(f"Token count not implemented for model {model}.")
|
| 167 |
|
| 168 |
start_time = time.monotonic()
|
| 169 |
logger.debug(f"Calling model with client class {self.__class__.__name__}.")
|
|
|
|
| 171 |
###############################################################
|
| 172 |
# call the model, either from the cache or from the API
|
| 173 |
###############################################################
|
| 174 |
+
cache_key = str((model, chat_api_params)) # need string to be hashable
|
| 175 |
if self.cache_api_calls and (cache_key in self.api_cache):
|
| 176 |
response = self.api_cache[cache_key]
|
| 177 |
else:
|
| 178 |
+
if waiting_time > 0:
|
| 179 |
+
logger.info(f"Waiting {waiting_time} seconds before next API request (to avoid throttling)...")
|
| 180 |
+
time.sleep(waiting_time)
|
| 181 |
+
|
| 182 |
+
response = self._raw_model_call(model, chat_api_params)
|
| 183 |
if self.cache_api_calls:
|
| 184 |
self.api_cache[cache_key] = response
|
| 185 |
self._save_cache()
|
|
|
|
| 195 |
else:
|
| 196 |
return utils.sanitize_dict(self._raw_model_response_extractor(response))
|
| 197 |
|
| 198 |
+
except InvalidRequestError as e:
|
| 199 |
+
logger.error(f"[{i}] Invalid request error, won't retry: {e}")
|
| 200 |
+
|
| 201 |
+
# there's no point in retrying if the request is invalid
|
| 202 |
+
# so we return None right away
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
except openai.BadRequestError as e:
|
| 206 |
logger.error(f"[{i}] Invalid request error, won't retry: {e}")
|
| 207 |
+
|
| 208 |
+
# there's no point in retrying if the request is invalid
|
| 209 |
+
# so we return None right away
|
| 210 |
return None
|
| 211 |
|
| 212 |
+
except openai.RateLimitError:
|
| 213 |
+
logger.warning(
|
| 214 |
+
f"[{i}] Rate limit error, waiting a bit and trying again.")
|
| 215 |
+
aux_exponential_backoff()
|
| 216 |
+
|
| 217 |
+
except NonTerminalError as e:
|
| 218 |
+
logger.error(f"[{i}] Non-terminal error: {e}")
|
| 219 |
+
aux_exponential_backoff()
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"[{i}] {type(e).__name__} Error: {e}")
|
| 223 |
+
aux_exponential_backoff()
|
| 224 |
+
|
| 225 |
+
logger.error(f"Failed to get response after {max_attempts} attempts.")
|
| 226 |
+
return None
|
| 227 |
|
| 228 |
def _raw_model_call(self, model, chat_api_params):
|
| 229 |
"""
|
|
|
|
| 246 |
chat_api_params["reasoning_effort"] = default["reasoning_effort"]
|
| 247 |
|
| 248 |
|
| 249 |
+
# To make the log cleaner, we remove the messages from the logged parameters
|
| 250 |
+
logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
if "response_format" in chat_api_params:
|
| 253 |
# to enforce the response format via pydantic, we need to use a different method
|
|
|
|
| 396 |
Sets up the Azure OpenAI Service API configurations for this client,
|
| 397 |
including the API endpoint and key.
|
| 398 |
"""
|
| 399 |
+
if os.getenv("AZURE_OPENAI_KEY"):
|
| 400 |
+
logger.info("Using Azure OpenAI Service API with key.")
|
| 401 |
+
self.client = AzureOpenAI(azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 402 |
+
api_version = config["OpenAI"]["AZURE_API_VERSION"],
|
| 403 |
+
api_key = os.getenv("AZURE_OPENAI_KEY"))
|
| 404 |
+
else: # Use Entra ID Auth
|
| 405 |
+
logger.info("Using Azure OpenAI Service API with Entra ID Auth.")
|
| 406 |
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
| 407 |
+
|
| 408 |
+
credential = DefaultAzureCredential()
|
| 409 |
+
token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
|
| 410 |
+
self.client = AzureOpenAI(
|
| 411 |
+
azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 412 |
+
api_version = config["OpenAI"]["AZURE_API_VERSION"],
|
| 413 |
+
azure_ad_token_provider=token_provider
|
| 414 |
+
)
|
|
|
|
| 415 |
|
| 416 |
|
| 417 |
class HelmholtzBlabladorClient(OpenAIClient):
|
|
|
|
| 424 |
"""
|
| 425 |
Sets up the Helmholtz Blablador API configurations for this client.
|
| 426 |
"""
|
| 427 |
+
self.client = OpenAI(
|
| 428 |
+
base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
|
| 429 |
+
api_key=os.getenv("BLABLADOR_API_KEY", "dummy"),
|
| 430 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
###########################################################################
|
| 433 |
# Exceptions
|
tinytroupe/utils/llm.py
CHANGED
|
@@ -721,7 +721,7 @@ class LLMChat:
|
|
| 721 |
|
| 722 |
def _request_list_of_dict_llm_message(self):
|
| 723 |
return {"role": "user",
|
| 724 |
-
"content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\
|
| 725 |
|
| 726 |
def _coerce_to_list(self, llm_output:str):
|
| 727 |
"""
|
|
|
|
| 721 |
|
| 722 |
def _request_list_of_dict_llm_message(self):
|
| 723 |
return {"role": "user",
|
| 724 |
+
"content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\{...\}, \{...\}, ...]`. This is critical for later processing."}
|
| 725 |
|
| 726 |
def _coerce_to_list(self, llm_output:str):
|
| 727 |
"""
|
tinytroupe/utils/semantics.py
CHANGED
|
@@ -265,45 +265,3 @@ def compute_semantic_proximity(text1: str, text2: str, context: str = None) -> f
|
|
| 265 |
"""
|
| 266 |
# llm decorator will handle the body of this function
|
| 267 |
|
| 268 |
-
@llm()
|
| 269 |
-
def select_best_persona(criteria: str, personas: list) -> int:
|
| 270 |
-
"""
|
| 271 |
-
Given a set of criteria and a list of personas (each a dictionary),
|
| 272 |
-
select the index of the persona that best matches the criteria.
|
| 273 |
-
If no persona matches at all, return -1.
|
| 274 |
-
|
| 275 |
-
Rules:
|
| 276 |
-
- You must analyze each persona against the criteria.
|
| 277 |
-
- Return ONLY the integer index (starting from 0) of the best matching persona.
|
| 278 |
-
- Do not provide any explanation, just the number.
|
| 279 |
-
- If there are multiple good matches, pick the best one.
|
| 280 |
-
|
| 281 |
-
Args:
|
| 282 |
-
criteria (str): The search criteria or description of the desired persona.
|
| 283 |
-
personas (list): A list of dictionaries, where each dictionary is a persona specification.
|
| 284 |
-
|
| 285 |
-
Returns:
|
| 286 |
-
int: The index of the best matching persona, or -1 if none match.
|
| 287 |
-
"""
|
| 288 |
-
# llm decorator will handle the body of this function
|
| 289 |
-
|
| 290 |
-
@llm()
|
| 291 |
-
def select_relevant_personas_utility(context: str, personas: list) -> list:
|
| 292 |
-
"""
|
| 293 |
-
Given a context and a list of personas (each a dictionary),
|
| 294 |
-
select which personas are relevant to the context.
|
| 295 |
-
|
| 296 |
-
Rules:
|
| 297 |
-
- Analyze each persona against the provided context.
|
| 298 |
-
- Return a LIST of indices (starting from 0) of the relevant personas.
|
| 299 |
-
- Return an empty list [] if none match.
|
| 300 |
-
- Provide the result as a JSON array of integers.
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
context (str): The context or requirements for persona selection.
|
| 304 |
-
personas (list): A list of dictionaries, where each dictionary is a persona specification.
|
| 305 |
-
|
| 306 |
-
Returns:
|
| 307 |
-
list: A list of indices of the matching personas.
|
| 308 |
-
"""
|
| 309 |
-
# llm decorator will handle the body of this function
|
|
|
|
| 265 |
"""
|
| 266 |
# llm decorator will handle the body of this function
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|