CF123D_SAM_算法日常[27/100]

题目链接

VJ
CF

题意

  • 如果字符串y在字符串x中出现n次,那么F(x,y)=n*(n+1)/2 (可以看做是一个长为n的区间,求滑动区块的总个数)
  • 现在给一个字符串,求所有的F(s,x)的和,x为字符串的所有不相同的子串.

思路

  • 直接SAM
  • right[v]就是SAM上状态表示的所有字符串出现的次数
  • 那么每个状态的答案就是right[v](right[v]+1)/2*(st[v].len-st[st[v].link].len)
  • 前面right[v](right[v]+1)/2是串的组合
  • 后面是 st[v].len - st[st[v].link].len是后缀的前缀长度,是本质不同的串的贡献
  • 也即后缀的前缀每个字母的贡献—>就是 每个后缀节点t跳父亲节点fa跳掉的那部分t的前缀 中的 以每一个字母开头的串t的后缀 都是和串t所在状态节点出现次数(前面的串的组合数)相同的!
  • 累加答案完成计算

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <bits/stdc++.h>
#define ms(a, x) memset(a, x, sizeof(a))
typedef long long LL;

using namespace std;
const int N = 5e5 + 7;
struct SAM {
#define MAXALP 30
struct state {
int len, link, nxt[MAXALP];
int leftmost; //某个状态的right集合中r值最小的
int rightmost; //某个状态的right集合的r的最大值
int Right; //right集合大小
};
state st[N * 2];
char S[N];
int sz, last, rt;
char s[N];
int cnt[2 * N], rk[2 * N]; //for radix sort
int idx(char c){
if (c >= 'a' && c <= 'z')
return c - 'a';
return c - 'A' + 26;
}
void init(){
sz = 0;
ms(st, 0);
last = rt = ++sz;
st[1].len = 0;
st[1].link = -1;
st[1].rightmost = 0;
ms(st[1].nxt, -1);
}
void extend(int c, int head){
int cur = ++sz;
st[cur].len = st[last].len + 1;
st[cur].leftmost = st[cur].rightmost = head;
memset(st[cur].nxt, -1, sizeof(st[cur].nxt));
int p;
for (p = last; p != -1 && st[p].nxt[c] == -1; p = st[p].link)
st[p].nxt[c] = cur;
if (p == -1) {
st[cur].link = rt;
} else {
int q = st[p].nxt[c];
if (st[p].len + 1 == st[q].len) {

st[cur].link = q;
} else {
int clone = ++sz;
st[clone].len = st[p].len + 1;
st[clone].link = st[q].link;
memcpy(st[clone].nxt, st[q].nxt, sizeof(st[q].nxt));
st[clone].leftmost = st[q].leftmost;
st[clone].rightmost = st[q].rightmost;
for (; p != -1 && st[p].nxt[c] == q; p = st[p].link)
st[p].nxt[c] = clone;
st[q].link = st[cur].link = clone;
}
}
last = cur;
}
void build(){
init();
for (int i = 0, _len = strlen(S); i < _len; i++) {
st[sz + 1].Right = 1;
extend(idx(S[i]), i);
}
}
void topo(){
ms(cnt, 0);
for (int i = 1; i <= sz; i++) cnt[st[i].len]++;
for (int i = 1; i <= sz; i++) cnt[i] += cnt[i - 1];
//rk[1]是len最小的状态的标号
for (int i = 1; i <= sz; i++) rk[cnt[st[i].len]--] = i;
}
//跑拓扑序,预处理一些东西
void pre(){
for (int i = sz; i >= 2; i--) {
int v = rk[i];
int fa = st[v].link;
if (fa == -1) continue;
st[fa].rightmost = max(st[fa].rightmost, st[v].rightmost);
st[fa].Right += st[v].Right;
}
}
void solve(){
LL ans = 0;
for (int i = sz; i >= 2; i--) {
int v = rk[i];
if (st[v].link == -1) continue;
// 前面是串的组合
// 后面是 st[v].len - st[st[v].link].len是后缀的前缀,是本质不同的串的贡献
// 每个字母的贡献--->就是每个后缀节点t跳父亲节点fa跳掉的那部分t的前缀中的每一个字母开头的后缀都是和串t出现次数相同的!
ans = ans + 1LL * st[v].Right * (st[v].Right + 1) / 2 * (st[v].len - st[st[v].link].len);
// cout<<"TEST: "<<st[v].len - st[st[v].link].len<<endl;
}
printf("%lld\n", ans);
}
} A;
char B[N];

int main(){
scanf("%s", A.S);
A.build();
A.topo();
A.pre();
A.solve();

return 0;
}

参考

111qqz

每天一句叨叨

路还长,别太狂,以后指不定谁辉煌