ac 自动机有两种,trie 树版和 trie 图版。
注意事项:
-
trie 树版可以直接将根(空串的节点)的下标设为 1,而 trie 图版若将根设为 1 还需将 son 数组的初值全部设为 1。
-
在函数 evafail 中,首先插入队列的不应该是根节点,而应该是根节点的儿子节点。这与 kmp 的函数 evafail 中变量 i 从 2 开始的原因一致。
trie 树版
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10]; int cntep[MAXsumlenb + 10]; void insert(int lenb, char *b) { int cur = 0; for (int i = 1; i <= lenb; ++i) { if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd; cur = son[cur][b[i]]; } ++cntep[cur]; } int l, r; int que[MAXsumlenb + 10]; void evafail() { l = 1, r = 0; for (int i = 0; i < 26; ++i) { if (son[0][i]) { fail[son[0][i]] = 0; que[++r] = son[0][i]; } } while (l <= r) { int i = que[l++]; for (int k = 0; k < 26; ++k) { if (son[i][k]) { int j = fail[i]; while (j && son[j][k] == 0) j = fail[j]; if (son[j][k]) j = son[j][k]; fail[son[i][k]] = j; que[++r] = son[i][k]; } } } } bool vis[MAXsumlenb + 10]; ll match(int lena, char *a) { ll ans = 0; int j = 0; for (int i = 1; i <= lena; ++i) { while (j && son[j][a[i]] == 0) j = fail[j]; if (son[j][a[i]]) j = son[j][a[i]]; } return ans; }
|
trie 图版
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10]; int cntep[MAXsumlenb + 10]; void insert(int lenb, char *b) { int cur = 0; for (int i = 1; i <= lenb; ++i) { if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd; cur = son[cur][b[i]]; } ++cntep[cur]; } int l, r; int que[MAXsumlenb + 10]; void evafail() { l = 1, r = 0; for (int i = 0; i < 26; ++i) { if (son[0][i]) { fail[son[0][i]] = 0; que[++r] = son[0][i]; } } while (l <= r) { int i = que[l++]; for (int k = 0; k < 26; ++k) { if (son[i][k]) { fail[son[i][k]] = son[fail[i]][k]; que[++r] = son[i][k]; } else { son[i][k] = son[fail[i]][k]; } } } } bool vis[MAXsumlenb + 10]; ll match(int lena, char *a) { ll ans = 0; int j = 0; for (int i = 1; i <= lena; ++i) { j = son[j][a[i]]; } return ans; }
|
题目
Luogu P3808 【模板】AC 自动机(简单版)(注:该题数据过水,不建议作为测试代码是否正确的标准)
以 trie 树版为例,代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| #include<bits/stdc++.h> using namespace std; typedef long long ll; const int MAXn = 1e6; const int MAXlena = 1e6; const int MAXlenb = 1e6; const int MAXsumlenb = 1e6;
template <typename T> inline void read(T &a) { char c;for (c = getchar(); (c < '0' || c > '9') && c != '-'; c = getchar());bool f = c == '-';T x = f ? 0 : (c ^ '0');for (c = getchar(); c >= '0' && c <= '9'; c = getchar()) {x = x * 10 + (c ^ '0');}a = f ? -x : x; } template <typename T, typename ...Argv> inline void read(T &a, Argv &...argv) { read(a), read(argv...); }
int head[MAXsumlenb + 10], cntnex, nex[MAXsumlenb + 10], to[MAXsumlenb + 10]; inline void connect(int u, int v) { nex[++cntnex] = head[u]; head[u] = cntnex; to[cntnex] = v; }
int cntnd, son[MAXsumlenb + 10][26 + 2], fail[MAXsumlenb + 10]; int cntep[MAXsumlenb + 10]; void insert(int lenb, char *b) { int cur = 0; for (int i = 1; i <= lenb; ++i) { if (son[cur][b[i]] == 0) son[cur][b[i]] = ++cntnd; cur = son[cur][b[i]]; } ++cntep[cur]; } int l, r; int que[MAXsumlenb + 10]; void evafail() { l = 1, r = 0; for (int i = 0; i < 26; ++i) { if (son[0][i]) { fail[son[0][i]] = 0; que[++r] = son[0][i]; } } while (l <= r) { int i = que[l++]; for (int k = 0; k < 26; ++k) { if (son[i][k]) { int j = fail[i]; while (j && son[j][k] == 0) j = fail[j]; if (son[j][k]) j = son[j][k]; fail[son[i][k]] = j; que[++r] = son[i][k]; } } } } bool vis[MAXsumlenb + 10]; ll match(int lena, char *a) { ll ans = 0; int j = 0; for (int i = 1; i <= lena; ++i) { while (j && son[j][a[i]] == 0) j = fail[j]; if (son[j][a[i]]) j = son[j][a[i]]; int k = j; while (k && vis[k] == 0) { vis[k] = 1; ans += cntep[k]; k = fail[k]; } } return ans; }
int n; int lena; char a[MAXlena + 10]; int lenb; char b[MAXlenb + 10]; signed main() { read(n); for (int i = 1; i <= n; ++i) { scanf("%s", b + 1); lenb = strlen(b + 1); for (int j = 1; j <= lenb; ++j) b[j] -= 'a'; insert(lenb, b); } evafail(); scanf("%s", a + 1); lena = strlen(a + 1); for (int i = 1; i <= lena; ++i) a[i] -= 'a'; printf("%lld\n", match(lena, a)); return 0; }
|