String Hash

From EOJ Wiki
Revision as of 12:49, 22 March 2018 by Ultmaster (talk | contribs) (Created page with "<syntaxhighlight lang='cpp'> #include <bits/stdc++.h> using namespace std; #define ENABLE_DOUBLE_HASH typedef long long LL; typedef unsigned long long ULL; const int x = 13...")
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search
#include <bits/stdc++.h>
using namespace std;

#define ENABLE_DOUBLE_HASH

typedef long long LL;
typedef unsigned long long ULL;

const int x = 135;
const int maxn = 4e5 + 10;
const int p1 = 1e9 + 7, p2 = 1e9 + 9;
ULL xp1[maxn], xp2[maxn], xp[maxn];

void init_xp() {
    xp1[0] = xp2[0] = xp[0] = 1;
    for (int i = 1; i < maxn; ++i) {
        xp1[i] = xp1[i - 1] * x % p1;
        xp2[i] = xp2[i - 1] * x % p2;
        xp[i] = xp[i - 1] * x;
    }
}

struct String {
    char s[maxn];
    int length, subsize;
    bool sorted;
    ULL h[maxn], hl[maxn];

    ULL hash() {
        length = strlen(s);
        ULL res1 = 0, res2 = 0;
        h[length] = 0;  // ATTENTION!
        for (int j = length - 1; j >= 0; --j) {
        #ifdef ENABLE_DOUBLE_HASH
            res1 = (res1 * x + s[j]) % p1;
            res2 = (res2 * x + s[j]) % p2;
            h[j] = (res1 << 32) | res2;
        #else
            res1 = res1 * x + s[j];
            h[j] = res1;
        #endif
            // printf("%llu\n", h[j]);
        }
        return h[0];
    }

    ULL get_substring_hash(int left, int right) const {
        int len = right - left;
    #ifdef ENABLE_DOUBLE_HASH
        // get hash of s[left...right-1]
        unsigned int mask32 = ~(0u);
        ULL left1 = h[left] >> 32, right1 = h[right] >> 32;
        ULL left2 = h[left] & mask32, right2 = h[right] & mask32;
        return (((left1 - right1 * xp1[len] % p1 + p1) % p1) << 32) |
               (((left2 - right2 * xp2[len] % p2 + p2) % p2));
    #else
        return h[left] - h[right] * xp[len];
    #endif
    }

    void get_all_subs_hash(int sublen) {
        subsize = length - sublen + 1;
        for (int i = 0; i < subsize; ++i)
            hl[i] = get_substring_hash(i, i + sublen);
        sorted = 0;
    }

    void sort_substring_hash() {
        sort(hl, hl + subsize);
        sorted = 1;
    }

    bool match(ULL key) const {
        if (!sorted) assert (0);
        if (!subsize) return false;
        return binary_search(hl, hl + subsize, key);
    }

    void init(const char *t) {
        length = strlen(t);
        strcpy(s, t);
    }
};

int LCP(const String &a, const String &b, int ai, int bi) {
    // Find LCP of a[ai...] and b[bi...]
    int l = 0, r = min(a.length - ai, b.length - bi);
    while (l < r) {
        int mid = (l + r + 1) / 2;
        if (a.get_substring_hash(ai, ai + mid) == b.get_substring_hash(bi, bi + mid))
            l = mid;
        else r = mid - 1;
    }
    return l;
}

char stra[maxn], strb[maxn], strc[maxn], ansstr[maxn];
String S, T;
int n;

int check(int ans) {
    if (T.length < ans) return 1;
    T.get_all_subs_hash(ans); T.sort_substring_hash();
    for (int i = 0; i < S.length - ans + 1; ++i)
        if (!T.match(S.get_substring_hash(i, i + ans)))
            return 1;
    return 0;
}

int get_ans_start(int ans) {
    int start = -1;
    T.get_all_subs_hash(ans); T.sort_substring_hash();
    for (int i = 0; i < S.length - ans + 1; ++i) {
        if (!T.match(S.get_substring_hash(i, i + ans)) && (start == -1
            || strncmp(S.s + i, S.s + start, ans) < 0))
            start = i;
    }
    assert (start >= 0);
    return start;
}

int main() {
    init_xp();  // DON'T FORGET TO DO THIS!

    int kases; scanf("%d", &kases);
    for (int tt = 1; tt <= kases; ++tt) {
        printf("Case #%d: ", tt);
        scanf("%d", &n);
        scanf("%s", stra);
        memset(strb, 0, sizeof strb);
        int pos = 0;
        for (int i = 2; i <= n; ++i) {
            scanf("%s", strc);
            int tmp = strlen(strc);
            strcpy(strb + pos, strc);
            pos += tmp;
            strb[pos++] = '*';
        }
        S.init(stra); T.init(strb);
        S.hash(); T.hash();
        int l = 1, r = S.length;
        if (!check(r)) puts("Impossible");
        else {
            while (l < r) {
                int mid = (l + r) / 2;
                if (check(mid)) r = mid;
                else l = mid + 1;
            }
            memset(ansstr, 0, sizeof ansstr);
            strncpy(ansstr, stra + get_ans_start(l), l);
            // printf("%d %d\n", l, get_ans_start(l));
            printf("%s\n", ansstr);
        }
    }
}