题目大意
给你一个长度为N的整数序列(数字可以为负), 定义一个连续子串的价值为子序列中出现过的数字之和(同一个数字只算一次), 问价值第K大子串的价值。
[latex] N≤100000, K≤200000 [/latex], 保证第K个子串存在。
解题报告
首先PST可以做到O(N log N)求区间权值。
然后第一个想法就是把区间[1, N]塞入优先队列, 每次取出堆头[l, r], 把[l-1, r]和[l, r-1]去重后塞回去, 反复K次。
但是由于数字可以是负数就没有单调性了…
换思路, 考虑固定右端点, 用PST做区间最大值可以快速求出固定右端点后权值最大的区间。那么可以先对每个右端点都query一下让区间价值最大的左端点, 揉到priority_queue里, 每次弹出堆头, 把堆头对应的区间从世界上抹去(PST上新开一层, 单点修改权值为-INF), 再query一下让区间价值最大的左端点, 塞优先队列里。
重复K次就可以拿到第K大值了。
代码
各种傻瓜错..断断续续写了3d…
#include <cstdio> #include <cstring> #include <climits> #include <cstdlib> #include <iostream> #include <set> #include <vector> #include <queue> #include <map> using namespace std; typedef long long LL; const int INF = 0x3f3f3f3f; const LL LLINF = 0x3f3f3f3f3f3f3f3fLL; const int MAXN = 100010; inline int getInt () { int ret; char ch; bool flg = false; while((ch = getchar()) < '0' || ch > '9') flg |= (ch == '-'); ret = ch-'0'; while((ch = getchar()) >= '0' && ch <= '9') ret = ret*10+(ch-'0'); return flg ? -ret : ret; } int A[MAXN]; int pre[MAXN], nxt[MAXN]; map<int, int> mp; struct Data { int y1, y2; LL val; Data () {} Data (int a, int b, LL c):y1(a), y2(b), val(c) {} }; vector<Data> lines[MAXN]; int N, K; namespace PST { const int MAXD = MAXN*100; LL mark[MAXD]; int dtot; int lc[MAXD], rc[MAXD]; LL maxv[MAXD], maxp[MAXD]; int roots[MAXN], tot; int agl, agr; LL agv; inline void updata (int u) { if((maxv[lc[u]] >= maxv[rc[u]] && lc[u]) || !rc[u]) { maxp[u] = maxp[lc[u]], maxv[u] = maxv[lc[u]]; } else maxp[u] = maxp[rc[u]], maxv[u] = maxv[rc[u]]; maxv[u] += mark[u]; //maxv[u] = max(maxv[lc[u]], maxv[rc[u]])+mark[u]; } void build (int&u, int l, int r) { u = ++dtot; if(l == r) {maxp[u] = l; return ;} int mid = (l+r)>>1; build(lc[u], l, mid); build(rc[u], mid+1, r); updata(u); } void init () { build(roots[0], 1, N); } void insert (int&u, int lb, int rb, int r) { if(rb < agl || agr < lb) return ; if(!u) { mark[u = ++dtot] = mark[r]; maxv[u] = maxv[r]; maxp[u] = maxp[r]; } //if(!maxp[u]) maxp[u] = lb; if(agl <= lb && rb <= agr) { mark[u] += agv; maxv[u] += agv; return ; } int mid = (lb+rb)>>1; insert(lc[u], lb, mid, lc[r]); insert(rc[u], mid+1, rb, rc[r]); } int agp; void erase (int&u, int lb, int rb, int r) { if(agp < lb || rb < agp) return ; u = ++dtot; mark[u] = mark[r], lc[u] = lc[r], rc[u] = rc[r]; maxv[u] = maxv[r], maxp[u] = maxp[r]; if(lb == rb) {mark[u] = -LLINF; maxv[u] = -LLINF; return ;} int mid = (lb+rb)>>1; erase(lc[u], lb, mid, lc[r]); erase(rc[u], mid+1, rb, rc[r]); updata(u); } void handleErase (int x, int y) { agp = y; erase(roots[x], 1, N, roots[x]); } inline void newFloor () {++tot;} void solidfy (int&u, int r) { if(!u) {u = r; return ;} solidfy(lc[u], lc[r]); solidfy(rc[u], rc[r]); if(lc[u] || rc[u]) updata(u); } inline void closeFloor () { solidfy(roots[tot], roots[tot-1]); } inline void handleInsert (int l, int r, LL val) { agv = val; agl = l, agr = r; //printf("pos:%d %d %d %d\n", tot, l, r, val); insert(roots[tot], 1, N, roots[tot-1]); } int qy; LL query (int u, int lb, int rb) { if(!u || qy < lb || rb < qy) return 0; if(lb == rb) return mark[u]; int mid = (lb+rb)>>1; return mark[u]+query(lc[u], lb, mid)+query(rc[u], mid+1, rb); } inline LL handleQuery (int a, int b) { qy = b; return query(roots[a], 1, N); } inline int getPos (int p) { return maxp[roots[p]]; } inline LL getValue (int p) { //printf("val:%d\n", maxv[roots[p]]); return maxv[roots[p]]; } } struct Interval { int l, r; LL val; Interval () {} Interval (int a, int b, LL v):l(a), r(b), val(v) {} inline bool operator < (const Interval & b) const { return val < b.val; } }; priority_queue<Interval> que; set<pair<int, int> > st; int main () { scanf("%d%d", &N, &K); for(int i = 1; i<=N; i++) A[i] = getInt(); for(int i = 1; i<=N; i++) { pre[i] = mp[A[i]]; mp[A[i]] = i; } for(int i = 1; i<=N; i++) { lines[pre[i]+1].push_back(Data(i, N, A[i])); lines[i+1].push_back(Data(i, N, -A[i])); if(i != 1) lines[i].push_back(Data(1, i-1, -INF)); if(i != 1) lines[i+1].push_back(Data(1, i-1, INF)); } PST::init(); //PST::maxv[0] = -LLINF; for(int i = 1; i<=N; i++) { PST::newFloor(); for(int j = 0; j<(int)lines[i].size(); j++) PST::handleInsert(lines[i][j].y1, lines[i][j].y2, lines[i][j].val); PST::closeFloor(); } PST::handleQuery(1, 1); for(int i = 1; i<=N; i++) que.push(Interval(i, PST::getPos(i), PST::getValue(i))); int cur = 0; while((++cur) < K) { Interval inv = que.top(); //printf("at:[%d, %d]\n", inv.l, inv.r); que.pop(); PST::handleErase(inv.l, inv.r); que.push(Interval(inv.l, PST::getPos(inv.l), PST::getValue(inv.l))); } cout << que.top().val << endl; } /* 7 5 3 -2 1 2 2 1 3 -2 3 5 -31 -11 46 */
Join the discussion