Add support for multiple traces (e.g., charts where multiple lines are plotted)

This commit is contained in:
moshferatu 2024-01-01 06:00:09 -08:00
parent 4a60be48ee
commit 49e0c4c9bd
5 changed files with 22 additions and 18 deletions

View File

@ -1,6 +1,7 @@
from chart import Chart from chart import Chart
from pandas import Series from pandas import Series
from plotly.graph_objects import Candlestick from plotly.graph_objects import Candlestick
from typing import List
class CandlestickChart(Chart): class CandlestickChart(Chart):
@ -11,8 +12,8 @@ class CandlestickChart(Chart):
self.lows = lows self.lows = lows
self.closes = closes self.closes = closes
def trace(self) -> Candlestick: def traces(self) -> List[Candlestick]:
return Candlestick( return [Candlestick(
x = self.x, x = self.x,
open = self.opens, open = self.opens,
high = self.highs, high = self.highs,
@ -20,4 +21,4 @@ class CandlestickChart(Chart):
close = self.closes, # TODO: Make colors configurable. close = self.closes, # TODO: Make colors configurable.
increasing = dict(line = dict(color = 'limegreen', width = 1), fillcolor = 'limegreen'), increasing = dict(line = dict(color = 'limegreen', width = 1), fillcolor = 'limegreen'),
decreasing = dict(line = dict(color = 'red', width = 1), fillcolor = 'red') decreasing = dict(line = dict(color = 'red', width = 1), fillcolor = 'red')
) )]

View File

@ -1,6 +1,7 @@
from plotly.basedatatypes import BaseTraceType from plotly.basedatatypes import BaseTraceType
from typing import List
class Chart: class Chart:
def trace(self) -> BaseTraceType: def traces(self) -> List[BaseTraceType]:
pass pass

View File

@ -1,19 +1,21 @@
from chart import Chart from chart import Chart
from pandas import Series from pandas import Series
from plotly.graph_objects import Scatter from plotly.graph_objects import Scatter
from typing import List
class LineChart(Chart): class LineChart(Chart):
def __init__(self, x: Series, y: Series, name: str): def __init__(self, x: Series, values: List[Series]):
self.x = x self.x = x
self.y = y self.values = values
self.name = name
def trace(self) -> Scatter: def traces(self) -> List[Scatter]:
return Scatter( traces = []
x = self.x, for value in self.values:
y = self.y, traces.append(Scatter(
name = self.name, x = self.x,
mode = 'lines', y = value,
line=dict(color = 'yellow') # TODO: Make this configurable. mode = 'lines',
) line=dict(color = 'yellow') # TODO: Make this configurable.
))
return traces

View File

@ -10,7 +10,6 @@ data = ohlc('SPX.XO', '1d', start_date = start_date, end_date = end_date)
line_chart = LineChart( line_chart = LineChart(
x = data['Timestamp'], x = data['Timestamp'],
y = data['Close'], values = [data['Close']]
name = 'SPX'
) )
plot(line_chart) plot(line_chart)

View File

@ -11,7 +11,8 @@ def figure_with_subplots(subplots: List[List[Chart]]) -> Figure:
for i, row in enumerate(subplots, start = 1): for i, row in enumerate(subplots, start = 1):
for j, chart in enumerate(row, start = 1): for j, chart in enumerate(row, start = 1):
figure.add_trace(chart.trace(), row = i, col = j) for trace in chart.traces():
figure.add_trace(trace, row = i, col = j)
figure.update_xaxes(showgrid = False, showticklabels = True, rangeslider = dict(visible = False), row = i, col = j) figure.update_xaxes(showgrid = False, showticklabels = True, rangeslider = dict(visible = False), row = i, col = j)
figure.update_yaxes(showgrid = False, side = 'right', row = i, col = j) figure.update_yaxes(showgrid = False, side = 'right', row = i, col = j)