From 49e0c4c9bd42755bb3dbaebed848a06c33499003 Mon Sep 17 00:00:00 2001 From: moshferatu Date: Mon, 1 Jan 2024 06:00:09 -0800 Subject: [PATCH] Add support for multiple traces (e.g., charts where multiple lines are plotted) --- candlestick_chart.py | 7 ++++--- chart.py | 3 ++- line_chart.py | 24 +++++++++++++----------- line_chart_example.py | 3 +-- plot.py | 3 ++- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/candlestick_chart.py b/candlestick_chart.py index b239fc3..e921aa0 100644 --- a/candlestick_chart.py +++ b/candlestick_chart.py @@ -1,6 +1,7 @@ from chart import Chart from pandas import Series from plotly.graph_objects import Candlestick +from typing import List class CandlestickChart(Chart): @@ -11,8 +12,8 @@ class CandlestickChart(Chart): self.lows = lows self.closes = closes - def trace(self) -> Candlestick: - return Candlestick( + def traces(self) -> List[Candlestick]: + return [Candlestick( x = self.x, open = self.opens, high = self.highs, @@ -20,4 +21,4 @@ class CandlestickChart(Chart): close = self.closes, # TODO: Make colors configurable. increasing = dict(line = dict(color = 'limegreen', width = 1), fillcolor = 'limegreen'), decreasing = dict(line = dict(color = 'red', width = 1), fillcolor = 'red') - ) \ No newline at end of file + )] \ No newline at end of file diff --git a/chart.py b/chart.py index e1cb015..aebd954 100644 --- a/chart.py +++ b/chart.py @@ -1,6 +1,7 @@ from plotly.basedatatypes import BaseTraceType +from typing import List class Chart: - def trace(self) -> BaseTraceType: + def traces(self) -> List[BaseTraceType]: pass \ No newline at end of file diff --git a/line_chart.py b/line_chart.py index 6907023..2652e63 100644 --- a/line_chart.py +++ b/line_chart.py @@ -1,19 +1,21 @@ from chart import Chart from pandas import Series from plotly.graph_objects import Scatter +from typing import List class LineChart(Chart): - def __init__(self, x: Series, y: Series, name: str): + def __init__(self, x: Series, values: List[Series]): self.x = x - self.y = y - self.name = name + self.values = values - def trace(self) -> Scatter: - return Scatter( - x = self.x, - y = self.y, - name = self.name, - mode = 'lines', - line=dict(color = 'yellow') # TODO: Make this configurable. - ) \ No newline at end of file + def traces(self) -> List[Scatter]: + traces = [] + for value in self.values: + traces.append(Scatter( + x = self.x, + y = value, + mode = 'lines', + line=dict(color = 'yellow') # TODO: Make this configurable. + )) + return traces \ No newline at end of file diff --git a/line_chart_example.py b/line_chart_example.py index 34affa2..bb6d4ab 100644 --- a/line_chart_example.py +++ b/line_chart_example.py @@ -10,7 +10,6 @@ data = ohlc('SPX.XO', '1d', start_date = start_date, end_date = end_date) line_chart = LineChart( x = data['Timestamp'], - y = data['Close'], - name = 'SPX' + values = [data['Close']] ) plot(line_chart) \ No newline at end of file diff --git a/plot.py b/plot.py index 772b399..02ea3c5 100644 --- a/plot.py +++ b/plot.py @@ -11,7 +11,8 @@ def figure_with_subplots(subplots: List[List[Chart]]) -> Figure: for i, row in enumerate(subplots, start = 1): for j, chart in enumerate(row, start = 1): - figure.add_trace(chart.trace(), row = i, col = j) + for trace in chart.traces(): + figure.add_trace(trace, row = i, col = j) figure.update_xaxes(showgrid = False, showticklabels = True, rangeslider = dict(visible = False), row = i, col = j) figure.update_yaxes(showgrid = False, side = 'right', row = i, col = j)