패턴을 찾아야 하는 문제다.
d[n][x]를 n*n크기에서 3을 x개만 가지고 있을 때 가능한 경우의 수라고 할 때
X 0 1 2 3 4 5
N = 1 : 0 1
N = 2 : 2 0 2
N = 3 : 12 18 0 6
N = 4 : 216 192 144 0 24
N = 5 : 5280 5400 2400 1200 0 120
이런 형태로 간다. 일단 바로 보이는건
d[N][N-1] = 0
d[N][N] = N!
다. 모든 요소가 N!를 약수로 갖고 있고 각 행의 합은 1,4,36,576으로 (n!)^2이다.
N!으로 나누어 보자
X 0 1 2 3 4 5
N = 1 : 0 1
N = 2 : 1 0 1
N = 3 : 2 3 0 1
N = 4 : 9 8 6 0 1
N = 5 : 44 45 20 10 0 1
이렇게 변한다.
그리고 d[N][i]를 어떻게 만드는지 생각해 보자.
i==0이면 N*N배열을 3 하나도 쓰지 않고 문제의 조건에 맞춰서 만드는 경우의 수이다.
결론부터 말하면 NC2*N!에서 중복을 제거한 것이다. 그럼 i>0인 경우는?
d[N][i]은 d[N-1][i-1]에다가 3을 하나 끼워넣고 각각의 열 혹은 행이 오는 모든 경우의 수, N!를 곱한 후 중복을 제거한게 된다.
저 3을 끼워넣는것도 Combination으로 예상된다. 그렇다면 다시한번 nCx로 나누어 보자
X 0 1 2 3 4 5
N = 1 : 0 1
N = 2 : 1 0 1
N = 3 : 2 1 0 1
N = 4 : 9 2 1 0 1
N = 5 : 44 9 2 1 0 1
패턴이 보인다.
알아본 결과 1 0 1 2 9 44 ... 로 나가는 수열은 subfactorial 이라고 하며 sf(n) = (n-1)*(sf(n-1) + sf(n-2)) 점화식을 가진다.
역으로 타고 가게 되면 d[n][x] = sf(n-x) * nCx * n! 이 나온다!
sf(n), n!은 N만에 미리 구해놓을 수 있다.
nCx가 문제인데 일반적으로 우리는 이항계수를 2차원 배열을 통해 n^2만에 구한다.
근데 여기선 n이 100만이기 때문에 저렇게 하면 절대로 안된다. 이럴땐 구글이다.
이항계수 logn을 검색하니 좌르륵 뜬다. 갓 jason님의 코드를 참고해서 해보자. nCx가 logp만에 구해진다!
이제 factorial, subfactorial, inv(n!)를 전처리 하고 알맞게 계산을 하면 된다.
#include <iostream> #include <algorithm> #include <memory.h> using namespace std; typedef long long ll; #define mod 1000000007 int t, tc, N, X; ll fact[1000001]; ll d[1000001]; ll a[1000001]; ll dd[1000001]; inline ll _power(ll x, ll y) { ll ret = 1; while (y) { if (y & 1) ret = (ret*x) % mod; x = (x*x) % mod; y >>= 1; } return ret; } int main() { ios::sync_with_stdio(false); cin.tie(0); d[0] = 1; d[1] = 0; fact[0] = fact[1] = dd[0] = dd[1] = 1; a[1] = 1; for (ll i = 2; i <= 1000000; ++i) { d[i] = ((i - 1)*(d[i - 1] + d[i - 2])) % mod; fact[i] = (fact[i - 1] * i) % mod; a[i] = ((a[i - 1]) * ((i*i) % mod) % mod) % mod; dd[i] = _power(fact[i], mod - 2LL); } cin >> t; for (tc = 1; tc <= t; ++tc) { cin >> N >> X; ll ans = a[N], w; for (int i = 0; i < X; ++i) { w = (d[N - i] * fact[N]) % mod; ll temp = (((dd[N - i] * dd[i]) % mod)*fact[N]) % mod; temp = (temp*w) % mod; ans = (ans - temp + mod) % mod; } cout << '\n'; cout << "#" << tc << " " << ans << '\n'; } return 0; } | cs |