LuoguP4071 「SDOI2016」题解

错排问题 + 组合数

题目描述

求有多少种长度为 n 的序列 A,满足以下条件:

1 ~ n 这 n 个数在序列中各出现了一次

若第 i 个数 A[i] 的值为 i,则称 i 是稳定的。序列恰好有 m 个数是稳定的

满足条件的序列可能很多,序列数对$10^9+7$取模

Inputs and Outputs examples

Inputs’ e.g. #1

1
2
3
4
5
6
5
1 0
1 1
5 2
100 50
10000 5000

Outputs’ e.g. #1

1
2
3
4
5
0
1
20
578028887
60695423

分析

比较巧妙的一道题,较完美的结合了组合数和错排问题qwq

首先我们要知道一个经典小学奥数模型,错排问题.


错排问题的模型是这个样子的:有n封信,每封信都装错了,求装错的方法有多少种。

设$d_i$为有$i$封信装错的方案数,显然$d_1 = 0,d_2 = 1$.

当$n \geq 3$,我们钦定一个位置$k$,让第$n$位的数来这里,此时会有两种情况:

  • 第$k$位的数与第$n$位的数交换,也就是$k$位于第$n$位,此时的错排相当于$d_{n-2}$
  • $k$不位于第$n$位,此时的错排相当于$d_{n-1}$

由于$1 \leq k < n$,那么$k$的取值方法有$n-1$种,那么我们就可以得到一个递推式:

$$d_n = (n-1)(d_{n-1} + d_{n-2})$$

然后我们再结合题目,它要求有$m$个数必须满足$A_i = i$,也就是说会有$n-m$个数错排,所以答案显然就是:

$${\rm Ans} = C_n^m \times D_{n-m}$$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// luogu-judger-enable-o2
/* Headers */
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cctype>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<climits>
#include<iostream>
#include<map>
#define FOR(i,a,b,c) for(int i=(a);i<=(b);i+=(c))
#define ROF(i,a,b,c) for(int i=(a);i>=(b);i-=(c))
#define FORL(i,a,b,c) for(long long i=(a);i<=(b);i+=(c))
#define ROFL(i,a,b,c) for(long long i=(a);i>=(b);i-=(c))
#define FORR(i,a,b,c) for(register int i=(a);i<=(b);i+=(c))
#define ROFR(i,a,b,c) for(register int i=(a);i>=(b);i-=(c))
#define lowbit(x) x&(-x)
#define LeftChild(x) x<<1
#define RightChild(x) (x<<1)+1
#define RevEdge(x) x^1
#define FILE_IN(x) freopen(x,"r",stdin);
#define FILE_OUT(x) freopen(x,"w",stdout);
#define CLOSE_IN() fclose(stdin);
#define CLOSE_OUT() fclose(stdout);
#define IOS(x) std::ios::sync_with_stdio(x)
#define Dividing() printf("-----------------------------------\n");
namespace FastIO{
const int BUFSIZE = 1 << 20;
char ibuf[BUFSIZE],*is = ibuf,*its = ibuf;
char obuf[BUFSIZE],*os = obuf,*ot = obuf + BUFSIZE;
inline char getch(){
if(is == its)
its = (is = ibuf)+fread(ibuf,1,BUFSIZE,stdin);
return (is == its)?EOF:*is++;
}
inline int getint(){
int res = 0,neg = 0,ch = getch();
while(!(isdigit(ch) || ch == '-') && ch != EOF)
ch = getch();
if(ch == '-'){
neg = 1;ch = getch();
}
while(isdigit(ch)){
res = (res << 3) + (res << 1)+ (ch - '0');
ch = getch();
}
return neg?-res:res;
}
inline void flush(){
fwrite(obuf,1,os-obuf,stdout);
os = obuf;
}
inline void putch(char ch){
*os++ = ch;
if(os == ot) flush();
}
inline void putint(int res){
static char q[10];
if(res==0) putch('0');
else if(res < 0){putch('-');res = -res;}
int top = 0;
while(res){
q[top++] = res % 10 + '0';
res /= 10;
}
while(top--) putch(q[top]);
}
inline void space(bool x){
if(!x) putch('\n');
else putch(' ');
}
}
inline void read(int &x){
int rt = FastIO::getint();
x = rt;
}
inline void print(int x,bool enter){
FastIO::putint(x);
FastIO::flush();
FastIO::space(enter);
}
/* definitions */
const int mod = 1e9 + 7;
const int MAXN = 1e6 + 7;
#define type_T long long
type_T T,n,m;
type_T d[MAXN],f[MAXN],inv[MAXN];
/* functions */
inline type_T quick_multi(type_T a,type_T b,type_T p){
return 1ll * a * b % p;
}
inline type_T quick_power(type_T a,type_T b,type_T p){
type_T res = 1,base = a;
while(b){
if(b & 1) res = quick_multi(res,base,p);
b >>= 1; base = quick_multi(base,base,p);
}
return res;
}
inline void init(){
f[0] = 1;
FORL(i,1,MAXN,1){
f[i] = quick_multi(f[i-1],i,mod);
inv[i] = quick_power(f[i],(type_T)mod-2,(type_T)mod);
}
d[1] = 0; d[2] = 1; d[3] = 2;
FORL(i,4,MAXN,1){
d[i] = quick_multi(i-1,(d[i-1] + d[i-2]),mod);
}
}
inline type_T C(int n,int m){
return quick_multi(quick_multi(f[n],inv[m],mod),inv[n-m],mod);
}
int main(int argc,char *argv[]){
init();
scanf("%lld",&T);
while(T--){
scanf("%lld%lld",&n,&m);
if(n - m == 1) printf("0\n");
else if(m == n) printf("1\n");
else if(m == 0) printf("%lld\n",d[n]);
else printf("%lld\n",quick_multi(C(n,m),d[n-m],mod));
}
return 0;
}

THE END