Jarvis-K commited on
Commit
d731338
1 Parent(s): acf4310

simplified parser

Browse files
Files changed (1) hide show
  1. deciders/parser.py +26 -67
deciders/parser.py CHANGED
@@ -1,72 +1,31 @@
1
  from pydantic import BaseModel, Field, validator
2
 
3
- # Define your desired data structure.
4
- class TwoAction(BaseModel):
5
- action: int = Field(description="the choosed action to perform")
6
-
7
- # You can add custom validation logic easily with Pydantic.
8
- @validator('action')
9
- def action_is_valid(cls, field):
10
- if field not in [1, 2]:
11
- raise ValueError("Action is not valid ([1, 2])!")
12
- return field
13
-
14
- class ThreeAction(BaseModel):
15
- action: int = Field(description="the choosed action to perform")
16
-
17
- # You can add custom validation logic easily with Pydantic.
18
- @validator('action')
19
- def action_is_valid(cls, field):
20
- if field not in [1, 2, 3]:
21
- raise ValueError("Action is not valid ([1, 2, 3])!")
22
- return field
23
-
24
- class FourAction(BaseModel):
25
- action: int = Field(description="the choosed action to perform")
26
-
27
- # You can add custom validation logic easily with Pydantic.
28
- @validator('action')
29
- def action_is_valid(cls, field):
30
- if field not in [1, 2, 3, 4]:
31
- raise ValueError("Action is not valid ([1, 2, 3, 4])!")
32
- return field
33
-
34
- class SixAction(BaseModel):
35
- action: int = Field(description="the choosed action to perform")
36
-
37
- # You can add custom validation logic easily with Pydantic.
38
- @validator('action')
39
- def action_is_valid(cls, field):
40
- if field not in [1, 2, 3, 4, 5, 6]:
41
- raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
42
- return field
43
-
44
-
45
- class NineAction(BaseModel):
46
- action: int = Field(description="the choosed action to perform")
47
 
48
- # You can add custom validation logic easily with Pydantic.
49
- @validator('action')
50
- def action_is_valid(cls, field):
51
- if field not in [1, 2, 3, 4, 5, 6, 7, 8, 9]:
52
- raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9])!")
53
- return field
 
 
54
 
55
- class FullAtariAction(BaseModel):
56
- action: int = Field(description="the choosed action to perform")
57
- @validator('action')
58
- def action_is_valid(cls, info):
59
- if info not in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]:
60
- raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])!")
61
- return info
62
 
63
- class ContinuousAction(BaseModel):
64
- action: float = Field(description="the choosed action to perform")
65
- # You can add custom validation logic easily with Pydantic.
66
- @validator('action')
67
- def action_is_valid(cls, field):
68
- if not (field >= -1 and field <= 1):
69
- raise ValueError("Action is not valid ([-1,1])!")
70
- return field
71
-
72
- PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction, 9:NineAction, 18: FullAtariAction}
 
 
 
 
1
  from pydantic import BaseModel, Field, validator
2
 
3
+ class DisActionModel(BaseModel):
4
+ action: int = Field(description="the chosen action to perform")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ @classmethod
7
+ def create_validator(cls, max_action):
8
+ @validator('action', allow_reuse=True)
9
+ def action_is_valid(cls, field):
10
+ if field not in range(1, max_action + 1):
11
+ raise ValueError(f"Action is not valid ([1, {max_action}])!")
12
+ return field
13
+ return action_is_valid
14
 
15
+ # Generate classes dynamically
16
+ def generate_action_class(max_action):
17
+ return type(f"{max_action}Action", (DisActionModel,), {'action_is_valid': DisActionModel.create_validator(max_action)})
 
 
 
 
18
 
19
+ # Dictionary of parsers with dynamic class generation
20
+ PARSERS = {num: generate_action_class(num) for num in [2, 3, 4, 6, 9, 18]}
21
+
22
+ # class ContinuousAction(BaseModel):
23
+ # action: float = Field(description="the choosed action to perform")
24
+ # # You can add custom validation logic easily with Pydantic.
25
+ # @validator('action')
26
+ # def action_is_valid(cls, field):
27
+ # if not (field >= -1 and field <= 1):
28
+ # raise ValueError("Action is not valid ([-1,1])!")
29
+ # return field
30
+
31
+ # PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction, 9:NineAction, 18: FullAtariAction}