Fork me on GitHub

以前听说过这么一个问题,就是求一个二进制数里面有多少个1。最近在写代码的时候也碰到了需要对标志位计数的情况。后来发现bit-counting还有蛮多有趣的话题。一半听来一半看来一个并行算法,觉得挺有意思,记在下面。不过还是从最简单的counting开始。

最简单的是循环与,挨个看每一位是1还是0。这个方法要循环很多次,从低位到高位逐位检查,直到更高位没有1为止。

int count_bits_naive(unsigned int a)
{
    int count = 0;
    while (a)
    {
        count += a & 1;
        a = a >> 1;
    }
    return count;
}

一个比较经典的是self anding的办法,它的动机是减少循环次数,可不可以只数1,不管0。有没有办法不管什么情况下,都能一次消除一个1?于是它利用了这样一个现象,就是一个数减去1的话,它的最末一个1,以及之后的所有0,全部变反,前面的位不变。因此要消除最末一个1,可以把一个数减一再与自己。这样一次消除一个1,消除到0就停止循环。因此可以减少循环次数到1的个数。

int count_bits_self_anding(unsigned int a)
{
    int count = 0;
    while (a)
    {
        a = a & (a - 1);
        count++;
    }
    return count;
}

虽然是以32位的int为例,但其实思想并不局限于32位,当位串很长的时候,naive的办法就挫爆了,当位串中1很多的时候,self anding的办法也挫爆了。

仔细想想上面的算法,许多地方运算都不“充分”。第一个在与运算的时候,只有一位真正参与了运算。第二个稍好一点,但是也只有最末一个1后面的部分参与了运算。ALU的部分被“闲置”,就表明这里边还有不少油水可以榨。

首先,一段二进制串,要统计里面1的个数,可以把它分为两段。前后分别统计,然后再加起来。这就是递归的意思。剩下的,就是想如何把这前后两段的统计同时进行。因为统计一个只有一半长的子串,应该只用到运算器一半的能力就可以了吧。如果这样的话,前后两段不就可以同时进行了。一个数里面1的个数,必然比这个数本身要小很多哪。用a表示这个数,c表示其中1的个数。那a至少也是2^c – 1。所以,在装a的地方放个c是绰绰有余的。

所以并行算法的关键就是把a分段,然后做段与段之间的加法。每段保存的是这段里面1的个数。只不过从把递归变为循环,它是从局部到整体来计算。当段长度为1时,a本身就是c。所以第一步把相邻两个长度为1的段加到一起。比如11 00 10 01,就变为10 00 01 01。因为11中间1的个数是10,00中间1的个数是00,10和01中间1的个数是01。然后再取段长为2,把相邻段加起来,1000 0101变成0010 0010,因为10+00=0010,01+01=0010。以此类推,最后把前16位和后16位加起来就得到了整个32位里面的1的个数。由于1的个数总是比a本身要小,所以才没有进位的问题,两个段长为n的段,加起来长度肯定不超过2n,所以不用担心后面的段相加会影响前面。

int count_bits_parallel(unsigned int a)
{
    static const unsigned int mask[] = {
        0x55555555,        // 01010101010101010101010101010101
        0x33333333,        // 00110011001100110011001100110011
        0x0F0F0F0F,        // 00001111000011110000111100001111
        0x00FF00FF,        // 00000000111111110000000011111111
        0x0000FFFF,        // 00000000000000001111111111111111
    };
    a = ((a >> 1) & mask[0]) + (a & mask[0]);
    a = ((a >> 2) & mask[1]) + (a & mask[1]);
    a = ((a >> 4) & mask[2]) + (a & mask[2]);
    a = ((a >> 8) & mask[3]) + (a & mask[3]);
    a = ((a >> 16) & mask[4]) + (a & mask[4]);
    return a;
}

mask用来提取子串,a右移是为了错位,把前段挪到和后段对齐,然后前后段分别和mask做与运算滤掉不需要的位,再加起来。

然后有一个更加tricky的方法:

int count_bits_tricky(unsigned int a)
{
    a = a - ((a >> 1) & 0x55555555);
    a = (a & 0x33333333) + ((a >> 2) & 0x33333333);
    a = (a + (a >> 4)) & 0x0F0F0F0F;
    a *= 0x01010101;
    return a >> 24;
}

第一步改成a = a - ((a » 1) & 0x55555555);省去一个与运算。

对应真值表:

00 : 00 – 00 = 00; 01 : 01 – 00 = 01; 10 : 10 – 01 = 01; 11 : 11 – 01 = 10;

因此是一样的。

然后,从长度4位开始,加起来的和就不会溢出了。4位里面1的个数最多为4,也就是0100。0100+0100=1000仍然可以保存在一个4位的段里面,不会溢出,因此可以少做一次与。

第三步可以改成(a + (a » 4)) & 0x0F0F0F0F。右移四位直接加,这时候高四位上可能会有不想要的东西,但是低四位是求得的原来高四位和第四位的和,而且不会有进位,也就是说这个和肯定就在这低四位了。因此用mask提取出来就是正确的值了。但是段长为2的时候是不能这样做的,因为之前保存的2bit里面1的个数最大可能是2,这样的话10+10=100会进位到高2位去。因此第二步必须高两位和低两位分别提取出来再做加法。

第四步用一个乘法,相当于直接把4个段长为8的串相加。之所以到这一步才用乘法,还是由于溢出的问题。长8的串保存的是对应的串中1的个数,因此最大也就是8,四个8相加也只有32,放在一个8位长的串中没有问题。第三步如果要用乘法,就是8个4位长的串相加,最大情况32,4位就放不下了。

但这实际上是把移位相加的工作交给乘法指令去做了。所以一开始我并不认为这样会快多少,就自作聪明的写了个:

int count_bits_foo(unsigned int a)
{
    a = a - ((a >> 1) & 0x55555555);
    a = (a & 0x33333333) + ((a >> 2) & 0x33333333);
    a = (a + (a >> 4)) & 0x0F0F0F0F;
    a = (a + (a >> 8)) & 0x00FF00FF;
    return (a + (a >> 16)) & 0x0000FFFF;
}

把一个乘法指令换成了两次移位相加再做与,测试了一下,所花时间大概是原来算法的1.5倍。虽然加减,与或,移位这些运算会比较快,但是乘法也没有那么慢的,六个简单指令换一个乘法,果然还是败了。

BTW,这里有各种很tricky的位运算方法。

2013-04-04


blog comments powered by Disqus