From aa06e17dcfc0db19a7fb73199cdf42d2536d911f Mon Sep 17 00:00:00 2001 From: moshferatu Date: Wed, 3 Jan 2024 06:32:29 -0800 Subject: [PATCH] Add titles to charts --- candlestick_chart.py | 11 ++++++++++- candlestick_chart_example.py | 3 ++- chart.py | 1 + line_chart.py | 3 ++- line_chart_example.py | 3 ++- multi_line_chart_example.py | 3 ++- plot.py | 11 +++++++++-- 7 files changed, 28 insertions(+), 7 deletions(-) diff --git a/candlestick_chart.py b/candlestick_chart.py index e921aa0..d5fe773 100644 --- a/candlestick_chart.py +++ b/candlestick_chart.py @@ -5,12 +5,21 @@ from typing import List class CandlestickChart(Chart): - def __init__(self, x: Series, opens: Series, highs: Series, lows: Series, closes: Series): + def __init__( + self, + x: Series, + opens: Series, + highs: Series, + lows: Series, + closes: Series, + title: str = '' + ): self.x = x self.opens = opens self.highs = highs self.lows = lows self.closes = closes + self.title = title def traces(self) -> List[Candlestick]: return [Candlestick( diff --git a/candlestick_chart_example.py b/candlestick_chart_example.py index 565d4cf..792928d 100644 --- a/candlestick_chart_example.py +++ b/candlestick_chart_example.py @@ -13,6 +13,7 @@ candlestick_chart = CandlestickChart( opens = data['Open'], highs = data['High'], lows = data['Low'], - closes = data['Close'] + closes = data['Close'], + title = 'SPX' ) plot(candlestick_chart) \ No newline at end of file diff --git a/chart.py b/chart.py index 9684342..5cb72d8 100644 --- a/chart.py +++ b/chart.py @@ -5,6 +5,7 @@ from typing import List class Chart: x: Series = None + title: str = None def traces(self) -> List[BaseTraceType]: pass diff --git a/line_chart.py b/line_chart.py index 2652e63..944171e 100644 --- a/line_chart.py +++ b/line_chart.py @@ -5,9 +5,10 @@ from typing import List class LineChart(Chart): - def __init__(self, x: Series, values: List[Series]): + def __init__(self, x: Series, values: List[Series], title: str = ''): self.x = x self.values = values + self.title = title def traces(self) -> List[Scatter]: traces = [] diff --git a/line_chart_example.py b/line_chart_example.py index bb6d4ab..009b2cb 100644 --- a/line_chart_example.py +++ b/line_chart_example.py @@ -10,6 +10,7 @@ data = ohlc('SPX.XO', '1d', start_date = start_date, end_date = end_date) line_chart = LineChart( x = data['Timestamp'], - values = [data['Close']] + values = [data['Close']], + title = 'SPX (Close)' ) plot(line_chart) \ No newline at end of file diff --git a/multi_line_chart_example.py b/multi_line_chart_example.py index 67ceae6..a1f2e41 100644 --- a/multi_line_chart_example.py +++ b/multi_line_chart_example.py @@ -10,6 +10,7 @@ data = ohlc('SPX.XO', '1d', start_date = start_date, end_date = end_date) line_chart = LineChart( x = data['Timestamp'], - values = [data['High'], data['Low']] + values = [data['High'], data['Low']], + title = 'SPX (High, Low)' ) plot(line_chart) \ No newline at end of file diff --git a/plot.py b/plot.py index 5604933..715de16 100644 --- a/plot.py +++ b/plot.py @@ -3,11 +3,18 @@ from plotly.graph_objects import Figure from plotly.subplots import make_subplots from typing import List +def subplot_titles(subplots: List[List[Chart]]) -> List[str]: + subplot_titles = [] + for row in subplots: + for chart in row: + subplot_titles.append(chart.title) + return subplot_titles + def figure_with_subplots(subplots: List[List[Chart]]) -> Figure: num_rows = len(subplots) num_columns = len(subplots[0]) - - figure = make_subplots(rows = num_rows, cols = num_columns) + + figure = make_subplots(rows = num_rows, cols = num_columns, subplot_titles = subplot_titles(subplots)) for i, row in enumerate(subplots, start = 1): for j, chart in enumerate(row, start = 1):