socialmodel / app.py
developersajidbashir's picture
Create app.py
0aad598 verified
raw
history blame
No virus
5.84 kB
import gym
import numpy as np
import matplotlib.pyplot as plt
import requests
import pandas as pd
from datetime import datetime, timedelta
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from gym import spaces
import time
import firebase_admin
from firebase_admin import credentials, db
import os
cred = credentials.Certificate("credentials.json")
firebase_admin.initialize_app(cred, {"databaseURL": "https://socail-swap-default-rtdb.asia-southeast1.firebasedatabase.app/"})
ref = db.reference()
stopmodel = False
buy_signals = []
sell_signals = []
class TradingEnv(gym.Env):
def __init__(self, data, window_size=50):
super(TradingEnv, self).__init__()
self.data = data
self.window_size = window_size
self.current_step = window_size
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(
low=0, high=1, shape=(window_size, 2), dtype=np.float32)
def reset(self):
self.current_step = self.window_size
return self._get_observation()
def _get_observation(self):
window_data = self.data[self.current_step-self.window_size:self.current_step]
obs = window_data[['Close', 'EMA']].values
obs = (obs - obs.min()) / (obs.max() - obs.min())
return obs
def step(self, action):
reward = 0
done = False
self.current_step += 1
if self.current_step >= len(self.data):
done = True
else:
if action == 1:
reward = self.data['Close'].iloc[self.current_step] - self.data['Close'].iloc[self.current_step - 1]
elif action == 2:
reward = self.data['Close'].iloc[self.current_step - 1] - self.data['Close'].iloc[self.current_step]
return self._get_observation(), reward, done, {}
def fetch_data(symbol='ETHUSDT', interval='1h', start_date='2021-01-01'):
end_date = datetime.utcnow()
start_date = datetime.strptime(start_date, '%Y-%m-%d')
klines = []
while start_date < end_date:
url = f'https://api.binance.com/api/v3/klines?symbol={symbol}&interval={interval}&startTime={int(start_date.timestamp() * 1000)}'
response = requests.get(url)
data = response.json()
if not data:
break
klines += data
start_date = datetime.utcfromtimestamp(data[-1][6] / 1000)
df = pd.DataFrame(klines, columns=['timestamp', 'Open', 'High', 'Low', 'Close', 'Volume', 'Close_time', 'Quote_asset_volume', 'Number_of_trades', 'Taker_buy_base_asset_volume', 'Taker_buy_quote_asset_volume', 'Ignore'])
df['Close'] = df['Close'].astype(float)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('timestamp', inplace=True)
return df[['Close']]
def calculate_ema(data, span=20):
data['EMA'] = data['Close'].ewm(span=span, adjust=False).mean()
return data
def run_model():
while True:
try:
new_data = fetch_data(start_date=(datetime.utcnow() - timedelta(days=3)).strftime('%Y-%m-%d'))
new_data = calculate_ema(new_data)
if len(new_data) < 50:
print("Not enough data to update the environment.")
time.sleep(3600)
continue
env = DummyVecEnv([lambda: TradingEnv(new_data)])
model.set_env(env)
obs = env.reset()
dates = new_data.index[50:]
prices = new_data['Close'][50:]
emas = new_data['EMA'][50:]
actions = []
for date, price, ema in zip(dates, prices, emas):
action, _ = model.predict(obs)
actions.append(action[0])
obs, _, done, _ = env.step(action)
if done:
break
new_buy_signals = [(date, price, ema) for date, price, ema, action in zip(dates, prices, emas, actions) if action == 1 and date not in [signal[0] for signal in buy_signals]]
new_sell_signals = [(date, price, ema) for date, price, ema, action in zip(dates, prices, emas, actions) if action == 2 and date not in [signal[0] for signal in sell_signals]]
for signal in new_buy_signals:
if signal[0] not in [s[0] for s in buy_signals] and signal[0] not in [s[0] for s in sell_signals]:
buy_signals.append(signal)
for signal in new_sell_signals:
if signal[0] not in [s[0] for s in sell_signals] and signal[0] not in [s[0] for s in buy_signals]:
sell_signals.append(signal)
buy_signals_data = [{'timestamp': signal[0].strftime('%Y-%m-%d %H:%M:%S'), 'type': 'b', 'price': round(signal[1], 2), 'ema': round(signal[2],2)} for signal in buy_signals]
sell_signals_data = [{'timestamp': signal[0].strftime('%Y-%m-%d %H:%M:%S'), 'type': 's', 'price': round(signal[1], 2), 'ema': round(signal[2],2)} for signal in sell_signals]
all_signals_data = buy_signals_data + sell_signals_data
ref.child('signals').child('data').set(all_signals_data)
time.sleep(3600)
except Exception as e:
print(f"An error occurred: {e}")
break
if __name__ == "__main__":
data = fetch_data()
data = calculate_ema(data)
if len(data) < 50:
raise ValueError("Not enough data to fill the window size.")
env = DummyVecEnv([lambda: TradingEnv(data)])
if os.path.exists("./ppo_trading_agent.zip"):
model = PPO.load("ppo_trading_agent", env=env)
else:
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
print("Ender function")
run_model()