带更新操作的区间逆序数统计问题求解问询
带更新的区间逆序数查询高效解决方案
针对你提出的带更新操作的区间逆序数查询问题,这里分享一个利用元素值范围小(≤40)的特性设计的线段树方案,能达到**O(q log n)**的时间复杂度,完美适配题目中的1e5规模约束。
问题回顾
给定大小为n的整数数组arr,处理q次两种查询:
- 类型1:输入l、r,输出区间[l, r]内的逆序数(即满足j < i且arr[j] > arr[i]的数对数量)
- 类型2:输入x、y,将arr[x]的值更新为y
输入示例:
n=5,q=3,arr={1,4,3,5,2}
- 类型1,l=1,r=5 → 输出4
- 类型2,x=1,y=4
- 类型1,l=1,r=5 → 输出6
约束:1 ≤ n, q ≤ 1e5,arr[i], y ≤40,时间限制4秒。
核心思路
普通的线段树或树状数组在处理带更新的区间逆序数时,时间复杂度难以达标,但这里元素值范围极小(仅40种可能),我们可以给线段树的每个节点存储额外信息来优化:
每个线段树节点包含两个字段:
inversion_count:当前区间内的逆序数frequency[40]:当前区间内每个数值的出现次数
线段树构建与更新
- 叶子节点:对应数组中的单个元素,
inversion_count为0,frequency数组中对应元素值的位置设为1,其余为0。 - 父节点合并:
frequency数组直接是左右子节点frequency的逐位相加inversion_count等于左右子节点各自的逆序数之和,再加上左右子段合并时新增的逆序数——遍历左子段的所有较大值和右子段的所有较小值,计算它们的出现次数乘积之和。
- 更新操作:找到目标位置的叶子节点,更新其
frequency,然后自底向上重新合并所有父节点的信息即可。
区间查询
查询区间[l, r]时,将区间拆分为线段树中的若干子区间,合并这些子区间的节点信息,最终得到的节点的inversion_count就是目标区间的逆序数。
因为每次合并节点时处理frequency的时间是O(40²)(常数时间),所以每个查询和更新操作的时间复杂度都是O(log n),整体复杂度为O(q log n),完全满足时间要求。
实现代码
/** Lost Arrow (Aryan V S) Saturday 2020-10-10 **/ #include "bits/stdc++.h" using namespace std; struct node { int64_t inv = 0; vector <int> freq = vector <int> (40, 0); void combine (const node& l, const node& r) { inv = l.inv + r.inv; for (int i = 39; i >= 0; --i) { for (int j = 0; j < i; ++j) { // frequency of bigger numbers in the left * frequency of smaller numbers on the right inv += 1LL * l.freq [i] * r.freq [j]; } freq [i] = l.freq [i] + r.freq [i]; } } }; void build (vector <node>& tree, vector <int>& a, int v, int tl, int tr) { if (tl == tr) { tree [v].inv = 0; tree [v].freq [a [tl]] = 1; } else { int tm = (tl + tr) / 2; build(tree, a, 2 * v + 1, tl, tm); build(tree, a, 2 * v + 2, tm + 1, tr); tree [v].combine(tree [2 * v + 1], tree [2 * v + 2]); } } void update (vector <node>& tree, int v, int tl, int tr, int pos, int val) { if (tl == tr) { tree [v].inv = 0; tree [v].freq = vector <int> (40, 0); tree [v].freq [val] = 1; } else { int tm = (tl + tr) / 2; if (pos <= tm) update(tree, 2 * v + 1, tl, tm, pos, val); else update(tree, 2 * v + 2, tm + 1, tr, pos, val); tree [v].combine(tree [2 * v + 1], tree [2 * v + 2]); } } node inv_cnt (vector <node>& tree, int v, int tl, int tr, int l, int r) { if (l > r) return node(); if (tl == l && tr == r) return tree [v]; int tm = (tl + tr) / 2; node result; result.combine(inv_cnt(tree, 2 * v + 1, tl, tm, l, min(r, tm)), inv_cnt(tree, 2 * v + 2, tm + 1, tr, max(l, tm + 1), r)); return result; } void solve () { int n, q; cin >> n >> q; vector <int> a (n); for (int i = 0; i < n; ++i) { cin >> a [i]; --a [i]; } vector <node> tree (4 * n); build(tree, a, 0, 0, n - 1); while (q--) { int type, x, y; cin >> type >> x >> y; --x; --y; if (type == 1) { node result = inv_cnt(tree, 0, 0, n - 1, x, y); cout << result.inv << '\n'; } else if (type == 2) { update(tree, 0, 0, n - 1, x, y); } else assert(false); } } int main () { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.precision(10); std::cout << std::fixed << std::boolalpha; int t = 1; // std::cin >> t; while (t--) solve(); return 0; }
说明:代码未完全遵循严格编程规范,仅针对该问题实现核心功能。
内容的提问来源于Stack Exchange,提问作者Arrow




