You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在Python Dash应用中对齐Plotly子图网格的行和列标签

如何在Python Dash应用中对齐Plotly子图网格的行和列标签

我太懂你这种烦恼了——当变量名特别长的时候,手动调坐标很容易让顶部和侧边的标签和子图对不上,看起来乱糟糟的。咱们可以通过绑定标签到对应子图的轴坐标来解决这个问题,不用再算全局的paper坐标,让Plotly自己帮咱们精准对齐。

问题根源

你原来的代码用了全局的paper坐标来定位标签,但换行后的标签本身有高度/宽度,而且子图的分布是按比例的,固定的paper坐标很容易出现偏移。换成子图自身的轴引用,就能让标签和对应的子图牢牢绑定。

解决方案:修改注释定位逻辑

我给你调整了标签的xref/yref参数,还有坐标值,同时优化了边距避免标签被截断。下面是修改后的完整代码,关键修改点我都标了注释:

import dash
from dash import dcc, html
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import numpy as np
import plotly.graph_objects as go


# Sample DataFrame
df = pd.DataFrame({
    "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"],
    "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"],
    "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"],
    "DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"],
    'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana']
})

# Convert data types
def convert_dtypes(df):
    for col in df.columns:
        try:
            df[col] = pd.to_numeric(df[col])  # Convert to int/float
        except ValueError:
            try:
                df[col] = pd.to_datetime(df[col])  # Convert to datetime
            except ValueError:
                df[col] = df[col].astype("string")  # Keep as string
    return df

df = convert_dtypes(df)
columns = df.columns
num_cols = len(columns)

# Dash App
app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Pairwise Column Plots"),
    dcc.Graph(id='grid-plots')
])

@app.callback(
    dash.Output('grid-plots', 'figure'),
    dash.Input('grid-plots', 'id')  # Dummy input to trigger callback
)
def create_plot_grid(_):
    fig = sp.make_subplots(rows = num_cols, cols = num_cols, 
                           shared_xaxes = False, shared_yaxes = False)

    annotations = []  # Store subplot titles dynamically
    # --------------------------
    # 关键修改1:顶部列标签对齐
    # --------------------------
    for j, col_label in enumerate(columns):
        wrapped_label = '<br>'.join(col_label[x:x+10] for x in range(0, len(col_label), 10))
        annotations.append(
            dict(
                text=f"<b>{wrapped_label}</b>",
                xref=f"x{j+1}",  # 绑定到第j+1列的x轴,而非全局paper
                yref=f"y{1}",     # 绑定到第一行的y轴
                x=0.5,            # 该列的中心位置(归一化坐标,0-1)
                y=1.1,            # 子图顶部上方的位置,避免重叠
                showarrow=False,
                font=dict(size=14, color="black"),
                align="center"    # 换行文本居中显示
            )
        )
    # --------------------------
    # 关键修改2:侧边行标签对齐
    # --------------------------
    for i, row_label in enumerate(columns):
        wrapped_label = '<br>'.join(row_label[x:x+10] for x in range(0, len(row_label), 10))
        annotations.append(
            dict(
                text=f"<b>{wrapped_label}</b>",
                xref=f"x{1}",     # 绑定到第一列的x轴
                yref=f"y{i+1}",   # 绑定到第i+1行的y轴
                x=-0.1,           # 子图左侧的位置,留出空间放标签
                y=0.5,            # 该行的中心位置
                showarrow=False,
                font=dict(size=14, color="black"),
                textangle=-90,
                align="center"    # 换行文本居中显示
            )
        )

    # 子图内容部分保持不变
    for i, x_col in enumerate(columns):
        for j, y_col in enumerate(columns):
            dtype_x, dtype_y = df[x_col].dtype, df[y_col].dtype
            row, col = i + 1, j + 1  # Adjust for 1-based indexing

            # 只显示上三角网格
            if j <= i:
                trace = None

            # 数值vs数值:散点图
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.scatter(df, x = x_col, y = y_col).data[0]

            # 数值vs分类:箱线图
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
                trace = px.box(df, x = y_col, y = x_col).data[0]
            elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.box(df, x = x_col, y = y_col).data[0]

            # 分类vs分类:计数热力图
            elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
                counts_df = (
                    df
                    .groupby([x_col, y_col])
                    .size()
                    .reset_index(name = 'count')
                    .pivot_table(index = x_col, columns = y_col, values = "count", aggfunc="sum")  
                ) 
                trace = go.Heatmap(z = counts_df.values, x = counts_df.columns, y = counts_df.index, showscale = False)

            # 时间vs数值:折线图
            elif pd.api.types.is_datetime64_any_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.line(df, x = x_col, y = y_col).data[0]
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_datetime64_any_dtype(dtype_y):
                trace = px.line(df, x = y_col, y = x_col).data[0]

            else:
                trace = None  # 暂不支持的类型组合

            if trace:
                fig.add_trace(trace, row = row, col = col)

    # --------------------------
    # 关键修改3:调整边距,避免长标签被截断
    # --------------------------
    fig.update_layout(
        height = 300 * num_cols, 
        width = 300 * num_cols, 
        showlegend = False,
        annotations = annotations,
        margin=dict(t=120, l=120)  # 顶部和左侧留足够空间放换行后的标签
    )
    return fig

if __name__ == '__main__':
    app.run_server(debug = True)

为什么这样修改有效?

  1. 轴绑定定位:用xref=f"x{j+1}"yref=f"y{1}"代替全局paper坐标,标签会直接跟着对应列/行的子图走,不管子图网格怎么缩放,都能保持中心对齐。
  2. 文本居中:设置align="center"让换行后的长标签自己居中,不会出现左偏或右偏的情况。
  3. 边距调整:增加顶部和左侧的边距,确保换行后的标签不会被图表容器截断。

现在你运行修改后的代码,就能看到顶部和侧边的长标签精准对齐对应的子图,而且换行后的文本也整整齐齐的啦~

备注:内容来源于stack exchange,提问作者The_Questioner

火山引擎 最新活动