diff --git a/strategies/cumulative_rsi.py b/strategies/cumulative_rsi.py index f6d9f33..b613a8d 100644 --- a/strategies/cumulative_rsi.py +++ b/strategies/cumulative_rsi.py @@ -1,45 +1,17 @@ -import pandas as pd -import numpy as np - +from numpy import where from pandas import DataFrame, Series -def calculate_moving_average(data: DataFrame, window: int = 200) -> Series: - """ - Calculate the 200-day moving average and return it as a Series without modifying the original DataFrame. - """ - return data['Close'].rolling(window = window).mean() - -def calculate_rsi(data: DataFrame, period: int = 2) -> Series: - """ - Calculate the 2-period RSI and return it as a Series without modifying the original DataFrame. - """ - delta = data['Close'].diff() - gain = np.where(delta > 0, delta, 0) - loss = np.where(delta < 0, -delta, 0) - - alpha = 1 / period - avg_gain = pd.Series(gain).ewm(alpha = alpha, adjust = False).mean() - avg_loss = pd.Series(loss).ewm(alpha = alpha, adjust = False).mean() - - rs = avg_gain / avg_loss - return 100 - (100 / (1 + rs)) - -def calculate_cumulative_rsi(rsi: Series, window: int = 2) -> Series: - """ - Calculate the cumulative RSI over a specified window period and return it as a Series. - """ - return rsi.rolling(window = window).sum() +from indicators import rsi, sma def cumulative_rsi(data: DataFrame, rsi_period: int = 2, cumulative_period: int = 2) -> Series: """ - Generate 'L'ong entry signals based on the Cumulative RSI strategy. - Returns a Series with 'L' for entry signals and 'N' otherwise without modifying the original DataFrame. - - Entry Condition: 2-period cumulative RSI below 35 and above the 200-day moving average. + Calculate signals for the Cumulative RSI strategy. + + Returns a Series with 'L' for long signals and 'N' otherwise. """ - ma_200 = calculate_moving_average(data) - rsi_2 = calculate_rsi(data, period = rsi_period) - cumulative_rsi_2 = calculate_cumulative_rsi(rsi_2, window = cumulative_period) + ma_200 = sma(data, period = 200) + rsi_2 = rsi(data, period = rsi_period) + cumulative_rsi_2 = rsi_2.rolling(window = cumulative_period).sum() long_condition = (data['Close'] > ma_200) & (cumulative_rsi_2 < 35) - return Series(np.where(long_condition, 'L', 'N'), index = data.index) \ No newline at end of file + return Series(where(long_condition, 'L', 'N'), index = data.index) \ No newline at end of file