from chart import Chart from plotly.graph_objects import Figure from plotly.subplots import make_subplots from typing import List def subplot_titles(subplots: List[List[Chart]]) -> List[str]: subplot_titles = [] for row in subplots: for chart in row: subplot_titles.append(chart.title) return subplot_titles def figure_with_subplots(subplots: List[List[Chart]]) -> Figure: num_rows = len(subplots) num_columns = len(subplots[0]) figure = make_subplots(rows = num_rows, cols = num_columns, subplot_titles = subplot_titles(subplots)) for i, row in enumerate(subplots, start = 1): for j, chart in enumerate(row, start = 1): for trace in chart.traces(): figure.add_trace(trace, row = i, col = j) figure.update_xaxes( showgrid = False, showticklabels = True, rangebreaks = chart.range_breaks(), rangeslider = dict(visible = False), row = i, col = j ) figure.update_yaxes( showgrid = False, side = 'right', row = i, col = j ) # TODO: Make this configurable. figure.update_layout( bargap = 0, bargroupgap = 0, paper_bgcolor = '#0f0f0f', plot_bgcolor = '#0f0f0f', font_color = '#7a7c7d', showlegend = False, margin = dict(pad = 15), xaxis = dict(zerolinecolor = '#0f0f0f'), yaxis = dict(zerolinecolor = '#0f0f0f') ) return figure def plot(chart: Chart) -> None: figure_with_subplots([[chart]]).show(config = {'displayModeBar': False})