作者:XiaoQuQu,发表于 Tue Feb 27 2024。
写这题的时候才发现之前写过的“SA”是使用 std::sort()
版的 $n\log ^2 n$ 的后缀数组,然后爆补 $n \log n$ 的计数排序。发现对于计数排序还是有很多不理解的地方,写一篇题解,怕自己啥时候又忘了。
阅读这篇题解可能需要你对后缀数组有一定的基础认识,可以前往 OI-Wiki 等查看“后缀数组”的介绍,本篇题解只是作为对于其他博客等的补充。
考虑排序的本质,其实是求出每个数的排名。
计数排序先用桶储存了所有数的出现次数,如序列 $a=[2,2,1,2,3,4]$,桶数组 $b=[1,3,1,1]$,对 $b$ 数组做前缀和,得到前缀和数组 $c=[1,4,5,6]$。
接下来我们从最大数到最小数枚举 $i$,则我们发现,所有数 $i$ 的排名都应该为 $c_{i-1}+1$。
这一部分在阅读其他博文时可能有“看起来很简单,写起来很复杂”的感觉,在这里我尝试尽量用贴近代码语言的方式表述。
考虑倍增地求后缀数组,枚举 $k=2^w$,保证 $k\le n$,假设我们已知所有长度为 $2^{w-1}$ 的字符串的排名,从 $i$ 开始的字符串排名记为 $r_i$,我们希望求长度为 $2^w$ 次方的字符串的排名。
考虑如何比较两个从 $i,j$ 开始的字符串的大小,其实相当于比较 $r_i,r_j$ 的大小,若相等比较 $r_{i+2^{w-1}},r_{j+2^{w-1}}$ 的大小。
所以我们有一个很朴素的思想,也就是按照 $r_{i+2^{w-1}}$ 的大小作为第二关键字,$r_i$ 的大小作为第一关键字,然后进行排序,注意表述顺序,因为我们这里使用 LSD 进行基数排序,会先对优先级低的关键字比较。
但是这样的常数会很大,考虑我们怎么样省略掉比较第二关键字的步骤。
我们发现,可以枚举 $i$,对于 $i+k>n$ 的 $i$,他按照第二关键字排名后肯定是在最前面的,因为其第二关键字为 $0$。
对于 $i+k<n$ 的怎么办?我们可以枚举 $i$。若在这一轮倍增之前第 $i$ 小的字符串起始于 $sa_i$ 且 $sa_i-k>0$,那么我们就可以肯定,对于 $sa_i-k$ 这一项,排完序后在 $i+k<n$ 的部分的排名是 $i$。
这样我们就省略了对于第二关键字的排序,直接对于第一关键字排序即可,具体写起来是这样的。
int p = 0;
for (int i = n; i + (1 << w) > n; --i) id[++p] = i; // for i + k > n
for (int i = 1; i <= n; ++i)
if (sa[i] > (1 << w)) id[++p] = sa[i] - (1 << w); // 第 i 个位置的数字应该为 sa[i] - k
设对于第二关键字排好序时,第 $i$ 小的字符串起始于 $id_i$。考虑计数排序,可以直接按照 $r_{id_i}$ 进行计数排序,得到的结果直接存在 $sa_i$ 中。
接下来考虑更新 $r$ 数组,考虑从小到大枚举 $i$,然后直接将 $r_{sa_i}$ 更新为 $i$。但是这样会有问题。即有些字符串是相同的,我们需要对这些字符串进行去重。
考虑我们是如何比较两个字符串的大小,发现判断两个以 $i,j$ 开头的字符串是否相同,可以直接判断旧的 $r_i,r_j$ 与 $r_{i+k},r_{j+k}$ 是否相同即可。
完整代码如下。
const int MAXN = 2e6 + 5, MAXD = 256;
int n, sa[MAXN], rk[MAXN], cnt[MAXN], oldrk[MAXN], id[MAXN];
char s[MAXN];
void work() {
cin >> (s + 1); n = strlen(s + 1);
for (int i = 1; i <= n; ++i) cnt[rk[i] = s[i]]++;
for (int i = 1; i <= MAXD; ++i) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
for (int w = 0; (1 << w) <= n; ++w) {
int p = 0;
for (int i = n; i + (1 << w) > n; --i) id[++p] = i;
for (int i = 1; i <= n; ++i)
if (sa[i] > (1 << w)) id[++p] = sa[i] - (1 << w);
for (int i = 0; i <= p; ++i) cnt[i] = 0;
for (int i = 1; i <= n; ++i) ++cnt[rk[id[i]]];
for (int i = 1; i <= p; ++i) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i];
p = 0;
for (int i = 1; i <= n; ++i) oldrk[i] = rk[i];
for (int i = 1; i <= n; ++i) {
if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + (1 << w)] == oldrk[sa[i - 1] + (1 << w)])
rk[sa[i]] = p;
else rk[sa[i]] = ++p;
}
if (p == n) break;
}
for (int i = 1; i <= n; ++i) cout << sa[i] << ' ';
}
Copyright © 2024 LVJ, Open-Source Project. 本站内容在无特殊说明情况下均遵循 CC-BY-SA 4.0 协议,内容版权归属原作者。