Difference between revisions of "KMP and Aho-Corasick Automaton"

From EOJ Wiki
Jump to navigation Jump to search
(Created page with "== Aho-Corasick Automaton == Ordinary Match: <syntaxhighlight lang='cpp'> #include <bits/stdc++.h> using namespace std; typedef long long LL; const int MOD = 998244353; con...")
 
 
Line 1: Line 1:
 +
== KMP ==
 +
 +
<syntaxhighlight lang='cpp'>
 +
#include <bits/stdc++.h>
 +
const int maxn = 1e6 + 10;
 +
using namespace std;
 +
 +
char T[maxn]; // template
 +
char P[maxn];  // pattern
 +
int f[maxn], f2[maxn];  // optimized
 +
 +
int find(char *T, char *P, int *f) {
 +
    int n = strlen(T);
 +
    int m = strlen(P);
 +
    int j = 0, cnt = 0;
 +
    for (int i = 0; i < n; i++) {
 +
        while (j && T[i] != P[j]) j = f[j];
 +
        if (T[i] == P[j]) j++;
 +
        if (j == m) cnt++;
 +
//        if (j == m) printf("%d\n", i - m + 1);
 +
    }
 +
    return cnt;
 +
}
 +
 +
void getFail(char *P, int *f, int *f2) {
 +
    int m = strlen(P);
 +
    f[0] = f[1] = 0;
 +
    f2[0] = f2[1] = 0;
 +
    for (int i = 1; i < m; i++) {
 +
        int j = f2[i];
 +
        while (j && P[i] != P[j]) j = f2[j];
 +
        f2[i + 1] = f[i + 1] = (P[i] == P[j]) ? j + 1 : 0;
 +
        if (f[i + 1] == j + 1 && P[i + 1] == P[j + 1]) f[i + 1] = f[j + 1];
 +
    }
 +
}
 +
 +
int main() {
 +
    int n;
 +
    scanf("%d", &n);
 +
    while (n--) {
 +
        scanf("%s%s", P, T);
 +
        getFail(P, f, f2);
 +
        cout << find(T, P, f2) << endl;
 +
    }
 +
}
 +
</syntaxhighlight>
 +
 +
=== Extended KMP ===
 +
 +
<syntaxhighlight lang='cpp'>
 +
#include <bits/stdc++.h>
 +
using namespace std;
 +
 +
/*
 +
Define template S, pattern T, len(S)=n, len(T)=m
 +
Find the longest common prefix of T and every suffix of S
 +
ex[i]: the LCP between T and S[i..n-1]
 +
*/
 +
 +
const int maxn = 1e6 + 10;
 +
int nt[maxn], ex[maxn];
 +
char s[maxn], t[maxn];
 +
 +
void get_next(char *str) {
 +
    int i = 0, j, po, len = strlen(str);
 +
    nt[0] = len;
 +
    while (str[i] == str[i + 1] && i + 1 < len)
 +
        i++;
 +
    nt[1] = i;
 +
    po = 1;
 +
    for (i = 2; i < len; i++) {
 +
        if (nt[i - po] + i < nt[po] + po)
 +
            nt[i] = nt[i - po];
 +
        else {
 +
            j = nt[po] + po - i;
 +
            if (j < 0) j = 0;
 +
            while (i + j < len && str[j] == str[j + i])
 +
                j++;
 +
            nt[i] = j;
 +
            po = i;
 +
        }
 +
    }
 +
}
 +
 +
void exkmp(char *s1, char *s2) {
 +
    int i = 0, j, po, len = strlen(s1), l2 = strlen(s2);
 +
    get_next(s2);
 +
    while (s1[i] == s2[i] && i < l2 && i < len)
 +
        i++;
 +
    ex[0] = i;
 +
    po = 0;
 +
    for (i = 1; i < len; i++) {
 +
        if (nt[i - po] + i < ex[po] + po)
 +
            ex[i] = nt[i - po];
 +
        else {
 +
            j = ex[po] + po - i;
 +
            if (j < 0) j = 0;
 +
            while (i + j < len && j < l2 && s1[j + i] == s2[j])
 +
                j++;
 +
            ex[i] = j;
 +
            po = i;
 +
        }
 +
    }
 +
 +
 +
int main() {
 +
    const int modn = 1e9 + 7;
 +
    int T; scanf("%d", &T);
 +
    while (T--) {
 +
        memset(nt, 0, sizeof nt);
 +
        memset(ex, 0, sizeof ex);
 +
        scanf("%s", s); scanf("%s", t);
 +
        int slen = strlen(s), tlen = strlen(t);
 +
        reverse(s, s + slen);
 +
        reverse(t, t + tlen);
 +
        exkmp(s, t);
 +
        int ans = 0;
 +
        for (int i = 0; i < slen; ++i)
 +
            ans = (ans + 1LL * ex[i] * (ex[i] + 1) / 2) % modn;
 +
        printf("%d\n", ans);
 +
    }
 +
}
 +
</syntaxhighlight>
 +
 
== Aho-Corasick Automaton ==
 
== Aho-Corasick Automaton ==
  

Latest revision as of 12:55, 22 March 2018

KMP

#include <bits/stdc++.h>
const int maxn = 1e6 + 10;
using namespace std;

char T[maxn]; // template
char P[maxn];  // pattern
int f[maxn], f2[maxn];  // optimized

int find(char *T, char *P, int *f) {
    int n = strlen(T);
    int m = strlen(P);
    int j = 0, cnt = 0;
    for (int i = 0; i < n; i++) {
        while (j && T[i] != P[j]) j = f[j];
        if (T[i] == P[j]) j++;
        if (j == m) cnt++;
//        if (j == m) printf("%d\n", i - m + 1);
    }
    return cnt;
}

void getFail(char *P, int *f, int *f2) {
    int m = strlen(P);
    f[0] = f[1] = 0;
    f2[0] = f2[1] = 0;
    for (int i = 1; i < m; i++) {
        int j = f2[i];
        while (j && P[i] != P[j]) j = f2[j];
        f2[i + 1] = f[i + 1] = (P[i] == P[j]) ? j + 1 : 0;
        if (f[i + 1] == j + 1 && P[i + 1] == P[j + 1]) f[i + 1] = f[j + 1];
    }
}

int main() {
    int n;
    scanf("%d", &n);
    while (n--) {
        scanf("%s%s", P, T);
        getFail(P, f, f2);
        cout << find(T, P, f2) << endl;
    }
}

Extended KMP

#include <bits/stdc++.h>
using namespace std;

/*
 Define template S, pattern T, len(S)=n, len(T)=m
 Find the longest common prefix of T and every suffix of S
 ex[i]: the LCP between T and S[i..n-1]
 */

const int maxn = 1e6 + 10;
int nt[maxn], ex[maxn];
char s[maxn], t[maxn];

void get_next(char *str) {
    int i = 0, j, po, len = strlen(str);
    nt[0] = len;
    while (str[i] == str[i + 1] && i + 1 < len)
        i++;
    nt[1] = i;
    po = 1;
    for (i = 2; i < len; i++) {
        if (nt[i - po] + i < nt[po] + po)
            nt[i] = nt[i - po];
        else {
            j = nt[po] + po - i;
            if (j < 0) j = 0;
            while (i + j < len && str[j] == str[j + i])
                j++;
            nt[i] = j;
            po = i;
        }
    }
}

void exkmp(char *s1, char *s2) {
    int i = 0, j, po, len = strlen(s1), l2 = strlen(s2);
    get_next(s2);
    while (s1[i] == s2[i] && i < l2 && i < len)
        i++;
    ex[0] = i;
    po = 0;
    for (i = 1; i < len; i++) {
        if (nt[i - po] + i < ex[po] + po)
            ex[i] = nt[i - po];
        else {
            j = ex[po] + po - i;
            if (j < 0) j = 0;
            while (i + j < len && j < l2 && s1[j + i] == s2[j])
                j++;
            ex[i] = j;
            po = i;
        }
    }
}  

int main() {
    const int modn = 1e9 + 7;
    int T; scanf("%d", &T);
    while (T--) {
        memset(nt, 0, sizeof nt);
        memset(ex, 0, sizeof ex);
        scanf("%s", s); scanf("%s", t);
        int slen = strlen(s), tlen = strlen(t);
        reverse(s, s + slen);
        reverse(t, t + tlen);
        exkmp(s, t);
        int ans = 0;
        for (int i = 0; i < slen; ++i)
            ans = (ans + 1LL * ex[i] * (ex[i] + 1) / 2) % modn;
        printf("%d\n", ans);
    }
}

Aho-Corasick Automaton

Ordinary Match:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int MOD = 998244353;
const int CHAR_SIZE = 26;
const int MAX_SIZE = 1e6 + 100;

int mp(char ch) {
    return ch - 'a';
}

struct AC_Machine {
    int ch[MAX_SIZE][CHAR_SIZE], danger[MAX_SIZE], fail[MAX_SIZE];
    int sz;
    void init() {
        sz = 1;
        memset(ch[0], 0, sizeof ch[0]);
        memset(danger, 0, sizeof danger);
    }
    void _insert(const string &s, int m) {
        int n = s.size();
        int u = 0, c;
        for (int i = 0; i < n; i++) {
            c = mp(s[i]);
            if (!ch[u][c]) {
                memset(ch[sz], 0, sizeof ch[sz]);
                danger[sz] = 0;
                ch[u][c] = sz++;
            }
            u = ch[u][c];
        }
        danger[u] |= 1 << m;
    }
    void _build() {
        queue<int> Q;
        fail[0] = 0;
        for (int c = 0, u; c < CHAR_SIZE; c++) {
            u = ch[0][c];
            if (u) {
                Q.push(u);
                fail[u] = 0;
            }
        }
        int r;
        while (!Q.empty()) {
            r = Q.front();
            Q.pop();
            danger[r] |= danger[fail[r]];
            for (int c = 0, u; c < CHAR_SIZE; c++) {
                u = ch[r][c];
                if (!u) {
                    ch[r][c] = ch[fail[r]][c];
                    continue;
                }
                fail[u] = ch[fail[r]][c];
                Q.push(u);
            }
        }
    }
} ac;

char s[MAX_SIZE];

int main() {
    int n; scanf("%d", &n);
    ac.init();
    while (n--) {
        scanf("%s", s);
        ac._insert(s, 0);
    }
    ac._build();

    scanf("%s", s);
    int u = 0; n = strlen(s);
    for (int i = 0; i < n; ++i) {
        u = ac.ch[u][mp(s[i])];
        if (ac.danger[u]) {
            puts("YES");
            return 0;
        }
    }
    puts("NO");
    return 0;
}

With Status-compressed DP:

string ss;
int dp[2][MAX_SIZE][1 << 10];
int ans[50];

int main() {
    int n, m, i, j, k, s, x, u;
    cin >> m >> n;
    assert (m <= 10 && n <= 200);
    ac.init();
    set<string> spool;
    for (i = 0; i < m; i++) {
        cin >> ss;
        spool.insert(ss);
        ac._insert(ss, i);
    }
    ac._build();

    memset(dp, 0, sizeof dp);
    dp[0][0][0] = 1;
    for (i = 0, x = 1; i < n; i++, x ^= 1) {
        memset(dp[x], 0, sizeof dp[x]);
        for (j = 0; j < ac.sz; j++) {
            for (s = 0; s < (1 << m); s++) {
                if (dp[x ^ 1][j][s] == 0)
                    continue;
                for (k = 0; k < CHAR_SIZE; k++) {
                    u = ac.ch[j][k];
                    dp[x][u][s | ac.danger[u]] =
                            mod(dp[x][u][s | ac.danger[u]] + dp[x ^ 1][j][s]);
                }
            }
        }
    }

    memset(ans, 0, sizeof ans);
    for (s = 0; s < (1 << m); s++) {
        int kk = __builtin_popcount(s);
        for (j = 0; j < ac.sz; j++)
            ans[kk] = mod(ans[kk] + dp[x ^ 1][j][s]);
    }
    int final = 0;
    for (i = 0; i <= m; ++i) {
        // cout << ans[i] << endl;
        final = (final + 1LL * ans[i] * (i + 1) * (i + 1)) % MOD;
    }
    cout << final << endl;

    return 0;
}