Difference between revisions of "KMP and Aho-Corasick Automaton"
Jump to navigation
Jump to search
(Created page with "== Aho-Corasick Automaton == Ordinary Match: <syntaxhighlight lang='cpp'> #include <bits/stdc++.h> using namespace std; typedef long long LL; const int MOD = 998244353; con...") |
|||
Line 1: | Line 1: | ||
+ | == KMP == | ||
+ | |||
+ | <syntaxhighlight lang='cpp'> | ||
+ | #include <bits/stdc++.h> | ||
+ | const int maxn = 1e6 + 10; | ||
+ | using namespace std; | ||
+ | |||
+ | char T[maxn]; // template | ||
+ | char P[maxn]; // pattern | ||
+ | int f[maxn], f2[maxn]; // optimized | ||
+ | |||
+ | int find(char *T, char *P, int *f) { | ||
+ | int n = strlen(T); | ||
+ | int m = strlen(P); | ||
+ | int j = 0, cnt = 0; | ||
+ | for (int i = 0; i < n; i++) { | ||
+ | while (j && T[i] != P[j]) j = f[j]; | ||
+ | if (T[i] == P[j]) j++; | ||
+ | if (j == m) cnt++; | ||
+ | // if (j == m) printf("%d\n", i - m + 1); | ||
+ | } | ||
+ | return cnt; | ||
+ | } | ||
+ | |||
+ | void getFail(char *P, int *f, int *f2) { | ||
+ | int m = strlen(P); | ||
+ | f[0] = f[1] = 0; | ||
+ | f2[0] = f2[1] = 0; | ||
+ | for (int i = 1; i < m; i++) { | ||
+ | int j = f2[i]; | ||
+ | while (j && P[i] != P[j]) j = f2[j]; | ||
+ | f2[i + 1] = f[i + 1] = (P[i] == P[j]) ? j + 1 : 0; | ||
+ | if (f[i + 1] == j + 1 && P[i + 1] == P[j + 1]) f[i + 1] = f[j + 1]; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | int main() { | ||
+ | int n; | ||
+ | scanf("%d", &n); | ||
+ | while (n--) { | ||
+ | scanf("%s%s", P, T); | ||
+ | getFail(P, f, f2); | ||
+ | cout << find(T, P, f2) << endl; | ||
+ | } | ||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | === Extended KMP === | ||
+ | |||
+ | <syntaxhighlight lang='cpp'> | ||
+ | #include <bits/stdc++.h> | ||
+ | using namespace std; | ||
+ | |||
+ | /* | ||
+ | Define template S, pattern T, len(S)=n, len(T)=m | ||
+ | Find the longest common prefix of T and every suffix of S | ||
+ | ex[i]: the LCP between T and S[i..n-1] | ||
+ | */ | ||
+ | |||
+ | const int maxn = 1e6 + 10; | ||
+ | int nt[maxn], ex[maxn]; | ||
+ | char s[maxn], t[maxn]; | ||
+ | |||
+ | void get_next(char *str) { | ||
+ | int i = 0, j, po, len = strlen(str); | ||
+ | nt[0] = len; | ||
+ | while (str[i] == str[i + 1] && i + 1 < len) | ||
+ | i++; | ||
+ | nt[1] = i; | ||
+ | po = 1; | ||
+ | for (i = 2; i < len; i++) { | ||
+ | if (nt[i - po] + i < nt[po] + po) | ||
+ | nt[i] = nt[i - po]; | ||
+ | else { | ||
+ | j = nt[po] + po - i; | ||
+ | if (j < 0) j = 0; | ||
+ | while (i + j < len && str[j] == str[j + i]) | ||
+ | j++; | ||
+ | nt[i] = j; | ||
+ | po = i; | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | void exkmp(char *s1, char *s2) { | ||
+ | int i = 0, j, po, len = strlen(s1), l2 = strlen(s2); | ||
+ | get_next(s2); | ||
+ | while (s1[i] == s2[i] && i < l2 && i < len) | ||
+ | i++; | ||
+ | ex[0] = i; | ||
+ | po = 0; | ||
+ | for (i = 1; i < len; i++) { | ||
+ | if (nt[i - po] + i < ex[po] + po) | ||
+ | ex[i] = nt[i - po]; | ||
+ | else { | ||
+ | j = ex[po] + po - i; | ||
+ | if (j < 0) j = 0; | ||
+ | while (i + j < len && j < l2 && s1[j + i] == s2[j]) | ||
+ | j++; | ||
+ | ex[i] = j; | ||
+ | po = i; | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | |||
+ | int main() { | ||
+ | const int modn = 1e9 + 7; | ||
+ | int T; scanf("%d", &T); | ||
+ | while (T--) { | ||
+ | memset(nt, 0, sizeof nt); | ||
+ | memset(ex, 0, sizeof ex); | ||
+ | scanf("%s", s); scanf("%s", t); | ||
+ | int slen = strlen(s), tlen = strlen(t); | ||
+ | reverse(s, s + slen); | ||
+ | reverse(t, t + tlen); | ||
+ | exkmp(s, t); | ||
+ | int ans = 0; | ||
+ | for (int i = 0; i < slen; ++i) | ||
+ | ans = (ans + 1LL * ex[i] * (ex[i] + 1) / 2) % modn; | ||
+ | printf("%d\n", ans); | ||
+ | } | ||
+ | } | ||
+ | </syntaxhighlight> | ||
+ | |||
== Aho-Corasick Automaton == | == Aho-Corasick Automaton == | ||
Latest revision as of 12:55, 22 March 2018
KMP
#include <bits/stdc++.h>
const int maxn = 1e6 + 10;
using namespace std;
char T[maxn]; // template
char P[maxn]; // pattern
int f[maxn], f2[maxn]; // optimized
int find(char *T, char *P, int *f) {
int n = strlen(T);
int m = strlen(P);
int j = 0, cnt = 0;
for (int i = 0; i < n; i++) {
while (j && T[i] != P[j]) j = f[j];
if (T[i] == P[j]) j++;
if (j == m) cnt++;
// if (j == m) printf("%d\n", i - m + 1);
}
return cnt;
}
void getFail(char *P, int *f, int *f2) {
int m = strlen(P);
f[0] = f[1] = 0;
f2[0] = f2[1] = 0;
for (int i = 1; i < m; i++) {
int j = f2[i];
while (j && P[i] != P[j]) j = f2[j];
f2[i + 1] = f[i + 1] = (P[i] == P[j]) ? j + 1 : 0;
if (f[i + 1] == j + 1 && P[i + 1] == P[j + 1]) f[i + 1] = f[j + 1];
}
}
int main() {
int n;
scanf("%d", &n);
while (n--) {
scanf("%s%s", P, T);
getFail(P, f, f2);
cout << find(T, P, f2) << endl;
}
}
Extended KMP
#include <bits/stdc++.h>
using namespace std;
/*
Define template S, pattern T, len(S)=n, len(T)=m
Find the longest common prefix of T and every suffix of S
ex[i]: the LCP between T and S[i..n-1]
*/
const int maxn = 1e6 + 10;
int nt[maxn], ex[maxn];
char s[maxn], t[maxn];
void get_next(char *str) {
int i = 0, j, po, len = strlen(str);
nt[0] = len;
while (str[i] == str[i + 1] && i + 1 < len)
i++;
nt[1] = i;
po = 1;
for (i = 2; i < len; i++) {
if (nt[i - po] + i < nt[po] + po)
nt[i] = nt[i - po];
else {
j = nt[po] + po - i;
if (j < 0) j = 0;
while (i + j < len && str[j] == str[j + i])
j++;
nt[i] = j;
po = i;
}
}
}
void exkmp(char *s1, char *s2) {
int i = 0, j, po, len = strlen(s1), l2 = strlen(s2);
get_next(s2);
while (s1[i] == s2[i] && i < l2 && i < len)
i++;
ex[0] = i;
po = 0;
for (i = 1; i < len; i++) {
if (nt[i - po] + i < ex[po] + po)
ex[i] = nt[i - po];
else {
j = ex[po] + po - i;
if (j < 0) j = 0;
while (i + j < len && j < l2 && s1[j + i] == s2[j])
j++;
ex[i] = j;
po = i;
}
}
}
int main() {
const int modn = 1e9 + 7;
int T; scanf("%d", &T);
while (T--) {
memset(nt, 0, sizeof nt);
memset(ex, 0, sizeof ex);
scanf("%s", s); scanf("%s", t);
int slen = strlen(s), tlen = strlen(t);
reverse(s, s + slen);
reverse(t, t + tlen);
exkmp(s, t);
int ans = 0;
for (int i = 0; i < slen; ++i)
ans = (ans + 1LL * ex[i] * (ex[i] + 1) / 2) % modn;
printf("%d\n", ans);
}
}
Aho-Corasick Automaton
Ordinary Match:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MOD = 998244353;
const int CHAR_SIZE = 26;
const int MAX_SIZE = 1e6 + 100;
int mp(char ch) {
return ch - 'a';
}
struct AC_Machine {
int ch[MAX_SIZE][CHAR_SIZE], danger[MAX_SIZE], fail[MAX_SIZE];
int sz;
void init() {
sz = 1;
memset(ch[0], 0, sizeof ch[0]);
memset(danger, 0, sizeof danger);
}
void _insert(const string &s, int m) {
int n = s.size();
int u = 0, c;
for (int i = 0; i < n; i++) {
c = mp(s[i]);
if (!ch[u][c]) {
memset(ch[sz], 0, sizeof ch[sz]);
danger[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
danger[u] |= 1 << m;
}
void _build() {
queue<int> Q;
fail[0] = 0;
for (int c = 0, u; c < CHAR_SIZE; c++) {
u = ch[0][c];
if (u) {
Q.push(u);
fail[u] = 0;
}
}
int r;
while (!Q.empty()) {
r = Q.front();
Q.pop();
danger[r] |= danger[fail[r]];
for (int c = 0, u; c < CHAR_SIZE; c++) {
u = ch[r][c];
if (!u) {
ch[r][c] = ch[fail[r]][c];
continue;
}
fail[u] = ch[fail[r]][c];
Q.push(u);
}
}
}
} ac;
char s[MAX_SIZE];
int main() {
int n; scanf("%d", &n);
ac.init();
while (n--) {
scanf("%s", s);
ac._insert(s, 0);
}
ac._build();
scanf("%s", s);
int u = 0; n = strlen(s);
for (int i = 0; i < n; ++i) {
u = ac.ch[u][mp(s[i])];
if (ac.danger[u]) {
puts("YES");
return 0;
}
}
puts("NO");
return 0;
}
With Status-compressed DP:
string ss;
int dp[2][MAX_SIZE][1 << 10];
int ans[50];
int main() {
int n, m, i, j, k, s, x, u;
cin >> m >> n;
assert (m <= 10 && n <= 200);
ac.init();
set<string> spool;
for (i = 0; i < m; i++) {
cin >> ss;
spool.insert(ss);
ac._insert(ss, i);
}
ac._build();
memset(dp, 0, sizeof dp);
dp[0][0][0] = 1;
for (i = 0, x = 1; i < n; i++, x ^= 1) {
memset(dp[x], 0, sizeof dp[x]);
for (j = 0; j < ac.sz; j++) {
for (s = 0; s < (1 << m); s++) {
if (dp[x ^ 1][j][s] == 0)
continue;
for (k = 0; k < CHAR_SIZE; k++) {
u = ac.ch[j][k];
dp[x][u][s | ac.danger[u]] =
mod(dp[x][u][s | ac.danger[u]] + dp[x ^ 1][j][s]);
}
}
}
}
memset(ans, 0, sizeof ans);
for (s = 0; s < (1 << m); s++) {
int kk = __builtin_popcount(s);
for (j = 0; j < ac.sz; j++)
ans[kk] = mod(ans[kk] + dp[x ^ 1][j][s]);
}
int final = 0;
for (i = 0; i <= m; ++i) {
// cout << ans[i] << endl;
final = (final + 1LL * ans[i] * (i + 1) * (i + 1)) % MOD;
}
cout << final << endl;
return 0;
}