|
import datetime
|
|
import requests
|
|
import matplotlib.pyplot as plt
|
|
from mplfinance.original_flavor import candlestick_ohlc
|
|
import numpy as np
|
|
from sklearn.linear_model import LinearRegression
|
|
import os
|
|
from pathlib import Path
|
|
import streamlit as st
|
|
|
|
PLOT_DIR = Path("./Plots")
|
|
|
|
if not os.path.exists(PLOT_DIR):
|
|
os.mkdir(PLOT_DIR)
|
|
|
|
host = "https://api.gateio.ws"
|
|
prefix = "/api/v4"
|
|
headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
|
|
endpoint = '/spot/candlesticks'
|
|
url = host + prefix + endpoint
|
|
max_API_request_allowed = 900
|
|
|
|
def lin_reg(data, threshold_channel_len):
|
|
list_f = []
|
|
X = []
|
|
y = []
|
|
for i in range(0, len(data)):
|
|
X.append(data[i][0])
|
|
avg = (data[i][2] + data[i][3]) / 2
|
|
y.append(avg)
|
|
X = np.array(X).reshape(-1, 1)
|
|
y = np.array(y).reshape(-1, 1)
|
|
l = 0
|
|
j = threshold_channel_len
|
|
while l < j and j <= len(data):
|
|
score = []
|
|
list_pf = []
|
|
while j <= len(data):
|
|
reg = LinearRegression().fit(X[l:j], y[l:j])
|
|
temp_coeff = list(reg.coef_)
|
|
temp_intercept = list(reg.intercept_)
|
|
list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
|
|
score.append([reg.score(X[l:j], y[l:j]), j])
|
|
j = j + 1
|
|
req_score = float("-inf")
|
|
ind = -1
|
|
temp_ind = -1
|
|
for i in range(len(score)):
|
|
if req_score < score[i][0]:
|
|
ind = score[i][1]
|
|
req_score = score[i][0]
|
|
temp_ind = i
|
|
list_f.append(list_pf[temp_ind])
|
|
l = ind
|
|
j = ind + threshold_channel_len
|
|
return list_f
|
|
|
|
def binary_search(data, line_type, m, b, epsilon):
|
|
right = float("-inf")
|
|
left = float("inf")
|
|
get_y_intercept = lambda x, y: y - m * x
|
|
for i in range(len(data)):
|
|
d = data[i]
|
|
curr_y = d[2]
|
|
if line_type == "bottom":
|
|
curr_y = d[3]
|
|
curr = get_y_intercept(d[0], curr_y)
|
|
right = max(right, curr)
|
|
left = min(left, curr)
|
|
|
|
sign = -1
|
|
if line_type == "bottom":
|
|
left, right = right, left
|
|
sign = 1
|
|
ans = right
|
|
while left <= right:
|
|
mid = left + (right - left) // 2
|
|
intersection_count = 0
|
|
for i in range(len(data)):
|
|
d = data[i]
|
|
curr_y = m * d[0] + mid
|
|
candle_y = d[2]
|
|
if line_type == "bottom":
|
|
candle_y = d[3]
|
|
if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
|
|
intersection_count += 1
|
|
if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
|
|
intersection_count += 1
|
|
if intersection_count == 0:
|
|
right = mid + 1 * sign
|
|
ans = mid
|
|
else:
|
|
left = mid - 1 * sign
|
|
return ans
|
|
|
|
def plot_lines(lines, plt, converted_data):
|
|
for m, b, start, end in lines:
|
|
x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
|
|
y_data = [m * x + b for x in x_data]
|
|
plt.plot(x_data, y_data)
|
|
|
|
def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
|
|
curr_datetime = start_datetime
|
|
total_dates = 0
|
|
while curr_datetime <= end_datetime:
|
|
total_dates += 1
|
|
curr_datetime += interval_timedelta
|
|
data = []
|
|
for i in range(0, total_dates, max_API_request_allowed):
|
|
query_param = {
|
|
"currency_pair": "{}_USDT".format(currency),
|
|
"from": int((start_datetime + i * interval_timedelta).timestamp()),
|
|
"to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
|
|
"interval": interval,
|
|
}
|
|
r = requests.get(url=url, headers=headers, params=query_param)
|
|
if r.status_code != 200:
|
|
st.error("Invalid API Request")
|
|
return []
|
|
data += r.json()
|
|
return data
|
|
|
|
def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
|
|
start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
|
|
end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
|
|
|
|
if interval == "1h":
|
|
interval_timedelta = datetime.timedelta(hours=1)
|
|
elif interval == "4h":
|
|
interval_timedelta = datetime.timedelta(hours=4)
|
|
elif interval == "1d":
|
|
interval_timedelta = datetime.timedelta(days=1)
|
|
else:
|
|
interval_timedelta = datetime.timedelta(weeks=1)
|
|
|
|
start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
|
|
end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
|
|
|
|
data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
|
|
if len(data) == 0:
|
|
return
|
|
converted_data = []
|
|
for d in data:
|
|
converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
|
|
|
|
fig, ax = plt.subplots()
|
|
candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
|
|
|
|
fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
|
|
top_fitting_lines_data = []
|
|
bottom_fitting_lines_data = []
|
|
epsilon = 0
|
|
for i in range(len(fitting_lines_data)):
|
|
m, b, start, end = fitting_lines_data[i]
|
|
top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
|
|
bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
|
|
top_fitting_lines_data.append([m, top_b, start, end])
|
|
bottom_fitting_lines_data.append([m, bottom_b, start, end])
|
|
|
|
plot_lines(top_fitting_lines_data, plt, converted_data)
|
|
plot_lines(bottom_fitting_lines_data, plt, converted_data)
|
|
plt.title("{}_USDT".format(currency))
|
|
file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
|
|
file_location = os.path.join(PLOT_DIR, file_name)
|
|
plt.savefig(file_location)
|
|
st.pyplot(fig)
|
|
|
|
def main():
|
|
st.title("Cryptocurrency Regression Analysis")
|
|
st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
|
|
|
|
currency = st.text_input("Currency", "BTC")
|
|
interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"])
|
|
startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2022")
|
|
enddate = st.text_input("End Date (MM/DD/YYYY)", "12/31/2022")
|
|
threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
|
|
|
|
if st.button("Generate Plot"):
|
|
testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|