L3-3 可怜的简单题

代码存档:

#include <cstdio>
#include <iostream>
#include <cstring>
#include <bitset>

using namespace std;

#define IN freopen("in.txt", "r", stdin)
#define OUT freopen("out1.txt", "w", stdout)

const int N = 2.2e7;

__int128 n, p, w[N + 18], ans = 1, m[N + 18], prime[N + 18], cnt;
bitset<N + 9> v, vis;

inline __int128 read()
{
    __int128 x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

void print(__int128 x)
{
    if (!x)
        return;
    if (x < 0)
        putchar('-'), x = -x;
    print(x / 10);
    putchar(x % 10 + '0');
}

auto id = [](__int128 x)
{ return n / x; };

auto mu = [](auto self, __int128 n)
{
    if (n <= N)
        return m[n];
    if (v[id(n)])
        return w[id(n)];
    __int128 ret = 1;
    for (__int128 l = 2, r; l <= n; l = r + 1)
    {
        r = n / (n / l);
        ret -= (r - l + 1) * self(self, n / l);
        ret %= p;
    }
    v[id(n)] = 1;
    ret %= p;
    return w[id(n)] = ret;
};

int main()
{
   // IN;
    //OUT;

    n = read(), p = read();
    if(n==1){
        print(1);
        return 0;
    }
    m[1] = 1;
    for (__int128 i = 2; i <= N; ++i)
    {
        if (!vis[i])
            m[prime[++cnt] = i] = -1;
        for (__int128 j = 1; j <= cnt && i * prime[j] <= N; ++j)
        {
            vis[i * prime[j]] = 1;
            if (i % prime[j] == 0)
                break;
            m[i * prime[j]] = -m[i];
        }
        m[i] = (m[i] + m[i - 1]) % p;
    }

    auto inv = [](__int128 n)
    {
        __int128 a = n, res = 1, b = p - 2;
        while (b > 0)
        {
            if (b & 1)
                res = res * a % p;
            a = a * a % p, b >>= 1;
        }
        return res;
    };

    mu(mu, n);

    for (__int128 l = 1, r; l <= n; l = r + 1)
    {
        r = n / (n / l);
        __int128 tmp = ((mu(mu, r) - mu(mu, l - 1)) % p + p) % p * (n / l) % p * inv(n - n / l) % p;
        tmp = (tmp + p) % p;
        ans = ((ans - tmp) % p + p) % p;
    }
    ans = (ans + p) % p;
    print(ans);
    return 0;
}
上一篇
下一篇