Digit DP

From EOJ Wiki
Revision as of 12:44, 22 March 2018 by Ultmaster (talk | contribs) (Created page with "Hihocoder 1033: $f(x) = a_0 - a_1 + a_2 - \cdots + (-1)^{n-1} a_{n-1}$. e.g., $f(3214567)=3-2+1-4+5-6+7=4$. Find $\sum_{x=l}^r [f(x)=k] x$. <syntaxhighlight lang='cpp'> #in...")
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

Hihocoder 1033:

$f(x) = a_0 - a_1 + a_2 - \cdots + (-1)^{n-1} a_{n-1}$. e.g., $f(3214567)=3-2+1-4+5-6+7=4$.

Find $\sum_{x=l}^r [f(x)=k] x$.

#include <bits/stdc++.h>
using namespace std;
const int modn = 1e9 + 7;
const int offset = 200;
typedef long long LL;

int p10[22];
int dp[22][22][2][400], f[22][22][2][400];
int t[22], k;

// save for both limit and unlimited
pair<int, int> dfs(int len, int pos, int limit, int partial_sum) {
    // printf("%d %d %d %d\n", len, pos, limit, partial_sum);
    int &ret = dp[len][pos][limit][partial_sum + offset];
    int &fx = f[len][pos][limit][partial_sum + offset];
    if (ret == -1) {
        if (pos == 0) {
            ret = (partial_sum == k);
            fx = 0;
        } else {
            int upper = limit ? t[pos] : 9;
            ret = 0;
            for (int digit = upper; digit >= 0; --digit) {
                int nlen = len, padd = digit;
                if (pos == len && digit == 0) nlen--;
                if ((len - pos) % 2 == 1) padd = -padd;
                pair<int, int> tmp = dfs(nlen, pos - 1, limit && digit == upper, partial_sum + padd);
                ret = (ret + tmp.first) % modn;
                fx = (fx + 1LL * digit * p10[pos - 1] % modn * tmp.first % modn + tmp.second) % modn;
            }
        }
    }
    return make_pair(ret, fx);
}

int solve(long long x) {
    if (x <= 0) return 0;
    int cnt = 0;
    while (x) {
        t[++cnt] = x % 10;
        x /= 10; 
    }
    memset(dp, -1, sizeof dp);
    memset(f, 0, sizeof f);
    p10[0] = 1;
    for (int i = 1; i <= 20; ++i)
        p10[i] = (1LL * p10[i - 1] * 10) % modn;
    pair<int, int> ret = dfs(cnt, cnt, 1, 0);
    return ret.second;
}

int main() {
    long long a, b;
    cin >> a >> b >> k;
    cout << (solve(b) - solve(a - 1) + modn) % modn << endl;
}
```


```cpp
LL dfs(LL base, LL pos, LL len, LL s, bool limit) {
    if (pos == -1) return s ? base : 1;
    if (!limit && dp[base][pos][len][s] != -1) return dp[base][pos][len][s];
    LL ret = 0;
    LL ed = limit ? a[pos] : base - 1;
    FOR (i, 0, ed + 1) {
        tmp[pos] = i;
        if (len == pos)
            ret += dfs(base, pos - 1, len - (i == 0), s, limit && i == a[pos]);
        else if (s &&pos < (len + 1) / 2)
            ret += dfs(base, pos - 1, len, tmp[len - pos] == i, limit && i == a[pos]);
        else
            ret += dfs(base, pos - 1, len, s, limit && i == a[pos]);
    }
    if (!limit) dp[base][pos][len][s] = ret;
    return ret;
}

LL solve(LL x, LL base) {
    LL sz = 0;
    while (x) {
        a[sz++] = x % base;
        x /= base;
    }
    return dfs(base, sz - 1, sz - 1, 1, true);
}