树中节点被路径遍历的最大次数统计问题
解决树路径覆盖最大次数问题
这个问题我之前处理大规模树数据时碰到过,暴力枚举每条路径的所有节点肯定超时——毕竟N和k都能到1e5,O(k×路径长度)的复杂度完全扛不住。这里给你推荐一个高效解法:树上差分+LCA(最近公共祖先),时间复杂度能做到O(N + k logN),完美适配大数据规模。
核心思路
树上差分的本质是利用「区间修改,单点查询」的思想,把路径覆盖的计数转化为几个点的差分操作,最后通过一次遍历完成计数累加。对于任意路径(u, v),我们可以把它拆成两段:u到LCA(u, v),以及v到LCA(u, v)的父节点。通过对这几个关键节点的差分操作,就能间接统计所有节点的覆盖次数。
具体步骤
1. 预处理LCA
要快速找到任意两个节点的最近公共祖先,推荐用倍增法预处理:
- 先通过一次DFS记录每个节点的深度和直接父节点(2^0级祖先)
- 再通过动态规划填充倍增数组,
up[k][u]表示节点u的2^k级祖先,这样查询LCA的时间复杂度是O(logN)
2. 初始化差分数组
创建一个长度为N+1的数组cnt(假设节点编号从1开始),初始值全为0。
3. 处理每条路径
对每条路径(u, v)执行以下操作:
- 计算
l = LCA(u, v) cnt[u] += 1:标记路径起点cnt[v] += 1:标记路径终点cnt[l] -= 1:抵消LCA被重复计数的部分- 如果l不是根节点,
cnt[parent[l]] -= 1:抵消LCA父节点被错误计数的部分
4. 计算最终覆盖次数
通过后序遍历(先处理所有子节点,再处理父节点),把当前节点的cnt值加上所有子节点的cnt值,得到的就是该节点被路径覆盖的总次数。
5. 找到最大值
遍历所有节点的最终cnt值,取最大的那个就是答案。
示例验证
用题目中的例子来验证:
- 树结构:1是根,子节点2、3;2的子节点4、5
- 路径:(1,5)、(2,3)
处理路径(1,5):
- LCA是1,所以
cnt[1] +=1、cnt[5] +=1、cnt[1] -=1,根节点无父节点,无需额外操作。此时cnt数组:[0,0,0,0,0,1]
处理路径(2,3):
- LCA是1,所以
cnt[2] +=1、cnt[3] +=1、cnt[1] -=1,根节点无父节点。此时cnt数组:[0,-1,1,1,0,1]
后序遍历累加:
- 节点4:
cnt=0(无子女) - 节点5:
cnt=1(无子女) - 节点2:
cnt=1 + 0 +1 =2 - 节点3:
cnt=1(无子女) - 节点1:
cnt=-1 +2 +1=2
最终最大值是2,和题目示例一致。
代码示例(Python风格)
import math def main(): import sys input = sys.stdin.read data = input().split() idx = 0 n = int(data[idx]) idx +=1 k = int(data[idx]) idx +=1 # 构建邻接表 adj = [[] for _ in range(n+1)] for _ in range(n-1): u = int(data[idx]) v = int(data[idx+1]) adj[u].append(v) adj[v].append(u) idx +=2 # 预处理LCA(倍增法) log_max = math.floor(math.log2(n)) +1 depth = [0]*(n+1) up = [[-1]*(n+1) for _ in range(log_max)] # 第一次DFS填充depth和up[0] stack = [(1, -1)] while stack: u, p = stack.pop() up[0][u] = p for v in adj[u]: if v != p: depth[v] = depth[u]+1 stack.append((v, u)) # 填充倍增数组 for k_level in range(1, log_max): for u in range(1, n+1): if up[k_level-1][u] != -1: up[k_level][u] = up[k_level-1][up[k_level-1][u]] # LCA查询函数 def get_lca(u, v): if depth[u] < depth[v]: u, v = v, u # 把u升到和v同深度 for k_level in range(log_max-1, -1, -1): if depth[u] - (1 << k_level) >= depth[v]: u = up[k_level][u] if u == v: return u # 一起往上跳 for k_level in range(log_max-1, -1, -1): if up[k_level][u] != -1 and up[k_level][u] != up[k_level][v]: u = up[k_level][u] v = up[k_level][v] return up[0][u] # 处理路径,更新差分数组 cnt = [0]*(n+1) for _ in range(k): u = int(data[idx]) v = int(data[idx+1]) idx +=2 l = get_lca(u, v) cnt[u] +=1 cnt[v] +=1 cnt[l] -=1 if up[0][l] != -1: cnt[up[0][l]] -=1 # 后序遍历计算最终计数,同时找最大值 max_count = 0 stack = [(1, -1, False)] while stack: u, p, visited = stack.pop() if visited: # 累加子节点的cnt for v in adj[u]: if v != p: cnt[u] += cnt[v] if cnt[u] > max_count: max_count = cnt[u] else: stack.append((u, p, True)) # 倒序压入子节点,保证处理顺序正确 for v in reversed(adj[u]): if v != p: stack.append((v, u, False)) print(max_count) if __name__ == "__main__": main()
内容的提问来源于stack exchange,提问作者DebD




