P3868 [TJOI2009]猜数字

一. 思路

首先拿道题看到条件给出两组数,其中一组互素,让猜一个数字。自然而然往 crt 上想。但为什么是 crt 呢?

i[1,k]\forall i \in [1,k],有 bi(nai)b_i \mid (n - a_i),这句话可以化为一组同余方程,再移项可得标准的 crt 格式:

              {na10(mod b1)na20(mod b2)nak0(mod bk) {na1(mod b1)na2(mod b2)nak(mod bk)~~~~~~~~~~~~~~\begin{cases} n - a_1 \equiv 0 &(\operatorname{mod}~b_1) \\ n - a_2 \equiv 0 &(\operatorname{mod}~b_2) \\ \cdots \\ n - a_k \equiv 0 &(\operatorname{mod}~b_k) \\ \end{cases} \\ ~\\ \Longrightarrow\begin{cases} n \equiv a_1 &(\operatorname{mod}~b_1) \\ n \equiv a_2 &(\operatorname{mod}~b_2) \\ \cdots \\ n \equiv a_k &(\operatorname{mod}~b_k) \\ \end{cases}

然后 crt 求解就好了。crt ——中国剩余定理就是提供了一个解同余方程组 {xa1(mod m1)xa2(mod m2)xan(mod mn)\begin{cases}x \equiv a_1 &(\operatorname{mod}~m_1) \\ x \equiv a_2 &(\operatorname{mod}~m_2) \\ \cdots \\ x \equiv a_n &(\operatorname{mod}~m_n) \\ \end{cases} 的公式,即 x=i=1n{ai×Mi×ti}x = \sum\limits_{i = 1}^{n} \{ a_i \times M_i \times t_i \},其中 Mi=j=1nmjmiM_i = \dfrac{\prod\limits_{j = 1}^{n}m_j}{m_i}ti=inv(Mi)t_i = \operatorname{inv}(M_i)。公式的推导详见 OI-Wiki

二. 坑点

交代码上去一看,为什么只有 90 分?最后一个点 WA 掉了,并且显示第一行第一列输出了减号。看来是爆 long long 了。所以需要龟速快速乘防止爆 long long。(本蒟蒻不会long double)

三. 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<cstdio>
#define int long long
#define re register
const int MAXn = 10;

template <typename T>
inline T qmul(T x, T y, T mod) {
if (x == 0 || y == 0) return 0;
T ret = 0;
while (y) {
if (y & 1) ret = ((ret % mod) + (x % mod)) % mod;
y >>= 1;
x = ((x % mod) + (x % mod)) % mod;
}
return ret;
}

int exgcd(int a, int b, int &x, int &y) {
if (!b) {
x = 1; y = 0;
return a;
} else {
int d = exgcd(b, a % b, y, x);
y -= a / b * x;
return d;
}
}

int inv(int a, int m) {
int k, inv;
exgcd(a, m, inv, k);
return (inv % m + m) % m;
}

int crt(int cnta, int *a, int *m) {
int prod = 1, ans = 0;
for (re int i = 1; i <= cnta; ++i) {
prod *= m[i];
}
for (re int i = 1, M; i <= cnta; ++i) {
M = prod / m[i];
ans = (ans + qmul(qmul(a[i], M, prod), inv(M, m[i]), prod)) % prod;
}
return ans;
}

int n, a[MAXn + 10], m[MAXn + 10];
signed main() {
scanf("%lld", &n);
for (re int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
}
for (re int i = 1; i <= n; ++i) {
scanf("%lld", &m[i]);
}
printf("%lld\n", crt(n, a, m));
}