You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

Streamlit应用中实现侧边栏仅在指定Tab显示、切换Tab时自动隐藏并返回时重新显示的方案

Streamlit应用中实现侧边栏仅在指定Tab显示、切换Tab时自动隐藏并返回时重新显示的方案

嘿,我来帮你搞定这个侧边栏的显示逻辑问题!你的需求核心是让侧边栏只在「Modeling and analysis」这个Tab里出现,切走就自动隐藏,切回来又重新显示对吧?我们可以通过Streamlit的会话状态(session_state)来跟踪当前活跃Tab,再配合条件渲染来实现这个效果,下面是具体的修改方案:

关键思路拆解

  1. 跟踪当前选中的Tab:在会话状态里新增一个变量记录用户当前打开的Tab,这样我们就能精准判断是否要显示侧边栏。
  2. 条件渲染侧边栏:只有当当前Tab是「Modeling and analysis」,且满足之前的sidebar_score条件时,才渲染侧边栏的权重/分数调整内容;切换到其他Tab时,直接清空侧边栏。
  3. 修复变量作用域问题:把filtered_dfselected_features这些关键数据存到会话状态里,避免切换Tab时数据丢失导致逻辑失效。

修改后的完整代码

import streamlit as st
import pandas as pd
from io import BytesIO

# 假设这些是你已有的自定义函数
def read_file(uploaded_file):
    # 替换为你实际的文件读取逻辑
    if uploaded_file.name.endswith('.csv'):
        return pd.read_csv(uploaded_file)
    elif uploaded_file.name.endswith('.xlsx'):
        return pd.read_excel(uploaded_file)
    else:
        return pd.DataFrame()

def archetype_selection_and_filtration(df):
    # 替换为你实际的Archetype筛选逻辑
    st.selectbox("选择Archetype", df.columns[:3])
    return df

def feature_selection(df, selected_features, col1, col2, col3):
    # 替换为你实际的特征选择逻辑
    selected = st.multiselect("选择特征", df.columns[3:10])
    return selected

def get_inflator_deflator_dict(df):
    # 替换为你实际的Inflator/Deflator逻辑
    return {}

def process_and_score_data(df, weights_dict, scores_dict, feature_list, inflator_deflator_dict):
    # 替换为你实际的打分逻辑
    df['score'] = 50
    return df

def get_selected_columns():
    # 替换为你实际的列选择逻辑
    return {"column3": "status2"}

def create_propensity_tiers(scored_df):
    # 替换为你实际的分层逻辑
    scored_df['tier'] = 'Tier 2'
    return scored_df

def display_feature_distribution_charts(scored_df, selected_features, column3):
    # 替换为你实际的图表展示逻辑
    st.bar_chart(scored_df[column3].value_counts())

def display_additional_charts(scored_df, selected_features, column3):
    # 替换为你实际的其他图表逻辑
    st.line_chart(scored_df['score'].head(20))

def main():
    col1, col2 = st.columns([1, 12])  # 调整比例控制列宽

    with col2:
        st.title("Business Heuristic Model")

    with col1:
        st.image("download.jpg", width=100)

    # 初始化所有会话状态变量
    if "current_tab" not in st.session_state:
        st.session_state.current_tab = "Home"
    if "archetype_selected" not in st.session_state:
        st.session_state.archetype_selected = False
    if "features_selected" not in st.session_state:
        st.session_state.features_selected = False
    if "sidebar_score" not in st.session_state:
        st.session_state.sidebar_score = False
    if "filtered_df" not in st.session_state:
        st.session_state.filtered_df = None
    if "selected_features" not in st.session_state:
        st.session_state.selected_features = []
    if "scored_df" not in st.session_state:
        st.session_state.scored_df = None
    if "weights_dict" not in st.session_state:
        st.session_state.weights_dict = {}
    if "scores_dict" not in st.session_state:
        st.session_state.scores_dict = {}
    if "uploaded_file" not in st.session_state:
        st.session_state.uploaded_file = None

    # 创建Tabs
    tab_names = ["Home", "Archetype filtration", "Modeling and analysis", "Download results"]
    tabs = st.tabs(tab_names)
    
    # Home Tab逻辑
    with tabs[0]:
        st.session_state.current_tab = "Home"
        st.header("Upload data file")
        uploaded_file = st.file_uploader(
            "", type=["csv", "xlsx", "json", "pkl", "txt"]
        )
        if uploaded_file is not None:
            st.success("File uploaded successfully!")
            st.session_state.uploaded_file = uploaded_file
            # 上传新文件后重置相关状态
            st.session_state.filtered_df = None
            st.session_state.archetype_selected = False
        else:
            st.warning("**Disclaimer:** Please ensure that the uploaded file contains only the required columns for analysis.")

    # Archetype filtration Tab逻辑
    with tabs[1]:
        st.session_state.current_tab = "Archetype filtration"
        uploaded_file = st.session_state.get('uploaded_file')
        if uploaded_file is not None:
            df = read_file(uploaded_file)
            st.session_state.filtered_df = archetype_selection_and_filtration(df)

            if st.session_state.filtered_df is not None and not st.session_state.filtered_df.empty:
                if st.button("Submit Archetype Selection"):
                    st.success("Archetype selected successfully!")
                    st.session_state.archetype_selected = True
                    st.session_state.features_selected = False
        else:
            st.info("请先在Home Tab上传数据文件")

    # Modeling and analysis Tab逻辑(核心侧边栏控制)
    with tabs[2]:
        st.session_state.current_tab = "Modeling and analysis"
        if st.session_state.filtered_df is not None and not st.session_state.filtered_df.empty:
            if st.session_state.archetype_selected and not st.session_state.features_selected:
                st.session_state.selected_features = feature_selection(
                    st.session_state.filtered_df, st.session_state.selected_features, "id2", "name2", "status2"
                )
                if st.session_state.selected_features:
                    st.session_state.sidebar_score = True
                    st.session_state.features_selected = True  # 标记特征已选择
                else:
                    st.warning("请至少选择一个特征")
            elif not st.session_state.archetype_selected:
                st.warning("请先在Archetype filtration Tab选择并提交Archetype")

            # 仅当前Tab是Modeling and analysis且满足条件时,渲染侧边栏
            if st.session_state.current_tab == "Modeling and analysis" and st.session_state.sidebar_score:
                weights_dict = {}
                scores_dict = {}
                with st.sidebar.header("Adjust Weights"):
                    for idx, feature in enumerate(st.session_state.selected_features):
                        unique_key = f"weight_{feature}_{idx}"
                        weights_dict[feature] = st.sidebar.slider(
                            f"{feature}", min_value=0, max_value=10, value=1, key=unique_key
                        )

                with st.sidebar.header("Adjust Scores"):
                    for feature in st.session_state.selected_features:
                        st.sidebar.subheader(feature)
                        unique_values = sorted(st.session_state.filtered_df[feature].unique())
                        scores_dict[feature] = {}
                        for idx, value in enumerate(unique_values):
                            unique_key = f"{feature}_{value}_{idx}"
                            scores_dict[feature][value] = st.sidebar.slider(
                                f"{value}",
                                min_value=0,
                                max_value=100,
                                value=50,
                                step=5,
                                key=unique_key
                            )

                # 保存权重和分数到会话状态,供下载Tab使用
                st.session_state.weights_dict = weights_dict
                st.session_state.scores_dict = scores_dict

                # 调用打分函数
                inflator_deflator_dict = get_inflator_deflator_dict(st.session_state.filtered_df)
                st.session_state.scored_df = process_and_score_data(
                    df=st.session_state.filtered_df,
                    weights_dict=weights_dict,
                    scores_dict=scores_dict,
                    feature_list=list(weights_dict.keys()),
                    inflator_deflator_dict=inflator_deflator_dict,
                )

                # 展示结果
                st.header("Scored Data")
                st.write(st.session_state.scored_df.head())

                st.header("Charts")
                selected_columns = get_selected_columns()
                column3 = selected_columns["column3"]
                st.write(f"Customer Status Column: {column3}")

                st.session_state.scored_df = create_propensity_tiers(st.session_state.scored_df)
                display_feature_distribution_charts(st.session_state.scored_df, st.session_state.selected_features, column3)
                display_additional_charts(st.session_state.scored_df, st.session_state.selected_features, column3)
        else:
            st.info("请先完成Archetype筛选")

    # Download results Tab逻辑
    with tabs[3]:
        st.session_state.current_tab = "Download results"
        if st.session_state.scored_df is not None and not st.session_state.scored_df.empty:
            st.write("Scored Data:")
            st.dataframe(st.session_state.scored_df)

            sorted_features = sorted(st.session_state.selected_features, key=lambda x: st.session_state.weights_dict.get(x, 0), reverse=True)
            for feature in sorted_features:
                st.header(f"{feature}")
                st.subheader(f"Assigned weight: {st.session_state.weights_dict.get(feature, 'N/A')}")
                st.subheader("Assigned scores:")
                scores_df = pd.DataFrame(list(st.session_state.scores_dict.get(feature, {}).items()), columns=["Category", "Score"]).sort_values(by="Score", ascending=False)
                st.table(scores_df)
        
            download_options = st.multiselect(
                "Select items to download:",
                options=["Scored data", "Weights, scores and inflator/deflators", "HTML of charts"],
                default=[]
            )

            if download_options:
                if "Scored data" in download_options:
                    scored_data_excel = BytesIO()
                    with pd.ExcelWriter(scored_data_excel, engine='xlsxwriter') as writer:
                        st.session_state.scored_df.to_excel(writer, index=False, sheet_name='Scored Data')
                    st.download_button(
                        label="Download scored data",
                        data=scored_data_excel.getvalue(),
                        file_name="scored_data.xlsx",
                        mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
                    )
                if "Weights, scores and inflator/deflators" in download_options:
                    weights_df = pd.DataFrame(list(st.session_state.weights_dict.items()), columns=["Feature", "Weight"])
                    scores_combined_df = pd.concat(
                        [pd.DataFrame({"Feature": feature, "Category": list(st.session_state.scores_dict[feature].keys()), "Score": list(st.session_state.scores_dict[feature].values())}) for feature in st.session_state.scores_dict],
                        ignore_index=True
                    )
                    combined_df = pd.merge(scores_combined_df, weights_df, on="Feature", how="left")
                    
                    # 处理Inflator/Deflator数据
                    inflator_deflator_dict = get_inflator_deflator_dict(st.session_state.filtered_df)
                    inflator_deflator_list = []
                    for feature, values in inflator_deflator_dict.items():
                        for category, value in values.items():
                            inflator_deflator_list.append({"Feature": feature, "Category": category, "Inflator/Deflator": value})
                    inflator_deflator_df = pd.DataFrame(inflator_deflator_list)
                    combined_final = pd.concat([combined_df, inflator_deflator_df], ignore_index=True)

                    # 生成下载文件
                    combined_excel = BytesIO()
                    with pd.ExcelWriter(combined_excel, engine='xlsxwriter') as writer:
                        combined_final.to_excel(writer, index=False, sheet_name='Weights & Scores')
                    st.download_button(
                        label="Download weights, scores and inflator/deflators",
                        data=combined_excel.getvalue(),
                        file_name="weights_scores_inflators.xlsx",
                        mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
                    )
        else:
            st.info("请先完成Modeling and analysis Tab的打分流程")

    # 切换到非目标Tab时,自动清空侧边栏
    if st.session_state.current_tab != "Modeling and analysis":
        st.sidebar.empty()

if __name__ == "__main__":
    main()

核心修改点说明

  • 会话状态跟踪:新增current_tab变量记录当前选中的Tab,每个Tab内部都会更新这个值,确保我们能准确判断用户所在位置。
  • 侧边栏条件渲染:只有在current_tab是「Modeling and analysis」且sidebar_score为True时,才渲染侧边栏内容;其他Tab下调用st.sidebar.empty()清空侧边栏,实现自动隐藏效果。
  • 数据持久化:把filtered_dfselected_featuresscored_df等关键数据都存在st.session_state里,彻底解决切换Tab时数据丢失的问题,保证整个流程的连贯性。

备注:内容来源于stack exchange,提问作者Varish Tripathi

火山引擎 最新活动