ac 自动机有两种,trie 树版和 trie 图版。

注意事项:

  1. trie 树版可以直接将根(空串的节点)的下标设为 1,而 trie 图版若将根设为 1 还需将 son 数组的初值全部设为 1。

  2. 在函数 evafail 中,首先插入队列的不应该是根节点,而应该是根节点的儿子节点。这与 kmp 的函数 evafail 中变量 i 从 2 开始的原因一致。

trie 树版

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10];
int cntep[MAXsumlenb + 10];
void insert(int lenb, char *b) {
int cur = 0;
for (int i = 1; i <= lenb; ++i) {
if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd;
cur = son[cur][b[i]];
}
++cntep[cur];
}
int l, r; int que[MAXsumlenb + 10];
void evafail() {
l = 1, r = 0;
for (int i = 0; i < 26; ++i) {
if (son[0][i]) {
fail[son[0][i]] = 0;
que[++r] = son[0][i];
}
}
while (l <= r) {
int i = que[l++];
for (int k = 0; k < 26; ++k) {
if (son[i][k]) {
int j = fail[i];
while (j && son[j][k] == 0) j = fail[j];
if (son[j][k]) j = son[j][k];
fail[son[i][k]] = j;
que[++r] = son[i][k];
}
}
}
}
bool vis[MAXsumlenb + 10];
ll match(int lena, char *a) {
ll ans = 0;
int j = 0;
for (int i = 1; i <= lena; ++i) {
while (j && son[j][a[i]] == 0) j = fail[j];
if (son[j][a[i]]) j = son[j][a[i]];
// ...
}
return ans;
}

trie 图版

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10];
int cntep[MAXsumlenb + 10];
void insert(int lenb, char *b) {
int cur = 0;
for (int i = 1; i <= lenb; ++i) {
if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd;
cur = son[cur][b[i]];
}
++cntep[cur];
}
int l, r; int que[MAXsumlenb + 10];
void evafail() {
l = 1, r = 0;
for (int i = 0; i < 26; ++i) {
if (son[0][i]) {
fail[son[0][i]] = 0;
que[++r] = son[0][i];
}
}
while (l <= r) {
int i = que[l++];
for (int k = 0; k < 26; ++k) {
if (son[i][k]) {
fail[son[i][k]] = son[fail[i]][k];
que[++r] = son[i][k];
} else {
son[i][k] = son[fail[i]][k];
}
}
}
}
bool vis[MAXsumlenb + 10];
ll match(int lena, char *a) {
ll ans = 0;
int j = 0;
for (int i = 1; i <= lena; ++i) {
j = son[j][a[i]];
// ...
}
return ans;
}

题目

Luogu P3808 【模板】AC 自动机(简单版)(注:该题数据过水,不建议作为测试代码是否正确的标准)

以 trie 树版为例,代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXn = 1e6;
const int MAXlena = 1e6;
const int MAXlenb = 1e6;
const int MAXsumlenb = 1e6;

template <typename T>
inline void read(T &a) {
char c;for (c = getchar(); (c < '0' || c > '9') && c != '-'; c = getchar());bool f = c == '-';T x = f ? 0 : (c ^ '0');for (c = getchar(); c >= '0' && c <= '9'; c = getchar()) {x = x * 10 + (c ^ '0');}a = f ? -x : x;
}
template <typename T, typename ...Argv>
inline void read(T &a, Argv &...argv) {
read(a), read(argv...);
}

int head[MAXsumlenb + 10], cntnex, nex[MAXsumlenb + 10], to[MAXsumlenb + 10];
inline void connect(int u, int v) {
nex[++cntnex] = head[u];
head[u] = cntnex;
to[cntnex] = v;
}

int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10];
int cntep[MAXsumlenb + 10];
void insert(int lenb, char *b) {
int cur = 0;
for (int i = 1; i <= lenb; ++i) {
if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd;
cur = son[cur][b[i]];
}
++cntep[cur];
}
int l, r; int que[MAXsumlenb + 10];
void evafail() {
l = 1, r = 0;
for (int i = 0; i < 26; ++i) {
if (son[0][i]) {
fail[son[0][i]] = 0;
que[++r] = son[0][i];
}
}
while (l <= r) {
int i = que[l++];
for (int k = 0; k < 26; ++k) {
if (son[i][k]) {
int j = fail[i];
while (j && son[j][k] == 0) j = fail[j];
if (son[j][k]) j = son[j][k];
fail[son[i][k]] = j;
que[++r] = son[i][k];
}
}
}
}
bool vis[MAXsumlenb + 10];
ll match(int lena, char *a) {
ll ans = 0;
int j = 0;
for (int i = 1; i <= lena; ++i) {
while (j && son[j][a[i]] == 0) j = fail[j];
if (son[j][a[i]]) j = son[j][a[i]];
int k = j;
while (k && vis[k] == 0) {
vis[k] = 1;
ans += cntep[k];
k = fail[k];
}
}
return ans;
}

int n;
int lena; char a[MAXlena + 10];
int lenb; char b[MAXlenb + 10];
signed main() {
read(n);
for (int i = 1; i <= n; ++i) {
scanf("%s", b + 1);
lenb = strlen(b + 1);
for (int j = 1; j <= lenb; ++j) b[j] -= 'a';
insert(lenb, b);
}
evafail();
scanf("%s", a + 1);
lena = strlen(a + 1);
for (int i = 1; i <= lena; ++i) a[i] -= 'a';
printf("%lld\n", match(lena, a));
return 0;
}