「NOI2016」优秀的拆分 - Hash

如果一个字符串可以被拆分为 AABB 的形式,其中 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。

例如,对于字符串 aabaabaa,如果令 ,我们就找到了这个字符串拆分成 AABB 的一种方式。

一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。比如我们令 ,也可以用 AABB 表示出上述字符串;但是,字符串 abaabaa 就没有优秀的拆分。

现在给出一个长度为 的字符串 ,我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。

题解

点我去看题解

代码

Hash T 飞了 ……

#pragma GCC optimize("O3")

#include <cstdio>
#include <climits>
#include <cassert>
#include <cstring>
#include <algorithm>

typedef unsigned __int128 hash1_t;

const int MAXN = 30000;
const hash1_t BASE1 = 233;
// const hash2_t BASE2 = 233;
// const hash3_t BASE3 = 53;

char s[MAXN];
int n;
long long forward[MAXN + 1], backward[MAXN + 1];

hash1_t hash1[MAXN], base1[MAXN + 1];
// hash2_t hash2[MAXN], base2[MAXN + 1];
// hash3_t hash3[MAXN], base3[MAXN + 1];

inline __attribute__((always_inline)) bool compare(const int a, const int b, const int len) {
    return hash1[b + len - 1] - hash1[b - 1] * base1[len] == hash1[a + len - 1] - hash1[a - 1] * base1[len];
        // && hash2[b + len - 1] - hash2[b - 1] * base2[len] == hash2[a + len - 1] - hash2[a - 1] * base2[len]
        // && hash3[b + len - 1] - hash3[b - 1] * base3[len] == hash3[a + len - 1] - hash3[a - 1] * base3[len];
    // register long long h1 = ((hash[b + len - 1] - hash[b - 1] * base[len] % MOD) % MOD + MOD) % MOD;
    // register long long h2 = ((hash[a + len - 1] - hash[a - 1] * base[len] % MOD) % MOD + MOD) % MOD;
    // register hash_t h1 = hash[b + len - 1] - hash[b - 1] * base[len];
    // register hash_t h2 = hash[a + len - 1] - hash[a - 1] * base[len];
    // const bool res = h1 == h2;
    // assert(res == (memcmp(&s[a], &s[b], len) == 0));
    // return h1 == h2;
    // return memcmp(&s[a], &s[b], len) == 0;
}

inline __attribute__((always_inline)) int lcp(register int a, register int b) {
    if (a > b) std::swap(a, b);
    if (a < 0 || b >= n) return 0;
    register int l = 0, r = n - b;
    while (l != r) {
        const register int m = l + (r - l) / 2 + 1;
        if (compare(a, b, m)) {
            l = m;
        } else r = m - 1;
    }
    return l;
}

inline __attribute__((always_inline)) int lcs(register int a, register int b) {
    if (a > b) std::swap(a, b);
    if (a < 0 || b >= n) return 0;
    register int l = 0, r = a + 1;
    while (l != r) {
        const register int m = l + (r - l) / 2 + 1;
        if (compare(a - m + 1, b - m + 1, m)) {
            l = m;
        } else r = m - 1;
    }
    return l;
}

int main() {
    // base1[0] = base2[0] = base3[0] = 1;
    base1[0] = 1;
    // for (int i = 1; i <= MAXN; i++) base[i] = base[i - 1] * BASE % MOD;
    for (register int i = 1; i <= MAXN; i++) {
        base1[i] = base1[i - 1] * BASE1;
        // base2[i] = base2[i - 1] * BASE2;
        // base3[i] = base3[i - 1] * BASE3;
    }

    int t;
    scanf("%d", &t);
    while (t--) {
        scanf("%s", s);
        n = strlen(s);
        for (register int i = 0; i < n; i++) s[i] -= 'a' - 1;

        // hash1[0] = hash2[0] = hash3[0] = s[0];
        hash1[0] = s[0];
        // for (int i = 1; i < n; i++) hash[i] = (hash[i - 1] * BASE + s[i]) % MOD;
        for (register int i = 1; i < n; i++) {
            hash1[i] = (hash1[i - 1] * BASE1 + s[i]);
            // hash2[i] = (hash2[i - 1] * BASE2 + s[i]);
            // hash3[i] = (hash3[i - 1] * BASE3 + s[i]);
        }

        for (register int k = 1; k <= n / 2; k++) {
            // printf("k = %d\n", k);
            for (register int i = 0; i < n; i += k) {
                const register int a = std::min(lcs(i, i + k), k) - 1, b = std::min(lcp(i, i + k), k);
                // printf("lcs(%d, %d) - 1 = %d, lcp(%d, %d) = %d\n", i, i + k, a, i, i + k, b);
                if (a + b >= k) {
                    register int l, r;
                    l = i + k + k - a - 1, r = l + (a + b - k);
                    // l = std::max(i + k + k - a, i + k), r = std::min(l + (a + b - k), i + k + k - 1);
                    // printf("[%d, %d]\n", l, r);
                    // for (int i = l; i <= r; i++) backward[i]++;
                    backward[l]++, backward[r + 1]--;
                    r = r - 2 * k + 1, l = l - 2 * k + 1;
                    // for (int i = l; i <= r; i++) forward[i]++;
                    forward[l]++, forward[r + 1]--;
                }
            }
        }

        for (register int i = 1; i < n; i++) forward[i] += forward[i - 1], backward[i] += backward[i - 1];

        register long long ans = 0;
        for (register int i = 1; i < n; i++) ans += backward[i - 1] * forward[i];

        for (register int i = 0; i < n; i++) {
            hash1[i] = 0;
            // hash2[i] = 0;
            // hash3[i] = 0;
            backward[i] = forward[i] = 0;
        }
        forward[n] = backward[n] = 0;

        printf("%lld\n", ans);
    }

    return 0;
}