Author Avatar
Axell 1月 26, 2019
  • 在其它设备中阅读本文章


-

概念
堆即二叉堆,具有快速查找最小/大值,插入,删除,修改的功能

实现

  1. STL priority_queue<> 如果是结构体,需要定义小于号
    [post cid=”36” /]

  2. 手工实现,双映射

    #include <bits/stdc++.h>
    using namespace std;
    #define fa(i) (i>>1)
    #define lc(i) (i<<1)
    #define rc(i) ((i<<1)|1)
    typedef long long ll;
    const int MAXN=100005;

    int cnt,home[MAXN],opt[MAXN],v[MAXN]; //heap[home[i]].v==v[i],v[opt[i]]==heap[i].v
    int n,a[MAXN],Next[MAXN],Pre[MAXN],V[MAXN],id[MAXN],p,k;
    long long d[MAXN],ans;
    bool used[600005];

    struct node{

    ll v;
    int pos;

    }heap[MAXN*4];

    bool operator <(node x,node y){return x.v<y.v;}

    void up(int rt){

    int j=fa(rt);
    while (j){
        if (heap[rt]<heap[j]) swap(heap[rt],heap[j]),swap(home[opt[rt]],home[opt[j]]),swap(opt[rt],opt[j]),rt=j,j=fa(j);
        else break;
    }

    }

    void down(int rt){

    while (lc(rt)<=cnt){
        int l=lc(rt),r=rc(rt);
        if (r>cnt || heap[l]<heap[r]) r=l;
        if (heap[r]<heap[rt]) swap(heap[rt],heap[r]),swap(home[opt[rt]],home[opt[r]]),swap(opt[rt],opt[r]);
        else break;
        rt=r;
    }

    }

    void insert(node x,int pos){ //插入

    heap[++cnt]=x;
    opt[cnt]=pos;
    home[pos]=cnt;
    up(cnt);

    }

    void del(int pos){ //删除

    heap[home[pos]]=heap[cnt];
    opt[home[pos]]=opt[cnt];
    home[opt[cnt]]=home[pos];
    cnt--;
    up(home[pos]);
    down(home[pos]);

    }

    void ch(node x,int pos){ //更改

    heap[home[pos]]=x;
    up(home[pos]);
    down(home[pos]);

    }

例题

数据备份

题面

思路

首先,因为要选出K对,所以选出每一对的必然是相邻的,因此可以产生D数组,储存相邻两地的距离,其次,由于不能重复选取,因此如果选择了D2,则不能选择D1,D3,可以用链表配合堆实现,需要手写堆操作,详见进阶指南P83-84

代码

手写堆

#include <bits/stdc++.h>
using namespace std;
#define fa(i) (i>>1)
#define lc(i) (i<<1)
#define rc(i) ((i<<1)|1)
typedef long long ll;
const int MAXN=100005;

int cnt,home[MAXN],opt[MAXN],v[MAXN]; //heap[home[i]].v==v[i],v[opt[i]]==heap[i].v
int n,a[MAXN],Next[MAXN],Pre[MAXN],V[MAXN],id[MAXN],p,k;
long long d[MAXN],ans;
bool used[600005];

struct node{
    ll v;
    int pos;
}heap[MAXN*4];

bool operator <(node x,node y){return x.v<y.v;}

void up(int rt){
    int j=fa(rt);
    while (j){
        if (heap[rt]<heap[j]) swap(heap[rt],heap[j]),swap(home[opt[rt]],home[opt[j]]),swap(opt[rt],opt[j]),rt=j,j=fa(j);
        else break;
    }
}

void down(int rt){
    while (lc(rt)<=cnt){
        int l=lc(rt),r=rc(rt);
        if (r>cnt || heap[l]<heap[r]) r=l;
        if (heap[r]<heap[rt]) swap(heap[rt],heap[r]),swap(home[opt[rt]],home[opt[r]]),swap(opt[rt],opt[r]);
        else break;
        rt=r;
    }
}

void insert(node x,int pos){
    heap[++cnt]=x;
    opt[cnt]=pos;
    home[pos]=cnt;
    up(cnt);
}

void del(int pos){
    heap[home[pos]]=heap[cnt];
    opt[home[pos]]=opt[cnt];
    home[opt[cnt]]=home[pos];
    cnt--;
    up(home[pos]);
    down(home[pos]);
}

void ch(node x,int pos){
    heap[home[pos]]=x;
    up(home[pos]);
    down(home[pos]);
}

int main(){
    cin>>n>>k;
    for (int i=1;i<=n;++i){
        scanf("%d",&a[i]);
    }
    for (int i=2;i<=n;++i){
        d[i-1]=a[i]-a[i-1];
    }
    for (int i=0;i<=n-1;++i) Next[i]=i+1;
    for (int i=n;i>=1;--i) Pre[i]=i-1;
    for (int i=1;i<=n-1;++i) insert({d[i],i},i),V[i]=d[i];
    p=n-1;
    for (int i=1;i<=k;++i){
        node tmp=heap[1];
        int pos=tmp.pos;
        ans+=tmp.v;
        int P=Pre[pos],N=Next[pos];
        del(pos);

        if (P==0 && N==n) break;
        if (N==n){
            // used[id[P]]=1;
            del(P);
            Next[Pre[P]]=n;
        }else if (P==0){
            // used[id[N]]=1;
            del(N);
            Pre[Next[N]]=0;
        }else{
            // used[id[N]]=1;used[id[P]]=1;
            del(N);del(P);
            V[pos]=V[P]+V[N]-tmp.v;
            tmp.v=V[pos];
            insert(tmp,pos);

            Next[Pre[P]]=pos;
            Pre[Next[N]]=pos;
            Next[pos]=Next[N];
            Pre[pos]=Pre[P];
        }
    }
    printf("%lld\n",ans);
    return 0;
}

STL+去重判断(效率较低)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int n,a[MAXN],Next[MAXN],Pre[MAXN],V[MAXN],id[MAXN],p,k;
long long d[MAXN],ans;
bool used[600005];

struct node{
    ll v;
    int id,pos;
};

bool operator <(node x,node y){
    return x.v>y.v;
}


int main(){
    // freopen("backup.in","r",stdin);
    // freopen("backup.out","w",stdout);
    cin>>n>>k;
    for (int i=1;i<=n;++i){
        scanf("%d",&a[i]);
    }
    for (int i=2;i<=n;++i){
        d[i-1]=a[i]-a[i-1];
    }
    for (int i=0;i<=n-1;++i) Next[i]=i+1;
    for (int i=n;i>=1;--i) Pre[i]=i-1;
    for (int i=1;i<=n-1;++i) q.push({d[i],i,i}),id[i]=i,V[i]=d[i];
    p=n-1;
    for (int i=1;i<=k;++i){
        while (!q.empty() && used[q.top().id]) q.pop();
        node tmp=q.top(); q.pop();
        int pos=tmp.pos;
        ans+=tmp.v;
        int P=Pre[pos],N=Next[pos];
        used[id[pos]]=1;

        if (P==0 && N==n) break;
        if (N==n){
            // used[id[P]]=1;
            Next[Pre[P]]=n;
        }else if (P==0){
            // used[id[N]]=1;
            Pre[Next[N]]=0;
        }else{
            // used[id[N]]=1;used[id[P]]=1;

            V[pos]=V[P]+V[N]-tmp.v;
            tmp.v=V[pos];
            id[pos]=++p;
            tmp.id=id[pos];
            q.push(tmp);

            Next[Pre[P]]=pos;
            Pre[Next[N]]=pos;
            Next[pos]=Next[N];
            Pre[pos]=Pre[P];
        }
    }
    printf("%lld\n",ans);
    return 0;
}

序列问题

题面见进阶指南P82

#include <bits/stdc++.h>
using namespace std;

int nowi,nowj,n,m;
int a[105][2005],tmp[2005],b[4000005],cnt;
struct node{
    int a,b;
    bool is;
};
bool operator < (node x,node y){
    return ((a[nowi][x.a]+a[nowj][x.b])>(a[nowi][y.a]+a[nowj][y.b]));
}

void solve(){
    priority_queue<node> q;
    q.push({1,1,0});
    for (int i=1;i<=m;++i){
        node t=q.top();
        tmp[i]=a[nowi][t.a]+a[nowj][t.b];
        q.pop();
        q.push({t.a,t.b+1,1});
        if (!t.is) q.push({t.a+1,t.b,0});
    }
    for (int i=1;i<=m;++i) a[nowj][i]=tmp[i];
}

int main(){
    cin>>n>>m;
    for (int i=1;i<=n;++i){
        for (int j=1;j<=m;++j){
            scanf("%d",&a[i][j]);
        }
        sort(a[i]+1,a[i]+m+1);
    }
    for (int i=1;i<=n-1;++i){
        nowi=i,nowj=i+1;
        solve();
    }
    for (int i=1;i<=m;++i) printf("%d ",a[n][i]);
    return 0;
}

知识共享许可协议
本作品采用知识共享署名-非商业性使用-相同方式共享 3.0 未本地化版本许可协议进行许可。

本文链接:https://hs-blog.axell.top/archives/82/