超過状態を考慮する桁DP 既出問題の解説と実装
おことわり
解説と実装は、[クリックで開閉します] というボタンをクリックして表示させてください。
また、DP テーブルの「$i$ 桁め」についての次元は、使い回しにより省略しています。
$A$ 円の買い物をしたいときに、$X$ 円を払い $X-A$ 円を釣り銭として受け取ったとします。
この三つの数において、桁の数字が三つとも同じである桁の個数を最大化したとき、その桁の個数はいくつですか。
・$1 \le A \le 10^{100}$
クリックで開閉します
dp[XはAより未満/同じ/超過][繰り上がったか] := 揃った最大の桁数
$X-A$ を決めると、$A$ との和で $X$ が決まります。
足し算に繰り上がりがあるので、下位桁から決めていきます。
$A$ より $X$ の方が一桁多いときがあるので、$A$ の先頭桁にゼロをつけておきますが、
$A, X, X-A$ の三つとも同じ桁数のときもあり、そのときは先頭桁のゼロを同じとカウントしてしまうので、
カウントする際、「先頭桁ではない」という条件を加えておきます。
#include <iostream>
#include <vector>
constexpr int inf = 987'654'321;
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
int next_state(int x, int y, int cur) {
int res = cmp(x, y);
if(res == 1) { res = cur; }
return res;
}
int f(std::string s) {
s = '0' + s;
int n = static_cast<int>(s.size());
// dp[XがAより未満/同じ/超過][繰り上がったか] := 揃った最大の桁数
std::vector<std::vector<int>> dp(3, std::vector<int>(2, -inf));
dp[1][0] = 0;
for(int i=n-1; i>=0; --i) { // 下から
std::vector<std::vector<int>> ndp(3, std::vector<int>(2, -inf));
for(int state=0; state<3; ++state) {
for(int carry=0; carry<2; ++carry) {
if(dp[state][carry] == -inf) { continue; }
for(int yi=0; yi<10; ++yi) { // X-A を決める
int xi = (carry + yi + s[i]-'0') % 10,
ncarry = (carry + yi + s[i]-'0') / 10,
nstate = next_state(xi, s[i]-'0', state);
bool yes = i > 0 && (yi == xi) && (xi == (s[i]-'0'));
ndp[nstate][ncarry] = std::max(ndp[nstate][ncarry], dp[state][carry] + yes);
}
}
}
dp = ndp;
}
int res = 0;
for(int state : {1, 2}) {
res = std::max(res, dp[state][0]);
}
return res;
}
int main() {
std::string s; std::cin >> s;
int res = f(s);
std::cout << res << '\n';
return 0;
}
$K$ 種類以下の数字を使って 整数 $A$ になるべく近づけたいとき、その差を求めてください。
・$1 \le A \le 10^{15}$
・$1 \le K \le 10$
クリックで開閉します
dp[近づけたい整数は A より未満/同じ/超過][先行ゼロを抜けたか][使った数字の集合] := 差の最小値の絶対値
サンプルの最後を見ると $A$ より大きい範囲も調べる必要がありそうなので、超過状態も考慮します。
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr i64 inf = 987'654'321'987'654'321LL;
std::vector<i64> p10; // p10[i] := 10**i
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
int sign(int state) {
return state ? 1 : -1;
}
i64 f(i64 x, int K) {
std::string s = std::to_string(x);
int n = static_cast<int>(s.size());
// dp[近づけたい整数は A より未満/同じ/超過][先行ゼロを抜けたか][使った数字の集合] := 最小値(絶対値を入れるので必ず非負である)
std::vector<std::vector<std::vector<i64>>> dp(3, std::vector<std::vector<i64>>(2, std::vector<i64>(1<<10, inf)));
dp[1][0][0] = 0;
for(int i=0; i<n; ++i) {
std::vector<std::vector<std::vector<i64>>> ndp(3, std::vector<std::vector<i64>>(2, std::vector<i64>(1<<10, inf)));
for(int state=0; state<3; ++state) {
for(int nuke0=0; nuke0<2; ++nuke0) {
for(int S0=0; S0<1<<10; ++S0) {
if(dp[state][nuke0][S0] == inf) { continue; }
for(int d=0; d<10; ++d) {
int nstate = state,
cur_state = cmp(d, s[i]-'0');
if(nstate == 1) { nstate = cur_state; }
int n_nuke0 = nuke0 || d > 0,
S1 = S0;
if(nuke0 || d > 0) { S1 |= 1<<d; }
i64 p = dp[state][nuke0][S0] * sign(state); // これまでの差
i64 q = abs(s[i]-'0' - d) * p10[n-1-i] * sign(cur_state); // 今回の桁による差
ndp[nstate][n_nuke0][S1] = std::min(ndp[nstate][n_nuke0][S1], std::abs(p + q));
}
}
}
}
dp = ndp;
}
i64 res = inf;
for(int state=0; state<3; ++state) {
for(int S0=0; S0<1<<10; ++S0) {
if(__builtin_popcount(S0) <= K) {
res = std::min(res, dp[state][1][S0]);
}
}
}
return res;
}
int main() {
p10.resize(18);
p10[0] = 1;
for(int i=0; i<17; ++i) {
p10[i+1] = p10[i] * 10;
}
i64 A; int K; std::cin >> A >> K;
i64 res = f(A, K);
std::cout << res << '\n';
return 0;
}
与えられる正整数 $N$ について、$X$ と $X+N$ のビットカウントが等しくなるような最小の $X$ を求めてください。
存在しない場合はそれを指摘してください。
・$1 \le$ クエリの数 $\le 100$
・$1 \le N \le 10^{16}$
クリックで開閉します
dp[前回繰り上がったか][XのビットカウントとX+Nのビットカウントとの差 + n] := Xの最小値
$X$ の $i$ ビット目を決めると $X+N$ の $i$ ビット目も決まります。
繰り上がりがあるので、$X$ を下位ビットから決めていきます。
ビットカウントの差は $n$ を $N$ のビット長として $-n$ 以上 $n$ 以下の範囲を取り、マイナスになり得るので、$+n$ して下駄を履かせます。
$X$ と $N$ の大小関係に制約がないので、状態として持つ必要はありません。
#include <bitset>
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr i64 inf = 987'654'321'987'654'321LL;
std::string to_bin(const i64 a) {
std::string s = std::bitset<64>(a).to_string();
size_t p = s.find('1');
if(p == std::string::npos) { p = s.size() - 1; }
return s.substr(p);
}
i64 f(i64 x) {
std::string s = to_bin(x);
s = '0' + s;
int n = static_cast<int>(s.size());
// dp[繰り上がったか][XとX+Nのビットカウントの差 + n] := Xの最小値
std::vector<std::vector<i64>> dp(2, std::vector<i64>(n+n+5, inf));
dp[0][n] = 0;
i64 p2 = 1;
for(int i=n-1; i>=0; --i) { // 下から
std::vector<std::vector<i64>> ndp(2, std::vector<i64>(n+n+5, inf));
for(int carry=0; carry<2; ++carry) {
for(int dif=0; dif<n+n+5; ++dif) {
if(dp[carry][dif] == inf) { continue; }
for(int xi=0; xi<2; ++xi) {
int yi = (carry + xi + s[i]-'0') % 2,
ncarry = (carry + xi + s[i]-'0') / 2,
ndif = dif + (xi - yi);
ndp[ncarry][ndif] = std::min(ndp[ncarry][ndif], dp[carry][dif] + p2 * xi);
}
}
}
dp = ndp;
p2 *= 2;
}
i64 res = dp[0][n];
if(res == inf) { res = -1; }
return res;
}
int main() {
int T; std::cin >> T;
for(int loop=0; loop<T; ++loop) {
i64 n; std::cin >> n;
i64 res = f(n);
std::cout << res << '\n';
}
return 0;
}
非負整数 $N$ が与えられます。
$0 \le x \le y \le N$ の範囲で $(x\ \mathrm{AND}\ y) \lt (x\ \mathrm{XOR}\ y) \lt (x\ \mathrm{OR}\ y)$ を満たす $(x, y)$ の組の個数を求めてください。
・$0 \le N \le 10^{18}$
クリックで開閉します
dp[x が y より未満/同じ/超過][y が N より未満/同じ/超過][AND が XOR より未満/同じ/超過][XOR が OR より未満/同じ/超過] := パターン数
超過状態も認めた上で DP テーブルを更新し、集計するときに超過状態を外すようにすると楽です。
#include <bitset>
#include <cmath>
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr int mod = static_cast<int>(powl(10, 9)) + 7;
std::string to_bin(const i64 a) {
std::string s = std::bitset<64>(a).to_string();
size_t p = s.find('1');
if(p == std::string::npos) { p = s.size() - 1; }
return s.substr(p);
}
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
int next_state(int x, int y, int cur_state) {
int nstate = cur_state;
if(nstate == 1) { nstate = cmp(x, y); }
return nstate;
}
i64 f(i64 x) {
std::string s = to_bin(x);
int n = static_cast<int>(s.size());
// dp[xがyより未満/同じ/超過][yがNより未満/同じ/超過][ANDがXORより未満/同じ/超過][XORがORより未満/同じ/超過] := パターン数
std::vector<std::vector<std::vector<std::vector<i64>>>> dp(3, std::vector<std::vector<std::vector<i64>>>(3, std::vector<std::vector<i64>>(3, std::vector<i64>(3, 0))));
dp[1][1][1][1] = 1;
for(int i=0; i<n; ++i) {
std::vector<std::vector<std::vector<std::vector<i64>>>> ndp(3, std::vector<std::vector<std::vector<i64>>>(3, std::vector<std::vector<i64>>(3, std::vector<i64>(3, 0))));
for(int state0=0; state0<3; ++state0) { // x と y の大小
for(int state1=0; state1<3; ++state1) { // y と N の大小
for(int state2=0; state2<3; ++state2) { // x and y と x xor y の大小
for(int state3=0; state3<3; ++state3) { // x xor y と x or y の大小
if(!dp[state0][state1][state2][state3]) { continue; }
for(int xi=0; xi<2; ++xi) { // x の i 桁め
for(int yi=0; yi<2; ++yi) { // y の i 桁め
int nstate0 = next_state(xi, yi, state0),
nstate1 = next_state(yi, s[i]-'0', state1),
nstate2 = next_state(xi & yi, xi ^ yi, state2),
nstate3 = next_state(xi ^ yi, xi | yi, state3);
ndp[nstate0][nstate1][nstate2][nstate3] += dp[state0][state1][state2][state3];
ndp[nstate0][nstate1][nstate2][nstate3] %= mod;
}
}
}
}
}
}
dp = ndp;
}
i64 res = 0;
for(int state0=0; state0<2; ++state0) { // x は y 以下であるべきなので、未満と同じのものを数える
for(int state1=0; state1<2; ++state1) { // y は N 以下 〃
res += dp[state0][state1][0][0]; // AND は XOR 未満、XOR は OR 未満であるべき
res %= mod;
}
}
return res;
}
int main() {
i64 N; std::cin >> N;
i64 res = f(N);
std::cout << res << '\n';
return 0;
}
任意の非負整数 $k$ について、額面が $1 \times 10^{k}$ と $5 \times 10^{k}$ である硬貨があるとします。
$N$ 円の買い物をしたとき、$X (\ge N)$ 円を支払い、$X - N$ 円の釣り銭を受け取るとします。
支払いにおける硬貨の枚数と受け取る釣り銭の硬貨の枚数の合計を最小化したいとき、その最小値を求めてください。
・$1 \le N \le 10^{10{,}000}$
クリックで開閉します
dp[X-N は N より超過か][前回繰り上がったか] := 最小値
$X-N$ 円を決めると、$N$ 円との和で $X$ 円が決まります。
繰り上がりがあるので下位桁からやります。
最初のサンプルを見ると $X$ 円が $N$ 円より一桁大きいことがあると分かるので、$N$ (の文字列表現)の先頭にゼロをつけておきます。
#include <iostream>
#include <vector>
constexpr int inf = 987'654'321;
// 0 1 2 3 4 5 6 7 8 9
const std::vector<int> v = {0, 1, 2, 3, 4, 1, 2, 3, 4, 5};
int f(std::string s) {
s = "0" + s;
int n = static_cast<int>(s.size());
// dp[X-N は N より超過か][前回繰り上がったか] := 最小値
std::vector<std::vector<int>> dp(2, std::vector<int>(2, inf));
dp[0][0] = 0;
for(int i=n-1; i>=0; --i) { // 下から
std::vector<std::vector<int>> ndp(2, std::vector<int>(2, inf));
for(int over=0; over<2; ++over) {
for(int carry=0; carry<2; ++carry) {
if(dp[over][carry] == inf) { continue; }
for(int yi=0; yi<10; ++yi) { // X - N を決める
int nover = yi > s[i]-'0';
if(yi == s[i]-'0') { nover = over; }
int xi = (carry + yi + s[i]-'0') % 10, // X が決まる
ncarry = (carry + yi + s[i]-'0') / 10;
ndp[nover][ncarry] = std::min(ndp[nover][ncarry], dp[over][carry] + v[xi] + v[yi]);
}
}
}
dp = ndp;
}
int res = std::min(dp[0][0], dp[1][0]);
return res;
}
int main() {
std::string N; std::cin >> N;
int res = f(N);
std::cout << res << '\n';
return 0;
}
非負整数 $a, b$ について、
$\begin{cases}a\ \mathrm{XOR}\ b &= u \\ a + b &= v \end{cases}$
となるような非負整数 $u, v$ の組み合わせは $0 \le u, v \le N$ の範囲にいくつありますか。
・$1 \le N \le 10^{18}$
クリックで開閉します
dp[u は N より超過か][v は N より超過か][前回繰り上がったか] := パターン数
変数が多く混乱しますが、動かしやすいのは $a$ と $b$ なので、これらを決めます。
$a$ と $b$ の両方を決めると $u$ と $v$ も決まります。
足し算に繰り上がりがあるので下位ビットからやります。
$a$ と $b$ のビットが $(0, 1)$ のときと $(1, 0)$ のときは $u, v$ が組み合わせとして同じになってしまうので、どちらかを除外する必要があります。
#include <bitset>
#include <cmath>
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr int mod = static_cast<int>(powl(10, 9)) + 7;
std::string to_bin(const i64 a) {
std::string s = std::bitset<64>(a).to_string();
size_t p = s.find('1');
if(p == std::string::npos) { p = s.size() - 1; }
return s.substr(p);
}
int next_state(int val, int si, int cur_state) {
int res = val > si;
if(val == si) { res = cur_state; }
return res;
}
i64 f(i64 x) {
std::string s = to_bin(x);
int n = static_cast<int>(s.size());
// dp[uはNより超過か][vはNより超過か][前回繰り上がったか] := パターン数
std::vector<std::vector<std::vector<i64>>> dp(2, std::vector<std::vector<i64>>(2, std::vector<i64>(2, 0)));
dp[0][0][0] = 1;
for(int i=n-1; i>=0; --i) { // 下から
std::vector<std::vector<std::vector<i64>>> ndp(2, std::vector<std::vector<i64>>(2, std::vector<i64>(2, 0)));
for(int overU=0; overU<2; ++overU) {
for(int overV=0; overV<2; ++overV) {
for(int carry=0; carry<2; ++carry) {
if(!dp[overU][overV][carry]) { continue; }
for(int ai=0; ai<2; ++ai) { // a の i 桁めのビットを決める
for(int bi=0; bi<=ai; ++bi) { // b の i 桁めのビットを決める
int ui = ai ^ bi, // u の i 桁めのビットが決まる
vi = (carry + ai + bi) % 2, // v の i 桁めのビットが決まる
ncarry = (carry + ai + bi) / 2; // v の計算での繰り上がりの有無が決まる
int noverU = next_state(ui, s[i]-'0', overU),
noverV = next_state(vi, s[i]-'0', overV);
ndp[noverU][noverV][ncarry] += dp[overU][overV][carry];
ndp[noverU][noverV][ncarry] %= mod;
}
}
}
}
}
dp = ndp;
}
return dp[0][0][0]; // 最後に繰り上がるのは数えないように
}
int main() {
i64 N; std::cin >> N;
i64 res = f(N);
std::cout << res << '\n';
return 0;
}
与えられる正整数 $N$ と数字和が等しい正整数の最小値 $X$ を求めてください。
ただし、$X \neq N$ でなければいけません。
・$1 \le N \le 10^{15}$
クリックで開閉します
dp[X は N より未満/同じ/超過][数字和] := 最小値
$N$ が例えば $999$ だったとき正解は $1899$ で、必ずしも $X \lt N$ だとは限りません。
したがって超過状態も考慮して桁DPします。
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr i64 inf = 987'654'321'987'654'321LL;
int dsum(i64 x) {
int res = 0;
for(; x; x/=10) {
res += static_cast<int>(x % 10);
}
return res;
}
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
i64 f(i64 x) {
std::string s = std::to_string(x);
s = '0' + s;
int n = static_cast<int>(s.size());
int D = dsum(x);
// dp[未満/同じ/超過][数字和] := 最小値
std::vector<std::vector<i64>> dp(3, std::vector<i64>(D+1, inf));
dp[1][0] = 0;
for(int i=0; i<n; ++i) {
std::vector<std::vector<i64>> ndp(3, std::vector<i64>(D+1, inf));
for(int state=0; state<3; ++state) {
for(int acc=0; acc<=D; ++acc) {
if(dp[state][acc] == inf) { continue; }
for(int d=0; d<10; ++d) {
int nstate = state;
if(nstate == 1) { nstate = cmp(d, s[i]-'0'); }
int nacc = acc + d;
if(nacc > D) { continue; }
ndp[nstate][nacc] = std::min(ndp[nstate][nacc], dp[state][acc] * 10 + d);
}
}
}
dp = ndp;
}
i64 res = std::min(dp[0][D], dp[2][D]);
return res;
}
int main() {
i64 N; std::cin >> N;
i64 res = f(N);
std::cout << res << '\n';
return 0;
}
二進表記で与えられる正整数 $L$ について、
$\begin{cases}a + b &\le& L \\ a + b &= a\ \mathrm{XOR}\ b\end{cases}$
を満たす非負整数 $a, b$ の組み合わせはいくつありますか。
・$1 \le L \lt 2^{100{,}001}$
クリックで開閉します
dp[a+bはLより未満/同じ/超過][a+bはa XOR bより未満/同じ/超過][足し算で繰り上がったか] := パターン数
足し算に繰り上がりがあるので、下位桁からやります。
※「一般に $a + b \ge a\ \mathrm{XOR}\ b$ である」という事実を使わず、機械的に実装しています。
#include <cmath>
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr int mod = static_cast<int>(powl(10, 9)) + 7;
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
int next_state(int x, int y, int cur) {
int res = cmp(x, y);
if(res == 1) { res = cur; }
return res;
}
i64 f(std::string s) {
int n = static_cast<int>(s.size());
// dp[a+bはLより未満/同じ/超過][a+bはa XOR bより未満/同じ/超過][足し算で繰り上がったか] := パターン数
std::vector<std::vector<std::vector<i64>>> dp(3, std::vector<std::vector<i64>>(3, std::vector<i64>(2, 0)));
dp[1][1][0] = 1;
for(int i=n-1; i>=0; --i) { // 下から
std::vector<std::vector<std::vector<i64>>> ndp(3, std::vector<std::vector<i64>>(3, std::vector<i64>(2, 0)));
for(int state1=0; state1<3; ++state1) {
for(int state2=0; state2<3; ++state2) {
for(int carry=0; carry<2; ++carry) {
if(!dp[state1][state2][carry]) { continue; }
for(int ai=0; ai<2; ++ai) {
for(int bi=0; bi<2; ++bi) {
int nstate1 = next_state((carry+ai+bi)%2, s[i]-'0', state1),
nstate2 = next_state((carry+ai+bi)%2, ai ^ bi, state2),
ncarry = (carry + ai + bi) / 2;
ndp[nstate1][nstate2][ncarry] += dp[state1][state2][carry];
ndp[nstate1][nstate2][ncarry] %= mod;
}
}
}
}
}
dp = ndp;
}
i64 res = dp[0][1][0] + dp[1][1][0];
res %= mod;
return res;
}
int main() {
std::string L; std::cin >> L;
i64 res = f(L);
std::cout << res << '\n';
return 0;
}
与えられる整数 $L, R$ について、
$L \le x \le y \le R$ の範囲で $(y\ \mathrm{XOR}\ x) = (y \bmod{x})$ となるような $(x, y)$ の組はいくつありますか。
・$1 \le L \le R \le 10^{18}$
クリックで開閉します
$x$ と $y$ の最上位ビット位置は同じでなければならず(詳しくは
公式の解説 を参照)、すると $y$ を $x$ で割った余りは $y-x$ になります。
$y-x$ と $x$ を決めると、足し算によって $y$ が決まるので、下位桁から桁DPをします。
dp[Lはxより未満/同じ/超過][xはyより未満/同じ/超過][yはRより未満/同じ/超過][足し算で繰り上がったか][xとyの最上位ビット位置が同じか]
:= パターン数
#include <bitset>
#include <cmath>
#include <iostream>
#include <vector>
using i64 = int64_t;
constexpr int mod = static_cast<int>(powl(10, 9)) + 7;
std::string to_bin(const i64 a) {
std::string s = std::bitset<62>(a).to_string();
return s;
}
int cmp(int x, int y) {
if(x < y) { return 0; }
if(x == y) { return 1; }
return 2;
}
int next_state(int x, int y, int cur) {
int res = cmp(x, y);
if(res == 1) { res = cur; }
return res;
}
i64 f(i64 L, i64 R) {
std::string s = to_bin(L),
t = to_bin(R);
int n = 62;
// dp[Lはxより未満/同じ/超過][xはyより〃][yはRより〃][繰り上がったか][xとyの先頭桁がともに1か] := パターン数
std::vector<std::vector<std::vector<std::vector<std::vector<i64>>>>> dp(3, std::vector<std::vector<std::vector<std::vector<i64>>>>(3, std::vector<std::vector<std::vector<i64>>>(3, std::vector<std::vector<i64>>(2, std::vector<i64>(2, 0)))));
dp[1][1][1][0][0] = 1;
for(int i=n-1; i>=0; --i) {
std::vector<std::vector<std::vector<std::vector<std::vector<i64>>>>> ndp(3, std::vector<std::vector<std::vector<std::vector<i64>>>>(3, std::vector<std::vector<std::vector<i64>>>(3, std::vector<std::vector<i64>>(2, std::vector<i64>(2, 0)))));
for(int state1=0; state1<3; ++state1) {
for(int state2=0; state2<3; ++state2) {
for(int state3=0; state3<3; ++state3) {
for(int carry=0; carry<2; ++carry) {
for(int same=0; same<2; ++same) {
if(!dp[state1][state2][state3][carry][same]) { continue; }
for(int xi=0; xi<2; ++xi) {
for(int zi=0; zi<2; ++zi) {
int yi = (carry + zi + xi) % 2,
ncarry = (carry + zi + xi) / 2;
int nstate1 = next_state(s[i]-'0', xi, state1),
nstate2 = next_state(xi, yi, state2),
nstate3 = next_state(yi, t[i]-'0', state3),
nsame = xi == yi;
if(!xi && !yi) { nsame = same; }
if(zi != (xi ^ yi)) { continue; }
ndp[nstate1][nstate2][nstate3][ncarry][nsame] += dp[state1][state2][state3][carry][same];
ndp[nstate1][nstate2][nstate3][ncarry][nsame] %= mod;
}
}
}
}
}
}
}
dp = ndp;
}
i64 res = 0;
for(int state1=0; state1<2; ++state1) {
for(int state2=0; state2<2; ++state2) {
for(int state3=0; state3<2; ++state3) {
res += dp[state1][state2][state3][0][1];
res %= mod;
}
}
}
return res;
}
int main() {
i64 L, R; std::cin >> L >> R;
i64 res = f(L, R);
std::cout << res << '\n';
return 0;
}
(ΦωΦ)<おしまい