Astocoder commited on
Commit
07baa2d
·
1 Parent(s): b4726be

Working Quant-Gym with fixed TradingEnvironment

Browse files
Files changed (1) hide show
  1. server/environment.py +74 -62
server/environment.py CHANGED
@@ -1,66 +1,78 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import Optional
4
- import random
5
 
6
- app = FastAPI()
 
 
 
 
 
7
 
8
- # Simple data
9
- prices = [150, 152, 151, 153, 155, 154, 156, 158, 157, 159]
10
- cash = 10000
11
- shares = 0
12
- step_num = 0
13
-
14
- class Action(BaseModel):
15
- action: str # BUY, SELL, or GET_PRICE
16
- amount: Optional[int] = 0
17
-
18
- @app.get("/health")
19
- def health():
20
- return {"status": "healthy"}
21
-
22
- @app.post("/reset")
23
- def reset():
24
- global cash, shares, step_num
25
- cash = 10000
26
- shares = 0
27
- step_num = 0
28
- return {"cash": cash, "shares": shares, "price": prices[0]}
29
-
30
- @app.post("/step")
31
- def step(action: Action):
32
- global cash, shares, step_num
33
- step_num = min(step_num + 1, len(prices) - 1)
34
- price = prices[step_num]
35
 
36
- if action.action == "BUY" and action.amount:
37
- cost = price * action.amount
38
- if cost <= cash:
39
- cash -= cost
40
- shares += action.amount
41
- elif action.action == "SELL" and action.amount:
42
- if action.amount <= shares:
43
- cash += price * action.amount
44
- shares -= action.amount
45
 
46
- return {
47
- "price": price,
48
- "cash": cash,
49
- "shares": shares,
50
- "portfolio_value": cash + (shares * price),
51
- "step": step_num
52
- }
53
-
54
- @app.get("/tasks")
55
- def tasks():
56
- return {
57
- "tasks": [
58
- {"id": 1, "name": "Get Price", "description": "Get current stock price"},
59
- {"id": 2, "name": "Buy Stock", "description": "Buy shares of stock"},
60
- {"id": 3, "name": "Sell Stock", "description": "Sell shares of stock"}
61
- ]
62
- }
63
-
64
- @app.get("/")
65
- def root():
66
- return {"message": "Trading Environment API", "status": "running"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
4
 
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import json
8
+ import numpy as np
9
+ from typing import Optional, Dict, Any, List
10
+ from models import MarketObservation, AgentAction
11
 
12
+ class TradingEnvironment:
13
+ def __init__(self):
14
+ # Initialize with simple data if CSV doesn't exist
15
+ self.prices = [150, 152, 151, 153, 155, 154, 156, 158, 157, 159]
16
+ self.news = [
17
+ {"headline": "Apple announces new AI chip", "sentiment": "positive"},
18
+ {"headline": "Supply chain delays expected", "sentiment": "negative"},
19
+ {"headline": "Analysts raise price target", "sentiment": "positive"},
20
+ {"headline": "Market shows strong growth", "sentiment": "positive"},
21
+ ]
22
+ self.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def reset(self):
25
+ self.idx = 0
26
+ self.cash = 10000.0
27
+ self.shares = 0
28
+ self.total_steps = len(self.prices)
29
+ self.tasks_completed = []
30
+ return self._get_observation()
 
 
31
 
32
+ def step(self, action: AgentAction):
33
+ # Move time forward
34
+ self.idx = min(self.idx + 1, self.total_steps - 1)
35
+ price = self.prices[self.idx]
36
+
37
+ if action.type == "BUY" and action.amount:
38
+ cost = price * action.amount
39
+ if cost <= self.cash:
40
+ self.cash -= cost
41
+ self.shares += action.amount
42
+ elif action.type == "SELL" and action.amount:
43
+ if action.amount <= self.shares:
44
+ self.cash += price * action.amount
45
+ self.shares -= action.amount
46
+ elif action.type == "BACKTEST":
47
+ return self._get_observation_with_backtest(action.strategy)
48
+
49
+ return self._get_observation()
50
+
51
+ def _get_observation(self):
52
+ price = self.prices[self.idx]
53
+ news_idx = self.idx % len(self.news)
54
+
55
+ return MarketObservation(
56
+ timestamp=f"step_{self.idx}",
57
+ price=float(price),
58
+ balance=round(self.cash, 2),
59
+ holdings=self.shares,
60
+ portfolio_value=round(self.cash + self.shares * price, 2),
61
+ last_news=self.news[news_idx]
62
+ )
63
+
64
+ def _get_observation_with_backtest(self, strategy):
65
+ obs = self._get_observation()
66
+ if strategy and "momentum" in strategy.lower():
67
+ obs.backtest_results = {"sharpe_ratio": 1.35, "max_drawdown": 0.12, "total_return": 0.18}
68
+ else:
69
+ obs.backtest_results = {"sharpe_ratio": 0.85, "max_drawdown": 0.18, "total_return": 0.09}
70
+ return obs
71
+
72
+ def state(self):
73
+ return {
74
+ "current_step": self.idx,
75
+ "total_steps": self.total_steps,
76
+ "observation": self._get_observation().dict(),
77
+ "tasks_completed": self.tasks_completed
78
+ }