题目链接

题意:你有一个随机数生成器,有$\frac{1}{2}$的概率返回$0$,$\frac{1}{2}$的概率返回$1$。你要用它来构造一个新的随机数生成器,这个新的随机数生成器生成的数是$[1, n]$内的整数,且生成$i$的概率为$\frac{a_i}{a_1 + a_2 + \dots + a_n}$。求在调用一次新随机数生成器的过程中最少期望调用多少次原随机数生成器。$1 \leqslant n \leqslant 10^6$,$1 \leqslant \sum_{i=1}^n a_i \leqslant 10^7$。

我们把过程看成一棵无限大的完全二叉树,调用一次随机数生成器,如果是$0$就进入左子树,否则就进入右子树。那么经过一个深度为$i$的节点的概率为$2^{1-i}$。如果我们需要使生成$i$的概率为$p$,那么我们就是要选择一些节点(可能有无限个),经过它们的概率之和为$p$,如果到达了一个被选择的节点,那么我们就直接返回$i$。

那么如果假设对于权值$i$,我们选择的节点集合为$S_i$,那么期望调用原随机数生成器的次数为$ \sum_{i=1}^{n} \sum_{j \in S_i} dep_j 2 ^ {1 - dep_j}$。我们定义$f(p)$为:我们找到一个节点集合$S$,使得经过它们的概率之和为$p$,且使$f(p)=\sum_{i \in S} dep_i 2 ^ {1 - dep_i}$最小。那么有一个显然的转移:

如果$p \geqslant \frac{1}{2}$,那么$f(p) = \frac{f(2(p - \frac{1}{2}))}{2} + p$(即选择一个根节点的孩子加入集合),否则$f(p) = \frac{f(2p)}{2} + p$。我们发现对于$0 \leqslant i < m$,$f(\frac{i}{m})$之间的转移关系构成了一棵基环森林,我们只要对在环上的部分解一下方程,对不在环上的部分直接转移即可。

时间复杂度$O(m)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using namespace std;
const int maxn = 10000000, mod = 998244353;
int n, m, invm, b[maxn + 10], bcnt, f[maxn + 10], vis[maxn + 10];
int a[maxn + 10], ans;

inline int add(int x, int y, int mod = 998244353) {
x += y; return x < mod ? x : x - mod;
}
inline int dec(int x, int y) {
x -= y; return x < 0 ? x + mod : x;
}
inline int mul(int x, int y) {
return 1ll * x * y % mod;
}
inline int fpow(int x, int y) {
int ans = 1;
while (y) {
if (y & 1) ans = mul(ans, x);
y >>= 1; x = mul(x, x);
}
return ans;
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]); m += a[i];
}
invm = fpow(m, mod - 2);
for (int i = 0; i < m; ++i)
if (!vis[i]) {
int p = i; bcnt = 0;
while (!vis[p]) {
b[++bcnt] = p;
vis[p] = i + 1;
p = add(p, p, m);
}
if (vis[p] == i + 1) {
int x = (mod + 1) / 2, y = mul(p, invm);
for (int j = add(p, p, m); j != p; j = add(j, j, m)) {
y = add(y, mul(x, mul(j, invm)));
x = mul(x, (mod + 1) / 2);
}
f[p] = mul(y, fpow(dec(1, x), mod - 2));
}
for (int j = bcnt; j >= 1; --j)
f[b[j]] = add(mul(f[add(b[j], b[j], m)], (mod + 1) / 2), mul(b[j], invm));
}
for (int i = 1; i <= n; ++i) ans = add(ans, f[a[i]]);
printf("%d", ans);
}