「BZOJ 3796」Mushroom - 后缀数组 + AC 自动机

给三个字符串 ,求一个长度最大的 ,满足:

  1. 的子串;
  2. 的子串;
  3. 不是 的子串;

链接

BZOJ 3796

题解

建立后缀数组,枚举 的每个后缀,该后缀与在后缀数组离它最近的 的某个后缀的最长公共前缀长度为在满足前两个条件下的最大答案。

对于条件 3,对 建立 AC 自动机,求出 中的每个匹配位置,设每个 位置的右边第一个匹配 的位置为 ,则在满足条件三的情况下,第 个后缀的最大答案为

将两个答案取较小值即可。

代码

#include <cstdio>
#include <climits>
#include <cstring>
#include <queue>
#include <algorithm>

const int MAXN = 100000;
const int MAXN_LOG = 17; // Math.log2(1e5) = 16.609640474436812
const int CHARSET_SIZE = 12;
const int BASE_CHAR = 'a';

struct Trie {
    struct Node {
        Node *c[CHARSET_SIZE], *fail;
        bool isWord;
        int depth;

        Node(const int isWord = false) : fail(NULL), isWord(isWord), depth(0) {
            for (int i = 0; i < CHARSET_SIZE; i++) c[i] = NULL;
        }
    } *root;

    Trie() : root(NULL) {}

    void insert(const char *begin, const char *end) {
        Node **v = &root;
        for (const char *p = begin; p != end; p++) {
            if (!*v) *v = new Node;
            v = &(*v)->c[*p - BASE_CHAR];
        }
        if (!*v) *v = new Node(true);
        else (*v)->isWord = true;
    }

    void build() {
        std::queue<Node *> q;
        root->fail = root;
        root->depth = 0;
        q.push(root);
        while (!q.empty()) {
            Node *v = q.front();
            q.pop();

            for (int i = 0; i < CHARSET_SIZE; i++) {
                Node *&c = v->c[i];
                if (!c) {
                    c = v->fail->c[i] ? v->fail->c[i] : root;
                    continue;
                }
                c->depth = v->depth + 1;
                Node *u = v->fail;
                while (u != root && !u->c[i]) u = u->fail;
                c->fail = v != root && u->c[i] ? u->c[i] : root;
                q.push(c);
            }
        }
    }

    void exec(const char *begin, const char *end, bool *match) {
        Node *v = root;
        for (const char *p = begin; p != end; p++) {
            v = v->c[*p - BASE_CHAR];
            if (v->isWord) match[p - begin - v->depth + 1] = true;
        }
    }
} t;

struct SuffixArray {
    int n, sa[MAXN], rk[MAXN], ht[MAXN], st[MAXN][MAXN_LOG + 1], log[MAXN + 1];

    template <typename T>
    void build(const T *s, const int n) {
        this->n = n;
        static int set[MAXN], a[MAXN];
        std::copy(s, s + n, set);
        std::sort(set, set + n);
        int *end = std::unique(set, set + n);
        for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;

        static int fir[MAXN], sec[MAXN], tmp[MAXN], _buc[MAXN + 1], *buc = _buc + 1;
        std::fill(buc - 1, buc + n, 0);
        for (int i = 0; i < n; i++) buc[a[i]]++;
        for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
        for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];

        for (int t = 1; t < n; t *= 2) {
            for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[sec[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[fir[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];

            for (int i = 0; i < n; i++) {
                if (!i) rk[sa[i]] = 0;
                else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]];
                else rk[sa[i]] = rk[sa[i - 1]] + 1;
            }
        }

        for (int i = 0, k = 0; i < n; i++) {
            if (!rk[i]) continue;
            int j = sa[rk[i] - 1];
            if (k) k--;
            while (i + k < n && j + k < n && s[i + k] == s[j + k]) k++;
            ht[rk[i]] = k;
        }

#ifdef DBG
        printf("build: %s\n", s);
        for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d %s\n", ht[i], &s[sa[i]]);
#endif

        for (int i = 0; i < n; i++) st[i][0] = ht[i];
        for (int j = 1; (1 << j) < n; j++) {
            for (int i = 0; i < n; i++) {
                if (i + (1 << (j - 1)) < n) st[i][j] = std::min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
                else st[i][j] = st[i][j - 1];
            }
        }
        for (int i = 0; i <= n; i++) {
            int x = 0;
            while ((1 << x) <= i) x++;
            log[i] = x - 1;
        }
    }

    int rmq(const int l, const int r) {
        if (l == r) return st[l][0];
        else {
            const int t = log[r - l];
            return std::min(st[l][t], st[r - (1 << t) + 1][t]);
        }
    }

    int lcp(const int i, const int j) {
        if (i == j) return n - i;
        int a = rk[i], b = rk[j];
        if (a > b) std::swap(a, b);
        return rmq(a + 1, b);
    }
} sa1, sa2;

int main() {
    static char buf[MAXN];
    scanf("%s", buf);
    const int n1 = strlen(buf);
    buf[n1] = '
;
    scanf("%s", buf + n1 + 1);
    const int n2 = strlen(buf + n1 + 1);
    sa1.build(buf, n1 + 1 + n2);

    scanf("%s", buf + n1 + 1);
    const int n3 = strlen(buf + n1 + 1);
    // sa2.build(buf, n1 + 1 + n3);

    t.insert(buf + n1 + 1, buf + n1 + 1 + n3);
    t.build();

    static bool match[MAXN];
    t.exec(buf, buf + n1, match);

    static int nextMatch[MAXN];
    for (int i = n1 - 1, x = INT_MAX; i >= 0; i--) {
        if (match[i]) x = i;
        nextMatch[i] = x == INT_MAX ? INT_MAX : x - i + n3 - 1;
    }

#ifdef DBG
    puts("KMP:");
    buf[n1] = '\0';
    for (int i = 0; i < n1; i++) printf("%d %s\n", nextMatch[i], &buf[i]);
#endif

    static int l[MAXN], r[MAXN];
    for (int i = 1, x = -1; i < n1 + n2 + 1; i++) {
        if (sa1.sa[i] < n1) l[i] = x;
        else if (sa1.sa[i] > n1) x = i;
    }
    for (int i = n1 + n2 + 1 - 1, x = -1; i >= 1; i--) {
        if (sa1.sa[i] < n1) r[i] = x;
        else if (sa1.sa[i] > n1) x = i;
    }

#ifdef DBG
    for (int i = 1; i < n1 + n2 + 1; i++) {
        if (sa1.sa[i] < n1) printf("l[%d] = %d\n", i, l[i]);
        if (sa1.sa[i] < n1) printf("r[%d] = %d\n", i, r[i]);
    }
#endif

    int ans = 0;
    for (int i = 1; i < n1 + n2 + 1; i++) {
        if (sa1.sa[i] > n1) continue;
        int x;
        if (l[i] != -1) {
            x = sa1.lcp(sa1.sa[l[i]], sa1.sa[i]);
            x = std::min(x, nextMatch[sa1.sa[i]]);
            if (x) {
                ans = std::max(ans, x);
#ifdef DBG
                printf("%d\n", x);
#endif
            }
        }
        if (r[i] != -1) {
            x = sa1.lcp(sa1.sa[r[i]], sa1.sa[i]);
            x = std::min(x, nextMatch[sa1.sa[i]]);
            if (x) {
                ans = std::max(ans, x);
#ifdef DBG
                printf("%d\n", x);
#endif
            }
        }
    }

    printf("%d\n", ans);

    return 0;
}