我们使用一道例题来帮助我们理解莫队算法

洛谷P1972

例题

题目描述:

HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,
他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,
他的项链变得越来越长。有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?
这个问题很难回答……因为项链实在是太长了。于是,他只好求助睿智的你来解决这个问题。

输入输出格式:

输入格式:

第一行:一个整数N,表示项链的长度。

第二行:N 个整数,表示依次表示项链中贝壳的编号(编号为0 到1000000 之间的整数)。

第三行:一个整数M,表示HH 询问的个数。

接下来M 行:每行两个整数,L 和R(1 ≤ L ≤ R ≤ N),表示询问的区间。

输出格式:

M 行,每行一个整数,依次表示询问对应的答案。

输入输出样例

输入样例#1:

6
1 2 3 4 3 5
3
1 2
3 5
2 6

输出样例#1:

2
2
4

引入

想象一下,如果用暴力算法进行求解的话,我们会选择开一个$cnt$数组,遍历区间,累加求解,

显然,这样的时间复杂度太高,肯定会爆掉,所以我们需要莫队算法


首先,我们定义两个指针$curl$,$curr$,每次询问我们通过移动这两个指针来框定区间,

假设一开始 $curl$ 指向 $4$ ,$curr$ 指向 $6$,

下一个询问要求区间 $3$ ~ $5$,那么我们 $curl–$ ,顺带插入 $3$ ,$curr–$,顺带删去 $6$,

注意当前的先后顺序
$curl–$ 要求先减后加入,而 $curr–$ 要求先删去再减

同理,$curl++$ 要求先删去再加,而 $curr++$ 要求先加再加入

我们可以写出这一部分

int lsans;//表示该区间的答案
inline void add(int pos){lsans+=(++cnt[a[pos]]==1);}//加入
inline void del(int pos){lsans-=(--cnt[a[pos]]==0);}//删去


while(curr<rr) add(++curr);//变化范围
while(curr>rr) del(curr--);
while(curl>ll) add(--curl);
while(curl<ll) del(curl++);

莫队的优化

我们可以很容易的发现,如果面对特别设计的数据,上面的时间复杂度仍然很高

举个栗子,有6个询问如下:

(1, 100)  (2, 2)  (3, 99)  (4, 4)  (5, 102)  (6, 7)

我们如果直接按左端点上升排序,

用上述方法处理时,左端点会移动$6$次,右端点会移动移动$98+97+95+98+95=483$次。

我们可以先按左端点上升排序,如果左端点所在的块相同,再在块内按右端点上升排序,得到结果就像这样

(2, 2)  (4, 4)  (6, 7)  (5, 102)  (3, 99)  (1, 100)

左端点移动次数为$2+2+1+2+2=9$次,比原来稍多。右端点移动次数为$2+3+95+3+1=104$,右端点的移动次数大大降低了。

$Code$:

struct ques
{
    int l , r , id ;
}que[100005];

bool cmp(const ques &a,const ques &b) 
{
    return (a.l/blo==b.l/blo)?a.r<b.r:a.l<b.l;
}

最终的代码

#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define sync_with_stdio(false)
using namespace std;
inline int read(){
   int s=0,w=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
   while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
   return s*w;
}
struct ques
{
    int l , r , id ;
}que[100005];
int n,m, blo,lsans;
bool cmp(const ques &a,const ques &b) {return (a.l/blo==b.l/blo)?a.r<b.r:a.l<b.l;}
int a[100005],cnt[1000005],ans[100005];
inline void add(int pos){lsans+=(++cnt[a[pos]]==1);}
inline void del(int pos){lsans-=(--cnt[a[pos]]==0);}
int main()
{
    n=read();
    for(int i=1;i<=n;i++) a[i]=read();
    m=read();blo=sqrt(m);
    for(int i=1;i<=m;i++)
    {
        que[i].l=read(),que[i].r=read();
        que[i].id=i;
    }
    sort(que+1,que+m+1,cmp);
    int curl=0,curr=0;
    for(int i=1;i<=m;i++)
    {
        int ll=que[i].l,rr=que[i].r,idd=que[i].id;
        while(curr<rr) add(++curr);
        while(curr>rr) del(curr--);
        while(curl>ll) add(--curl);
        while(curl<ll) del(curl++);
        ans[idd]=lsans;
    }
    for(int i=1;i<=m;i++) printf("%d\n", ans[i]);
}

CODE_LIFE