题解 P5469 【[NOI2019]机器人】

2019-07-20 12:19:39


既然这题没题解那我就发一个吧,讲得不清楚请见谅

膜拜当场A掉的巨佬(我是NOI Cu滚粗狗)

欢迎移步https://www.cnblogs.com/suncongbo/p/11217236.html

首先我们考虑,一个序列位置最右边的最大值可以走遍整个序列,并且其余任何点都不能跨过这个位置。

所以我们可以区间dp, $dp[l][r][x]$表示区间$[l,r]$最大值不超过$x$的方案数,枚举最大值点$mid$及其值$k$, $dp[l][r][x]=\sum_{mid}\sum_{k}dp[l][mid-1][k]\times dp[mid+1][r][k-1]$, 也可以设$dp[l][r][x]$表示区间$[l,r]$的最大值恰好为$x$的方案数,枚举最大值点$mid$则有$dp[l][r][x]=\sum_{mid}\sum_{k\le x}dp[l][mid-1][k]\sum_{k<x}dp[mid+1][r][k]$.

可获得$35$分,当然如果你有梦想数组开大点卡卡常就有$50$分了。(然而我在考场上没梦想$35$分滚粗了)

然后正解的话,恰好为$x$那种状态比较好。

首先离散化,那么我们发现当$k$在每一段区间内时,转移是类似的。

一个结论是,当$k$在某一段区间内时$dp[l][r][k]$是关于$k$的不超过$(r-l)$次多项式。

证明: 首先$l=r$时显然是$0$次多项式,当$l<r$时,我们枚举$mid$然后左边有一个$mid-1-l$次多项式右边有一个$r-mid-1$次多项式,又因为转移要对左右两边多项式做前缀和再相乘(这个具体见下一段),所以次数要$+1$($k$次多项式的前缀和是$(k+1)$次多项式),所以总次数为$(mid-1-l+1)+(r-mid-1+1)=r-l$.

这里解释一下如何转移: $dp[l][r][x]$, 左右两边分开考虑,考虑现在枚举的$k$, 假设$k$在$x$区间前面的区间里,那么这个值与$x$区间内的自变量无关了,变成了“常数”(因为这个区间所包含的数无论如何都比自变量小),而这个“常数”的值就是$dp[l][r][k']$ ($k'$为$k$所在的区间)这个多项式在每个$k'$区间内的点的点值之和,把这个值加到$dp[l][r][x]$多项式的常数项里。

假设$k$在$x$区间里,那么新的多项式直接就等于这个多项式在区间内的小于等于自变量的前缀和(如果现在枚举的是左边),或者多项式在区间内小于自变量的前缀和(如果现在枚举的是右边)。

于是记忆化搜索一波,使用多项式前缀和进行转移,这样枚举$mid$之后复杂度为多项式次数的平方。

多项式前缀和需要预处理$s_k(x)=\sum^{x}_{i=0}x^k$, 这是一个$(k+1)$次多项式,所以Lagrange插值求出来系数。据说有其他的搞法,但是我只能想到这一种。

裸做$80$分起步(我裸做了一波得了$85$)

剪枝优化可获得$100$分。

好难写啊,我好菜啊……

代码

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<vector>
#include<algorithm>
#define llong long long
using namespace std;

const int P = 1e9+7;
const int N = 301;

llong quickpow(llong x,llong y)
{
    llong cur = x,ret = 1ll;
    for(int i=0; y; i++)
    {
        if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
        cur = cur*cur%P;
    }
    return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}

llong aux[N+4],aux2[N+4];
struct Polynomial
{
    vector<llong> a; int n;
    Polynomial() {}
    Polynomial(int _n) {n = _n; for(int i=0; i<=n; i++) a.push_back(0ll);}
    void clear() {n = 0; a.clear(); a.push_back(0ll);}
    void output() {printf("deg%d, ",n); for(int i=0; i<=n; i++) printf("%lld ",a[i]); puts("");}
    Polynomial operator +(Polynomial &arg) const
    {
        Polynomial ret(max(n,arg.n));
        for(int i=0; i<=min(n,arg.n); i++)
        {
            ret.a[i] = (a[i]+arg.a[i])%P;
        }
        for(int i=min(n,arg.n)+1; i<=n; i++) ret.a[i] = a[i];
        for(int i=min(n,arg.n)+1; i<=arg.n; i++) ret.a[i] = arg.a[i];
        return ret;
    }
    Polynomial operator -(Polynomial &arg) const
    {
        Polynomial ret(max(n,arg.n));
        for(int i=0; i<=min(n,arg.n); i++)
        {
            ret.a[i] = (a[i]-arg.a[i]+P)%P;
        }
        for(int i=min(n,arg.n)+1; i<=n; i++) ret.a[i] = a[i];
        for(int i=min(n,arg.n)+1; i<=arg.n; i++) ret.a[i] = P-arg.a[i];
        return ret;
    }
    Polynomial operator *(Polynomial &arg) const
    {
        Polynomial ret(n+arg.n);
        for(int i=0; i<=n; i++)
        {
            for(int j=0; j<=arg.n; j++)
            {
                ret.a[i+j] = (ret.a[i+j]+a[i]*arg.a[j])%P;
            }
        }
        return ret;
    }
    llong calc(llong x)
    {
        llong ret = 0ll;
        for(int i=n; i>=0; i--)
        {
            ret = (ret*x+a[i])%P;
        }
        return ret;
    }
    void interpoly(int _n,llong ax[],llong ay[])
    {
        n = _n; for(int i=0; i<=n; i++) a.push_back(0ll);
        for(int i=0; i<=n+1; i++) aux[i] = 0ll;
        aux[0] = 1ll;
        for(int i=0; i<=n; i++)
        {
            for(int j=i+1; j>0; j--)
            {
                aux[j] = (aux[j-1]-aux[j]*ax[i]%P+P)%P;
            }
            aux[0] = P-aux[0]*ax[i]%P;
        }
        for(int i=0; i<=n; i++)
        {
            llong tmp = 1ll;
            for(int j=0; j<=n; j++)
            {
                if(i==j) continue;
                tmp = tmp*(ax[i]-ax[j]+P)%P;
            }
            llong coe = mulinv(tmp);
            for(int j=n+1; j>=0; j--) {aux2[j] = aux[j];}
            for(int j=n; j>=0; j--)
            {
                a[j] = (a[j]+aux2[j+1]*coe%P*ay[i])%P;
                aux2[j] = (aux2[j]+ax[i]*aux2[j+1])%P;
            }
        }
    }
};
Polynomial tmp1,tmp2,tmp3;
Polynomial dp[2661][(N<<1)+3],sdp[2661][(N<<1)+3];
llong lval[2661][(N<<1)+3],rval[2661][(N<<1)+3];
int dpid[N+4][N+4];
Polynomial spw[N+4];
struct Interval
{
    llong lb,rb; //[1,2n]
} a[N+3];
vector<llong> disc;
llong spwx[N+3],spwy[N+3];
int mx[N+3][N+3];
int n,nsta;
llong ans;

int getid(llong x) {return lower_bound(disc.begin(),disc.end(),x)-disc.begin();} //no +1

Polynomial prefixsum(Polynomial poly)
{
    Polynomial ret(poly.n+1);
    for(int i=0; i<=poly.n; i++)
    {
        for(int j=0; j<=i+1; j++)
        {
            ret.a[j] = (ret.a[j]+poly.a[i]*spw[i].a[j])%P;
        }
    }
    return ret;
}

void dfs(int l,int r,int x)
{
    Polynomial tmp1,tmp2,tmp3;
    if(dpid[l][r] && dp[dpid[l][r]][x].a.size()>0) {return;}
    if(!dpid[l][r]) {nsta++; dpid[l][r] = nsta;}
    if(l==r)
    {
        if(!dpid[l][r]) {nsta++; dpid[l][r] = nsta;}
        dp[dpid[l][r]][x] = Polynomial(0); dp[dpid[l][r]][x].a[0] = (x>=a[l].lb&&x<=a[l].rb) ? 1ll : 0ll;
        sdp[dpid[l][r]][x] = prefixsum(dp[dpid[l][r]][x]);
        lval[dpid[l][r]][x] = sdp[dpid[l][r]][x].calc(disc[x-1]);
        rval[dpid[l][r]][x] = sdp[dpid[l][r]][x].calc(disc[x]);
        return;
    }
    dp[dpid[l][r]][x].clear(); sdp[dpid[l][r]][x].clear();
    if(mx[l][r]>x) return;
    for(int lenl=(r-l+1)>>1; lenl<=(r-l+1)+1-((r-l+1)>>1); lenl++)
    {
        int mid = l+lenl-1;
        if(!(x>=a[mid].lb && x<=a[mid].rb)) {continue;} //注意此处要特判
        tmp1.clear(); tmp2.clear();
        if(mid>l)
        {
            for(int k=1; k<=x; k++)
            {
                dfs(l,mid-1,k);
                if(k<x)
                {
                    tmp1.a[0] = (tmp1.a[0]+rval[dpid[l][mid-1]][k]-lval[dpid[l][mid-1]][k]+P)%P;
                }
                else
                {
                    tmp1 = tmp1+sdp[dpid[l][mid-1]][k];
                    tmp1.a[0] = (tmp1.a[0]-lval[dpid[l][mid-1]][k]+P)%P;
                }
            }
        }
        else
        {
            tmp1 = Polynomial(0); tmp1.a[0] = 1ll;
        }
        if(mid<r)
        {
            for(int k=0; k<=x; k++)
            {
                dfs(mid+1,r,k);
                if(k<x)
                {
                    tmp2.a[0] = (tmp2.a[0]+rval[dpid[mid+1][r]][k]-lval[dpid[mid+1][r]][k]+P)%P;
                }
                else
                {
                    tmp2 = tmp2+sdp[dpid[mid+1][r]][k];
                    tmp2 = tmp2-dp[dpid[mid+1][r]][k];
                    tmp2.a[0] = (tmp2.a[0]-lval[dpid[mid+1][r]][k]+P)%P;
                }
            }
        }
        else
        {
            tmp2 = Polynomial(0); tmp2.a[0] = 1ll;
        }
        tmp3 = tmp1*tmp2;
        dp[dpid[l][r]][x] = dp[dpid[l][r]][x]+tmp3;
    }
    sdp[dpid[l][r]][x] = prefixsum(dp[dpid[l][r]][x]);
    lval[dpid[l][r]][x] = sdp[dpid[l][r]][x].calc(disc[x-1]);
    rval[dpid[l][r]][x] = sdp[dpid[l][r]][x].calc(disc[x]);
}

int main()
{
    scanf("%d",&n);
    for(int i=1; i<=n; i++) {scanf("%lld%lld",&a[i].lb,&a[i].rb); disc.push_back(a[i].lb-1); disc.push_back(a[i].rb);}
    for(int i=0; i<=n; i++)
    {
        spwx[0] = 0ll; spwy[0] = 0ll;
        for(int j=1; j<=i+1; j++)
        {
            spwx[j] = j;
            spwy[j] = (spwy[j-1]+quickpow(j,i))%P;
        }
        spw[i].interpoly(i+1,spwx,spwy);
    }
    sort(disc.begin(),disc.end()); disc.erase(unique(disc.begin(),disc.end()),disc.end());
    for(int i=1; i<=n; i++) {a[i].lb = getid(a[i].lb); a[i].rb = getid(a[i].rb);}
    nsta = 1; for(int i=0; i<disc.size(); i++)
    {
        dp[1][i] = Polynomial(0); dp[1][i].a[0] = 1ll;
        sdp[1][i] = prefixsum(dp[1][i]);
        lval[1][i] = sdp[1][i].calc(disc[i-1]);
        rval[1][i] = sdp[1][i].calc(disc[i]);
    }
    for(int i=1; i<=n; i++)
    {
        mx[i][i] = a[i].lb;
        for(int j=i+1; j<=n; j++)
        {
            mx[i][j] = max(mx[i][j-1],(int)a[j].lb);
        }
    }
    ans = 0ll;
    for(int i=1; i<disc.size(); i++)
    {
        dfs(1,n,i);
        ans = (ans+sdp[dpid[1][n]][i].calc(disc[i])-sdp[dpid[1][n]][i].calc(disc[i-1])+P)%P;
    }
    printf("%lld\n",ans);
    return 0;
}