Spaces:
Sleeping
Sleeping
File size: 2,117 Bytes
868b252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, validator
from reworkd_platform.web.api.agent.analysis import Analysis
LLM_Model = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
]
Loop_Step = Literal[
"start",
"analyze",
"execute",
"create",
"summarize",
"chat",
]
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4": 8000,
}
class ModelSettings(BaseModel):
model: LLM_Model = Field(default="gpt-3.5-turbo")
custom_api_key: Optional[str] = Field(default=None)
temperature: float = Field(default=0.9, ge=0.0, le=1.0)
max_tokens: int = Field(default=500, ge=0)
language: str = Field(default="English")
@validator("max_tokens")
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float:
model = values["model"]
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]):
raise ValueError(f"Model {model} only supports {max_tokens} tokens")
return v
class AgentRunCreate(BaseModel):
goal: str
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentRun(AgentRunCreate):
run_id: str
class AgentTaskAnalyze(AgentRun):
task: str
tool_names: List[str] = Field(default=[])
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentTaskExecute(AgentRun):
task: str
analysis: Analysis
class AgentTaskCreate(AgentRun):
tasks: List[str] = Field(default=[])
last_task: Optional[str] = Field(default=None)
result: Optional[str] = Field(default=None)
completed_tasks: List[str] = Field(default=[])
class AgentSummarize(AgentRun):
results: List[str] = Field(default=[])
class AgentChat(AgentRun):
message: str
results: List[str] = Field(default=[])
class NewTasksResponse(BaseModel):
run_id: str
new_tasks: List[str] = Field(alias="newTasks")
class RunCount(BaseModel):
count: int
first_run: Optional[datetime]
last_run: Optional[datetime]
|