题目大意
给你一个长度为N的整数序列(数字可以为负), 定义一个连续子串的价值为子序列中出现过的数字之和(同一个数字只算一次), 问价值第K大子串的价值。
[latex] N≤100000, K≤200000 [/latex], 保证第K个子串存在。
解题报告
首先PST可以做到O(N log N)求区间权值。
然后第一个想法就是把区间[1, N]塞入优先队列, 每次取出堆头[l, r], 把[l-1, r]和[l, r-1]去重后塞回去, 反复K次。
但是由于数字可以是负数就没有单调性了…
换思路, 考虑固定右端点, 用PST做区间最大值可以快速求出固定右端点后权值最大的区间。那么可以先对每个右端点都query一下让区间价值最大的左端点, 揉到priority_queue里, 每次弹出堆头, 把堆头对应的区间从世界上抹去(PST上新开一层, 单点修改权值为-INF), 再query一下让区间价值最大的左端点, 塞优先队列里。
重复K次就可以拿到第K大值了。
代码
各种傻瓜错..断断续续写了3d…
#include <cstdio>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <iostream>
#include <set>
#include <vector>
#include <queue>
#include <map>
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const LL LLINF = 0x3f3f3f3f3f3f3f3fLL;
const int MAXN = 100010;
inline int getInt () {
int ret; char ch; bool flg = false;
while((ch = getchar()) < '0' || ch > '9') flg |= (ch == '-');
ret = ch-'0';
while((ch = getchar()) >= '0' && ch <= '9') ret = ret*10+(ch-'0');
return flg ? -ret : ret;
}
int A[MAXN];
int pre[MAXN], nxt[MAXN];
map<int, int> mp;
struct Data {
int y1, y2; LL val;
Data () {}
Data (int a, int b, LL c):y1(a), y2(b), val(c) {}
};
vector<Data> lines[MAXN];
int N, K;
namespace PST {
const int MAXD = MAXN*100;
LL mark[MAXD]; int dtot;
int lc[MAXD], rc[MAXD];
LL maxv[MAXD], maxp[MAXD];
int roots[MAXN], tot;
int agl, agr; LL agv;
inline void updata (int u) {
if((maxv[lc[u]] >= maxv[rc[u]] && lc[u]) || !rc[u]) {
maxp[u] = maxp[lc[u]], maxv[u] = maxv[lc[u]];
} else maxp[u] = maxp[rc[u]], maxv[u] = maxv[rc[u]];
maxv[u] += mark[u];
//maxv[u] = max(maxv[lc[u]], maxv[rc[u]])+mark[u];
}
void build (int&u, int l, int r) {
u = ++dtot;
if(l == r) {maxp[u] = l; return ;}
int mid = (l+r)>>1;
build(lc[u], l, mid);
build(rc[u], mid+1, r);
updata(u);
}
void init () {
build(roots[0], 1, N);
}
void insert (int&u, int lb, int rb, int r) {
if(rb < agl || agr < lb) return ;
if(!u) {
mark[u = ++dtot] = mark[r];
maxv[u] = maxv[r];
maxp[u] = maxp[r];
}
//if(!maxp[u]) maxp[u] = lb;
if(agl <= lb && rb <= agr) {
mark[u] += agv;
maxv[u] += agv;
return ;
}
int mid = (lb+rb)>>1;
insert(lc[u], lb, mid, lc[r]);
insert(rc[u], mid+1, rb, rc[r]);
}
int agp;
void erase (int&u, int lb, int rb, int r) {
if(agp < lb || rb < agp) return ;
u = ++dtot;
mark[u] = mark[r], lc[u] = lc[r], rc[u] = rc[r];
maxv[u] = maxv[r], maxp[u] = maxp[r];
if(lb == rb) {mark[u] = -LLINF; maxv[u] = -LLINF; return ;}
int mid = (lb+rb)>>1;
erase(lc[u], lb, mid, lc[r]);
erase(rc[u], mid+1, rb, rc[r]);
updata(u);
}
void handleErase (int x, int y) {
agp = y;
erase(roots[x], 1, N, roots[x]);
}
inline void newFloor () {++tot;}
void solidfy (int&u, int r) {
if(!u) {u = r; return ;}
solidfy(lc[u], lc[r]);
solidfy(rc[u], rc[r]);
if(lc[u] || rc[u]) updata(u);
}
inline void closeFloor () {
solidfy(roots[tot], roots[tot-1]);
}
inline void handleInsert (int l, int r, LL val) {
agv = val;
agl = l, agr = r;
//printf("pos:%d %d %d %d\n", tot, l, r, val);
insert(roots[tot], 1, N, roots[tot-1]);
}
int qy;
LL query (int u, int lb, int rb) {
if(!u || qy < lb || rb < qy) return 0;
if(lb == rb) return mark[u];
int mid = (lb+rb)>>1;
return mark[u]+query(lc[u], lb, mid)+query(rc[u], mid+1, rb);
}
inline LL handleQuery (int a, int b) {
qy = b;
return query(roots[a], 1, N);
}
inline int getPos (int p) {
return maxp[roots[p]];
}
inline LL getValue (int p) {
//printf("val:%d\n", maxv[roots[p]]);
return maxv[roots[p]];
}
}
struct Interval {
int l, r; LL val;
Interval () {}
Interval (int a, int b, LL v):l(a), r(b), val(v) {}
inline bool operator < (const Interval & b) const {
return val < b.val;
}
};
priority_queue<Interval> que;
set<pair<int, int> > st;
int main () {
scanf("%d%d", &N, &K);
for(int i = 1; i<=N; i++) A[i] = getInt();
for(int i = 1; i<=N; i++) {
pre[i] = mp[A[i]];
mp[A[i]] = i;
}
for(int i = 1; i<=N; i++) {
lines[pre[i]+1].push_back(Data(i, N, A[i]));
lines[i+1].push_back(Data(i, N, -A[i]));
if(i != 1) lines[i].push_back(Data(1, i-1, -INF));
if(i != 1) lines[i+1].push_back(Data(1, i-1, INF));
}
PST::init();
//PST::maxv[0] = -LLINF;
for(int i = 1; i<=N; i++) {
PST::newFloor();
for(int j = 0; j<(int)lines[i].size(); j++)
PST::handleInsert(lines[i][j].y1, lines[i][j].y2, lines[i][j].val);
PST::closeFloor();
}
PST::handleQuery(1, 1);
for(int i = 1; i<=N; i++)
que.push(Interval(i, PST::getPos(i), PST::getValue(i)));
int cur = 0;
while((++cur) < K) {
Interval inv = que.top();
//printf("at:[%d, %d]\n", inv.l, inv.r);
que.pop();
PST::handleErase(inv.l, inv.r);
que.push(Interval(inv.l, PST::getPos(inv.l), PST::getValue(inv.l)));
}
cout << que.top().val << endl;
}
/*
7 5
3 -2 1 2 2 1 3 -2
3 5
-31 -11 46
*/
Join the discussion