P7324 表达式求值 题解

2021-03-29 13:14:50


这题我做法和官方题解不一样,而且身边问了一圈网上看了一圈甚至没有找到一个和我做法相同的题解,给同学讲我的做法没有一个人愿意听完……于是放上来 ex 一下大家(bushi

首先题意是,“枚举把每个 $?$ 替换成 $\max$ 和 $\min$ 的方案,求最后的运算结果之和”。因为 $\max(a,b)+\min(a,b)=a+b$,所以这个显然可以转化为“枚举把每个 $?$ 替换成参与运算的左边一项和右边一项的方案,求最后的运算结果之和”。例如:$$\max(?(a,\min(b,c),d)=\max(a,d)+\max(\min(b,c),d)$$

接下来我们考虑容斥。现在的运算有 $\max$、$\min$、$+$ 三种,我们可以把 $\min$ 拆开:$$\min(a,b)=a+b-\max(a,b)$$ 这看上去有点奇怪,但其实和上面的拆问号是完全相同的原理。这就相当于:“枚举把每个 $\min$ 分别替换成参与运算的左边一项、右边一项、两边的 $\max$ 的方案,求最后的运算结果之和,其中每有一个替换成两边的 $\max$,运算结果就要乘以 $-1$ 的系数。”

例如:$\max(\min(a,b),?(c,d))=\max(a,c)+\max(b,c)-\max(\max(a,b),c)+\max(a,d)+\max(b,d)-\max(\max(a,b),d)=\max(a,c)+\max(b,c)-\max(a,b,c)+\max(a,d)+\max(b,d)-\max(a,b,d)$

现在我们把表达式化成了这样的形式:整个表达式包含若干个由加号或减号连接的项,每一项都形如一个或多个元素的 $\max$ (单个元素的 $\max$ 就是它本身).

因为我们是枚举了每种 $\min$ 和 $?$ 拆成的两种或三种方案,所以这个表达式的项数为关于 $n$ 的指数级。

但有一个性质是,“一个或多个元素的 $\max$”就是 $\{a_0,a_1,...,a_{m-1}\}$ 的一个子集的 $\max$,子集一共只有 $O(2^m)$ 个。而对于某个子集 $S$ 的项,可以将所有 $\pm \max(S)$ 的项进行合并,合并后记 $c_S$ 为该子集的所有项的系数和,那么答案就等于 $\sum_{S\in \{a_0,...,a_{m-1}\}} c_S\cdot \max(S)$.

我们考虑使用 DP 求出所有的 $c_S$. 假设求出了所有 $c_S$,那显然能在 $O(2^m)$ 时间内对一组 $a_i$ 的取值求出答案,那么就可以在 $O(n2^m)$ 的时间复杂度内解决问题了。

下面讨论如何 DP:建出表达式树,设 $c[u][S]$ 表示 $u$ 这个节点的子树内 $S$ 集合的系数,考虑如何合并两个由运算符连接的节点 $u$ 和 $v$,设合并后的节点为 $u'$:

(1) 运算符为 $\max$: 枚举 $S,T$

$$c[u][S]\times c[v][T]\rightarrow c[u'][S\cup T]$$

(2) 运算符为 $?$: 枚举 $S,T$

$$c[u][S]\times c[v][T]\rightarrow c[u'][S]$$

$$c[u][S]\times c[v][T]\rightarrow c[u'][T]$$

(3) 运算符为 $\min$: 枚举 $S,T$

$$c[u][S]\times c[v][T]\rightarrow c[u'][S]$$

$$c[u][S]\times c[v][T]\rightarrow c[u'][T]$$

$$-c[u][S]\times c[v][T]\rightarrow c[u'][S\cup T]$$

如果暴力做,时间复杂度高达 $O(4^m|S|)$.

这显然是一个集合 or 卷积。可以使用 FMT 优化,时间复杂度 $O(2^mm|S|)$,依然无法通过。

但一个经典的 trick 是,这里不难发现我们只需要给所有值同时乘一个数和做 or 卷积,根据 FMT 的线性性,可以直接带入 FMT 后的点值计算。只需要在叶子节点处进行 FMT,到最后根节点处 IFMT 回来。叶子节点因为原 $c$ 数组只有一个位置有值,所以 FMT 显然能在 $O(2^m)$ 时间内完成,那么 DP 的复杂度就降到 $O(2^m|S|)$ 了。

所以总时间复杂度 $O(2^m(|S|+n))$. (感谢评测机不卡常==)

考场代码如下(未去掉 freopen):

#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define x first
#define y second
#define y1 sjfoahsro
#define tm orijyjwre
using namespace std;

inline int read()
{
    int x = 0,f = 1; char ch = getchar();
    for(;!isdigit(ch);ch=getchar()) {if(ch=='-') {f = -1;}}
    for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
    return x*f;
}

const int mxN = 5e4;
const int mxM = 10;
const int P = 1e9+7;
int n,m,tp,len;
int lobit[(1<<mxM)+3];
llong a[mxN+3][mxM+3];
llong f[mxN/2+3][(1<<mxM)+3],g[(1<<mxM)+3]; char str[mxN+3],op[mxN+3];

void merge(llong x1[],llong x2[],char typ)
{
    if(typ=='<')
    {
        for(int i=0; i<(1<<m); i++)
        {
            x1[i] = (x1[i]*x2[(1<<m)-1]+x2[i]*x1[(1<<m)-1]-x1[i]*x2[i]%P+P)%P;
        }
    }
    else if(typ=='?')
    {
        for(int i=0; i<(1<<m); i++)
        {
            x1[i] = (x1[i]*x2[(1<<m)-1]+x2[i]*x1[(1<<m)-1])%P;
        }
    }
    else
    {
        for(int i=0; i<(1<<m); i++)
        {
            x1[i] = x1[i]*x2[i]%P;
        }
    }
}

int main()
{
    freopen("expr.in","r",stdin);
    freopen("expr.out","w",stdout);
    n = read(),m = read();
    for(int i=0; i<m; i++) for(int j=(1<<i); j<(1<<m); j+=(1<<i)) {lobit[j] = i;}
    for(int j=0; j<m; j++) for(int i=1; i<=n; i++) a[i][j] = read();
    scanf("%s",str+1); len = strlen(str+1); for(int i=1; i<=len; i++) if(str[i]>='0'&&str[i]<='9') {str[i] -= 48;}
    tp = 1;
    for(int i=1; i<=len; i++)
    {
        if(str[i]=='(')
        {
            tp++;
        }
        else if(str[i]==')')
        {
            tp--;
            if(op[tp])
            {
                merge(f[tp],f[tp+1],op[tp]);
                op[tp] = 0;
            }
            else
            {
                for(int j=0; j<(1<<m); j++) {f[tp][j] = f[tp+1][j];}
            }
        }
        else if(str[i]>=0&&str[i]<=9)
        {
            if(op[tp])
            {
                for(int j=0; j<(1<<m); j++) {g[j] = (j&(1<<str[i]))>0?1ll:0ll;}
                merge(f[tp],g,op[tp]);
                op[tp] = 0;
            }
            else
            {
                for(int j=0; j<(1<<m); j++) {f[tp][j] = (j&(1<<str[i]))>0?1ll:0ll;}
            }
        }
        else
        {
            op[tp] = str[i];
        }
//      printf("tp=%d\n",tp);
//      for(int j=1; j<=tp; j++) {for(int k=0; k<(1<<m); k++) printf("%lld ",f[j][k]); printf("op=%c\n",op[j]);}
    }
    for(int i=0; i<m; i++)
    {
        for(int j=0; j<(1<<m); j++) if(j&(1<<i))
        {
            f[1][j] = (f[1][j]-f[1][j^(1<<i)]+P)%P;
        }
    }
    llong ans = 0ll;
    for(int i=1; i<=n; i++)
    {
        g[0] = 0ll;
        for(int j=1; j<(1<<m); j++)
        {
            g[j] = max(g[j^(1<<lobit[j])],a[i][lobit[j]]);
            ans = (ans+g[j]*f[1][j])%P;
        }
    }
    printf("%lld\n",ans);
    return 0;
}