「BZOJ 2565」最长双回文串 - Manacher

输入串 ,求 的最长双回文子串 ,即可将 分为两部分 )且 都是回文串。

链接

BZOJ 2565

题解

首先用 Manacher 求出以每个位置为中心的最长回文半径

表示以第 个位置为最右字符的最长回文串长度, 表示以第 个位置为最左字符的最长回文串长度。扫描 ,更新每个回文串的两端点。

此时只有以每个位置为中心的最长回文子串的左右端点的 有效。注意到 一定不小于 ,对于 也有相似的结论。正反分别扫一遍,更新 即可。

最后枚举每一个分隔符 #,统计答案即可。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>

const int MAXN = 1e5;

char s1[MAXN + 2], s2[MAXN * 2 + 2];
int n, len, r[MAXN * 2 + 2];

inline void prepare() {
    n = strlen(s1 + 1);

    s2[++len] = '@';
    s2[++len] = '#';
    for (int i = 1; i <= n; i++) {
        s2[++len] = s1[i];
        s2[++len] = '#';
    }

    s2[++len] = '\0';
}

inline void manacher() {
    int right = 0, pos = -1;
    for (int i = 1; i <= len; i++) {
        int x;

        if (right < i) x = 1;
        else x = std::min(r[2 * pos - i], right - i);

        while (s2[i - x] == s2[i + x]) x++;

        if (x + i > right) {
            right = x + i;
            pos = i;
        }

        r[i] = x;
    }
}

int main() {
    scanf("%s", s1 + 1);

    prepare();
    manacher();

    // puts(s2 + 1);
    // for (int i = 1; i <= len; i++) printf("%d%c", r[i], i == len ? '\n' : ' ');
    // for (int i = 1; i <= len; i++) printf("%c%c", s2[i], i == len ? '\n' : ' ');

    static int right[MAXN * 2 + 2], left[MAXN * 2 + 2];
    for (int i = 1; i <= len; i++) {
        if (i + r[i] - 1 <= len) right[i + r[i] - 1] = std::max(right[i + r[i] - 1], r[i] - 1);
        if (i - r[i] + 1 >= 0) left[i - r[i] + 1] = std::max(left[i - r[i] + 1], r[i] - 1);
    }

    for (int i = 2; i <= len; i++) {
        left[i] = std::max(left[i], left[i - 1] - 1);
        // if (s2[i - 1 - right[i - 1]] == s2[i + 1]) right[i] = std::max(right[i], right[i - 1] + 1);
    }

    for (int i = len - 1; i >= 1; i--) {
        right[i] = std::max(right[i], right[i + 1] - 1);
        // if (s2[i + 1 + left[i + 1]] == s2[i - 1]) left[i] = std::max(left[i], left[i + 1] + 1);
    }

    // for (int i = 1; i <= len; i++) printf("%d%c", right[i], i == len ? '\n' : ' ');
    // for (int i = 1; i <= len; i++) printf("%d%c", left[i], i == len ? '\n' : ' ');

    int ans = 0;
    for (int i = 1; i <= len; i++) {
        if (s2[i] == '#') {
            ans = std::max(ans, right[i] + left[i]);
        }
    }

    printf("%d\n", ans);

    return 0;
}