关注小程序 找一找教程网-随时随地学编程

C/C++教程

P4062 [Code+#1] Yazid 的新生舞会 - 线段树

题解

为啥我写个线段树还得调 1h 啊?

考虑枚举每一种颜色 \(c\)。设 \(S_i\) 为 \(a_{1\dots i}\) 中 \(c\) 的出现次数,那么一个区间 \((l,r]\) 是合法的当且仅当 \(2S_r-r>2S_l-l\)。设 \(f(x)=2S_x-x\)。按顺序枚举 \(c\) 的每一个出现位置,设这个位置为 \(p\),下一个出现位置为 \(q\)。那么,对于 \([p,q)\) 中的每个位置 \(i\),\(f(i)=f(p)+p-i\)。因此,\([p,q)\) 之间的位置不会互相产生贡献。并且,\([0,p)\) 之间的位置对 \([p,q)\) 的贡献系数是一段斜率为 \(0\) 的线段加上一段斜率为 \(-1\) 的线段。由此我们便可以快速计算以 \([p,q)\) 中的某个位置为右端点的合法线段个数了。我们需要支持的操作是:

  • 区间 \(+1\);
  • 询问区间的和以及二阶和。

线段树或树状数组都可以胜任。

代码

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define For(Ti,Ta,Tb) for(int Ti=(Ta);Ti<=(Tb);++Ti)
#define Dec(Ti,Ta,Tb) for(int Ti=(Ta);Ti>=(Tb);--Ti)
typedef long long ll;
const int N=5e5+5;
int n,tp;
struct SegmentTree{
	struct Node{
		int l,r,clean;ll s[2],Add;
	}t[N<<3];
	void Pushup(int p){
		t[p].s[0]=t[p*2].s[0]+t[p*2+1].s[0];
		t[p].s[1]=t[p*2].s[1]+t[p*2+1].s[1]+(t[p*2+1].l-t[p].l)*t[p*2+1].s[0];
	}
	void PushAdd(int p,ll k){
		t[p].s[0]+=k*(t[p].r-t[p].l+1);
		t[p].s[1]+=1LL*(t[p].r-t[p].l+2)*(t[p].r-t[p].l+1)/2*k;
		t[p].Add+=k;
	}
	void PushClean(int p){
		t[p].clean=1,t[p].Add=0;
		t[p].s[0]=t[p].s[1]=t[p].Add=0;
	}
	void Pushdown(int p){
		if(t[p].clean){
			PushClean(p*2),PushClean(p*2+1);
			t[p].clean=0;
		}
		PushAdd(p*2,t[p].Add),PushAdd(p*2+1,t[p].Add);
		t[p].Add=0;
	}
	void Build(int p,int l,int r){
		t[p].l=l,t[p].r=r;
		if(l==r) return;
		Build(p*2,l,(l+r)/2),Build(p*2+1,(l+r)/2+1,r);
	}
	void Add(int p,int l,int r,ll k){
		if(l<=t[p].l&&t[p].r<=r) return PushAdd(p,k);
		Pushdown(p);
		int mid=(t[p].l+t[p].r)>>1;
		if(l<=mid) Add(p*2,l,r,k);
		if(r>mid) Add(p*2+1,l,r,k);
		Pushup(p);
	}
	pair<ll,ll> Query(int p,int l,int r){
		if(l>t[p].r||r<t[p].l) return {0,0};
		if(l<=t[p].l&&t[p].r<=r) return {t[p].s[0],t[p].s[0]*(t[p].l-l)+t[p].s[1]};
		Pushdown(p);
		auto resl=Query(p*2,l,r),resr=Query(p*2+1,l,r);
		return {resl.first+resr.first,resl.second+resr.second};
	}
}seg;
int a[N];
vector<int> occ[N];
int main(){
	ios::sync_with_stdio(false),cin.tie(nullptr);
	cin>>n>>tp;
	int mx=0;
	For(i,1,n) cin>>a[i],++a[i],mx=max(mx,a[i]),occ[a[i]].push_back(i);
	seg.Build(1,1,(n+3)*2);
	ll ans=0;
	const int delt=n+3;
	For(i,1,n){
		if(!occ[i].size()) continue;
		seg.PushClean(1);
		seg.Add(1,delt-occ[i].front()+1,delt,1);
		occ[i].push_back(n+1);
		for(auto it=occ[i].begin();it!=prev(occ[i].end());++it){
			int cnt=it-occ[i].begin()+1,l=2*cnt-*next(it)+1,r=2*cnt-*it;
			ans+=seg.Query(1,1,delt+r).first*(r-l+1)-seg.Query(1,l+delt,r+delt).second;
			seg.Add(1,l+delt,r+delt,1);
		}
	}
	cout<<ans;
	return 0;
}