computer/PS

[백준] 1492 - 합 (C++)

ketrewq 2025. 5. 12. 07:38

 

 

1. 문제 소개 

 

 

2. 풀이 

 

Berlekamp-Massey 알고리즘

 

Berlekamp-Massey 알고리즘

Berlekamp-Massey 알고리즘은 특정한 DP의 점화식을 찾아주는 알고리즘이다. $10^{18}$ 번째 피보나치 수를 찾기 위해서 행렬 곱셈을 짜고, 타일 채우기 문제를 풀기 위해서 수많은 점화식과 씨름하던

koosaga.com

 

 

Berlekamp-Massey 알고리즘을 사용했다. 위의 글을 읽어보면 친절하게 설명돼 있는데, 한 30% 이해한 것 같다. 

 

 

#include <iostream>
#include <algorithm>
#include <vector>
#include <ctime>
#include <random>

using namespace std;

const int MOD = 1e9 + 7;

using lint = long long;

lint ipow(lint x, lint p)
{
    lint ret = 1, piv = x;
    while (p)
    {
        if (p & 1)
            ret = ret * piv % MOD;
        piv = piv * piv % MOD;
        p >>= 1;
    }
    return ret;
}
vector<int> berlekamp_massey(vector<int> x)
{
    vector<int> ls, cur;
    int lf, ld;
    for (int i = 0; i < x.size(); i++)
    {
        lint t = 0;
        for (int j = 0; j < cur.size(); j++)
        {
            t = (t + 1ll * x[i - j - 1] * cur[j]) % MOD;
        }
        if ((t - x[i]) % MOD == 0)
            continue;
        if (cur.empty())
        {
            cur.resize(i + 1);
            lf = i;
            ld = (t - x[i]) % MOD;
            continue;
        }
        lint k = -(x[i] - t) * ipow(ld, MOD - 2) % MOD;
        vector<int> c(i - lf - 1);
        c.push_back(k);
        for (auto &j : ls)
            c.push_back(-j * k % MOD);
        if (c.size() < cur.size())
            c.resize(cur.size());
        for (int j = 0; j < cur.size(); j++)
        {
            c[j] = (c[j] + cur[j]) % MOD;
        }
        if (i - lf + (int)ls.size() >= (int)cur.size())
        {
            tie(ls, lf, ld) = make_tuple(cur, i, (t - x[i]) % MOD);
        }
        cur = c;
    }
    for (auto &i : cur)
        i = (i % MOD + MOD) % MOD;
    return cur;
}
int get_nth(vector<int> rec, vector<int> dp, lint n)
{
    int m = rec.size();
    vector<int> s(m), t(m);
    s[0] = 1;
    if (m != 1)
        t[1] = 1;
    else
        t[0] = rec[0];
    auto mul = [&rec](vector<int> v, vector<int> w)
    {
        int m = v.size();
        vector<int> t(2 * m);
        for (int j = 0; j < m; j++)
        {
            for (int k = 0; k < m; k++)
            {
                t[j + k] += 1ll * v[j] * w[k] % MOD;
                if (t[j + k] >= MOD)
                    t[j + k] -= MOD;
            }
        }
        for (int j = 2 * m - 1; j >= m; j--)
        {
            for (int k = 1; k <= m; k++)
            {
                t[j - k] += 1ll * t[j] * rec[k - 1] % MOD;
                if (t[j - k] >= MOD)
                    t[j - k] -= MOD;
            }
        }
        t.resize(m);
        return t;
    };
    while (n)
    {
        if (n & 1)
            s = mul(s, t);
        t = mul(t, t);
        n >>= 1;
    }
    lint ret = 0;
    for (int i = 0; i < m; i++)
        ret += 1ll * s[i] * dp[i] % MOD;
    return ret % MOD;
}
int guess_nth_term(vector<int> x, lint n)
{
    if (n < x.size())
        return x[n];
    vector<int> v = berlekamp_massey(x);
    if (v.empty())
        return 0;
    return get_nth(v, x, n);
}
struct elem
{
    int x, y, v;
}; // A_(x, y) <- v, 0-based. no duplicate please..
vector<int> get_min_poly(int n, vector<elem> M)
{
    // smallest poly P such that A^i = sum_{j < i} {A^j \times P_j}
    vector<int> rnd1, rnd2;
    mt19937 rng(0x14004);
    auto randint = [&rng](int lb, int ub)
    {
        return uniform_int_distribution<int>(lb, ub)(rng);
    };
    for (int i = 0; i < n; i++)
    {
        rnd1.push_back(randint(1, MOD - 1));
        rnd2.push_back(randint(1, MOD - 1));
    }
    vector<int> gobs;
    for (int i = 0; i < 2 * n + 2; i++)
    {
        int tmp = 0;
        for (int j = 0; j < n; j++)
        {
            tmp += 1ll * rnd2[j] * rnd1[j] % MOD;
            if (tmp >= MOD)
                tmp -= MOD;
        }
        gobs.push_back(tmp);
        vector<int> nxt(n);
        for (auto &i : M)
        {
            nxt[i.x] += 1ll * i.v * rnd1[i.y] % MOD;
            if (nxt[i.x] >= MOD)
                nxt[i.x] -= MOD;
        }
        rnd1 = nxt;
    }
    auto sol = berlekamp_massey(gobs);
    reverse(sol.begin(), sol.end());
    return sol;
}
lint det(int n, vector<elem> M)
{
    vector<int> rnd;
    mt19937 rng(0x14004);
    auto randint = [&rng](int lb, int ub)
    {
        return uniform_int_distribution<int>(lb, ub)(rng);
    };
    for (int i = 0; i < n; i++)
        rnd.push_back(randint(1, MOD - 1));
    for (auto &i : M)
    {
        i.v = 1ll * i.v * rnd[i.y] % MOD;
    }
    auto sol = get_min_poly(n, M)[0];
    if (n % 2 == 0)
        sol = MOD - sol;
    for (auto &i : rnd)
        sol = 1ll * sol * ipow(i, MOD - 2) % MOD;
    return sol;
}

vector<int> get_first_values(int k, int need)
{
    vector<int> dp(need + 1);
    for (int i = 1; i <= need; i++)
    {
        long long t = ipow(i, k);
        dp[i] = (dp[i - 1] + t) % MOD;
    }
    return dp;
}

int main() {
    long long N;
    int K;
    cin >> N >> K;

    int sample_size = 3 * K + 3;

    vector<int> dp = get_first_values(K, sample_size);
    vector<int> initial(dp.begin(), dp.end()); 
    vector<int> rec = berlekamp_massey(initial);

    int ans = get_nth(rec, initial, N);
    cout << ans << "\n";
}

'computer > PS' 카테고리의 다른 글

[백준] 4179 - 불! (C++)  (0) 2025.05.10
[백준] 3015 - 오아시스 재결합 (C++)  (0) 2025.05.09
[알고리즘] BFS/DFS와 동적 계획법  (0) 2024.12.02