如何在Python Dash应用中对齐Plotly子图网格的行和列标签
如何在Python Dash应用中对齐Plotly子图网格的行和列标签
我太懂你这种烦恼了——当变量名特别长的时候,手动调坐标很容易让顶部和侧边的标签和子图对不上,看起来乱糟糟的。咱们可以通过绑定标签到对应子图的轴坐标来解决这个问题,不用再算全局的paper坐标,让Plotly自己帮咱们精准对齐。
问题根源
你原来的代码用了全局的paper坐标来定位标签,但换行后的标签本身有高度/宽度,而且子图的分布是按比例的,固定的paper坐标很容易出现偏移。换成子图自身的轴引用,就能让标签和对应的子图牢牢绑定。
解决方案:修改注释定位逻辑
我给你调整了标签的xref/yref参数,还有坐标值,同时优化了边距避免标签被截断。下面是修改后的完整代码,关键修改点我都标了注释:
import dash from dash import dcc, html import pandas as pd import plotly.express as px import plotly.subplots as sp import numpy as np import plotly.graph_objects as go # Sample DataFrame df = pd.DataFrame({ "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"], "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"], "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"], "DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"], 'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana'] }) # Convert data types def convert_dtypes(df): for col in df.columns: try: df[col] = pd.to_numeric(df[col]) # Convert to int/float except ValueError: try: df[col] = pd.to_datetime(df[col]) # Convert to datetime except ValueError: df[col] = df[col].astype("string") # Keep as string return df df = convert_dtypes(df) columns = df.columns num_cols = len(columns) # Dash App app = dash.Dash(__name__) app.layout = html.Div([ html.H1("Pairwise Column Plots"), dcc.Graph(id='grid-plots') ]) @app.callback( dash.Output('grid-plots', 'figure'), dash.Input('grid-plots', 'id') # Dummy input to trigger callback ) def create_plot_grid(_): fig = sp.make_subplots(rows = num_cols, cols = num_cols, shared_xaxes = False, shared_yaxes = False) annotations = [] # Store subplot titles dynamically # -------------------------- # 关键修改1:顶部列标签对齐 # -------------------------- for j, col_label in enumerate(columns): wrapped_label = '<br>'.join(col_label[x:x+10] for x in range(0, len(col_label), 10)) annotations.append( dict( text=f"<b>{wrapped_label}</b>", xref=f"x{j+1}", # 绑定到第j+1列的x轴,而非全局paper yref=f"y{1}", # 绑定到第一行的y轴 x=0.5, # 该列的中心位置(归一化坐标,0-1) y=1.1, # 子图顶部上方的位置,避免重叠 showarrow=False, font=dict(size=14, color="black"), align="center" # 换行文本居中显示 ) ) # -------------------------- # 关键修改2:侧边行标签对齐 # -------------------------- for i, row_label in enumerate(columns): wrapped_label = '<br>'.join(row_label[x:x+10] for x in range(0, len(row_label), 10)) annotations.append( dict( text=f"<b>{wrapped_label}</b>", xref=f"x{1}", # 绑定到第一列的x轴 yref=f"y{i+1}", # 绑定到第i+1行的y轴 x=-0.1, # 子图左侧的位置,留出空间放标签 y=0.5, # 该行的中心位置 showarrow=False, font=dict(size=14, color="black"), textangle=-90, align="center" # 换行文本居中显示 ) ) # 子图内容部分保持不变 for i, x_col in enumerate(columns): for j, y_col in enumerate(columns): dtype_x, dtype_y = df[x_col].dtype, df[y_col].dtype row, col = i + 1, j + 1 # Adjust for 1-based indexing # 只显示上三角网格 if j <= i: trace = None # 数值vs数值:散点图 elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y): trace = px.scatter(df, x = x_col, y = y_col).data[0] # 数值vs分类:箱线图 elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y): trace = px.box(df, x = y_col, y = x_col).data[0] elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y): trace = px.box(df, x = x_col, y = y_col).data[0] # 分类vs分类:计数热力图 elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y): counts_df = ( df .groupby([x_col, y_col]) .size() .reset_index(name = 'count') .pivot_table(index = x_col, columns = y_col, values = "count", aggfunc="sum") ) trace = go.Heatmap(z = counts_df.values, x = counts_df.columns, y = counts_df.index, showscale = False) # 时间vs数值:折线图 elif pd.api.types.is_datetime64_any_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y): trace = px.line(df, x = x_col, y = y_col).data[0] elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_datetime64_any_dtype(dtype_y): trace = px.line(df, x = y_col, y = x_col).data[0] else: trace = None # 暂不支持的类型组合 if trace: fig.add_trace(trace, row = row, col = col) # -------------------------- # 关键修改3:调整边距,避免长标签被截断 # -------------------------- fig.update_layout( height = 300 * num_cols, width = 300 * num_cols, showlegend = False, annotations = annotations, margin=dict(t=120, l=120) # 顶部和左侧留足够空间放换行后的标签 ) return fig if __name__ == '__main__': app.run_server(debug = True)
为什么这样修改有效?
- 轴绑定定位:用
xref=f"x{j+1}"和yref=f"y{1}"代替全局paper坐标,标签会直接跟着对应列/行的子图走,不管子图网格怎么缩放,都能保持中心对齐。 - 文本居中:设置
align="center"让换行后的长标签自己居中,不会出现左偏或右偏的情况。 - 边距调整:增加顶部和左侧的边距,确保换行后的标签不会被图表容器截断。
现在你运行修改后的代码,就能看到顶部和侧边的长标签精准对齐对应的子图,而且换行后的文本也整整齐齐的啦~
备注:内容来源于stack exchange,提问作者The_Questioner




