1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
| #include <bits/stdc++.h> #define ms(a, x) memset(a, x, sizeof(a)) typedef long long LL;
using namespace std; const int N = 5e5 + 7; struct SAM { #define MAXALP 30 struct state { int len, link, nxt[MAXALP]; int leftmost; int rightmost; int Right; }; state st[N * 2]; char S[N]; int sz, last, rt; char s[N]; int cnt[2 * N], rk[2 * N]; int idx(char c){ if (c >= 'a' && c <= 'z') return c - 'a'; return c - 'A' + 26; } void init(){ sz = 0; ms(st, 0); last = rt = ++sz; st[1].len = 0; st[1].link = -1; st[1].rightmost = 0; ms(st[1].nxt, -1); } void extend(int c, int head){ int cur = ++sz; st[cur].len = st[last].len + 1; st[cur].leftmost = st[cur].rightmost = head; memset(st[cur].nxt, -1, sizeof(st[cur].nxt)); int p; for (p = last; p != -1 && st[p].nxt[c] == -1; p = st[p].link) st[p].nxt[c] = cur; if (p == -1) { st[cur].link = rt; } else { int q = st[p].nxt[c]; if (st[p].len + 1 == st[q].len) {
st[cur].link = q; } else { int clone = ++sz; st[clone].len = st[p].len + 1; st[clone].link = st[q].link; memcpy(st[clone].nxt, st[q].nxt, sizeof(st[q].nxt)); st[clone].leftmost = st[q].leftmost; st[clone].rightmost = st[q].rightmost; for (; p != -1 && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } void build(){ init(); for (int i = 0, _len = strlen(S); i < _len; i++) { st[sz + 1].Right = 1; extend(idx(S[i]), i); } } void topo(){ ms(cnt, 0); for (int i = 1; i <= sz; i++) cnt[st[i].len]++; for (int i = 1; i <= sz; i++) cnt[i] += cnt[i - 1]; for (int i = 1; i <= sz; i++) rk[cnt[st[i].len]--] = i; } void pre(){ for (int i = sz; i >= 2; i--) { int v = rk[i]; int fa = st[v].link; if (fa == -1) continue; st[fa].rightmost = max(st[fa].rightmost, st[v].rightmost); st[fa].Right += st[v].Right; } } void solve(){ LL ans = 0; for (int i = sz; i >= 2; i--) { int v = rk[i]; if (st[v].link == -1) continue; ans = ans + 1LL * st[v].Right * (st[v].Right + 1) / 2 * (st[v].len - st[st[v].link].len); } printf("%lld\n", ans); } } A; char B[N];
int main(){ scanf("%s", A.S); A.build(); A.topo(); A.pre(); A.solve();
return 0; }
|