基于8bit伽罗华域的单字节RAID6编解码

此外这里有完整的 gf 四则运算 reedsolomon/reedsolo.py at master · tomerfiliba/reedsolomon



class Block():
    '''表示一个 datablock, 一个 Block 表示 "一个字节 / 8bit"  数据,
       但会额外带一个label属性, 用于注释
    '''

    def __init__(self, byte_data, label='default'):
        '''byte_data = 输入数据, 目前仅限一字节 (8位bit, 可以表达256个状态)
           label = 注释信息
        '''
        if isinstance(byte_data, int):
            assert 0 <= byte_data <= 255
            self.byte_data = byte_data
        else:
            self.byte_data = ord(str(byte_data)[0])  # 视为字符串, 取第一个字符
        self.label = label

    def __str__(self):
        return f'<Block 0x{self.byte_data:02x}={self.byte_data:08b}={self.byte_data!r} ({self.label})>'

    def __repr__(self):
        return f'<Block 0x{self.byte_data:02x}={self.byte_data:08b}={self.byte_data!r} ({self.label})>'

    def __eq__(self, other):
        if isinstance(other, Block):
            return self.byte_data == other.byte_data
        else:
            return False


def make_gf_power_log_tables():
    '''生成 GF(2^8) 的指数表和对数表
       这里本原多项式选 P₈(x) = x⁸ + x⁴ + x³ + x² + 1 = 285 / 0x11d / 0b100011101
       这里生成元 g = 1x+0 = 0x02

    # 部分数据
    #     指数表  g^i    十进制   二进制   十六进制          对数表    log(g,i)    十进制   二进制   十六进制
    # power table 2^0   =   1 = 0b00000001 = 0x01      # log table  log(2,0  ) =   0 = 0b00000000 = 0x00
    # power table 2^1   =   2 = 0b00000010 = 0x02      # log table  log(2,1  ) =   0 = 0b00000000 = 0x00
    # power table 2^2   =   4 = 0b00000100 = 0x04      # log table  log(2,2  ) =   1 = 0b00000001 = 0x01
    # power table 2^3   =   8 = 0b00001000 = 0x08      # log table  log(2,3  ) =  25 = 0b00011001 = 0x19
    # power table 2^4   =  16 = 0b00010000 = 0x10      # log table  log(2,4  ) =   2 = 0b00000010 = 0x02
    # power table 2^5   =  32 = 0b00100000 = 0x20      # log table  log(2,5  ) =  50 = 0b00110010 = 0x32
    # power table 2^6   =  64 = 0b01000000 = 0x40      # log table  log(2,6  ) =  26 = 0b00011010 = 0x1a
    # power table 2^7   = 128 = 0b10000000 = 0x80      # log table  log(2,7  ) = 198 = 0b11000110 = 0xc6
    # power table 2^8   =  29 = 0b00011101 = 0x1d      # log table  log(2,8  ) =   3 = 0b00000011 = 0x03
    # power table 2^9   =  58 = 0b00111010 = 0x3a      # log table  log(2,9  ) = 223 = 0b11011111 = 0xdf
    # power table 2^10  = 116 = 0b01110100 = 0x74      # log table  log(2,10 ) =  51 = 0b00110011 = 0x33
    # power table 2^11  = 232 = 0b11101000 = 0xe8      # log table  log(2,11 ) = 238 = 0b11101110 = 0xee
    # power table 2^12  = 205 = 0b11001101 = 0xcd      # log table  log(2,12 ) =  27 = 0b00011011 = 0x1b
    # power table 2^13  = 135 = 0b10000111 = 0x87      # log table  log(2,13 ) = 104 = 0b01101000 = 0x68
    # power table 2^14  =  19 = 0b00010011 = 0x13      # log table  log(2,14 ) = 199 = 0b11000111 = 0xc7
    # power table 2^15  =  38 = 0b00100110 = 0x26      # log table  log(2,15 ) =  75 = 0b01001011 = 0x4b
    # ...                                              # ...
    # ...                                              # ...
    # power table 2^254 = 142 = 0b10001110 = 0x8e      # log table  log(2,254) =  88 = 0b01011000 = 0x58
    # power table 2^255 =   1 = 0b00000001 = 0x01      # log table  log(2,255) = 175 = 0b10101111 = 0xaf
    '''

    power, log = [0] * 256, [0] * 256
    n = 1
    for i in range(0, 256):
        power[i] = n
        log[n] = i
        n *= 2
        if n >= 256:
            n = n ^ 0x11d  # modular by the prime polynomial: P₈(x) = x⁸ + x⁴ + x³ + x² + 1
    log[1] = 0  # log[1] is 255, but it should be 0

    # for i, elem in enumerate(power):
    #     print(f'power table 2^{i: <3} = {elem: >3} = 0b{elem:08b} = 0x{elem:02x}')
    # for i, elem in enumerate(log):
    #     print(f'log table  log(2,{i: <3}) = {elem: >3} = 0b{elem:08b} = 0x{elem:02x}')
    return power, log


gf_power_table, gf_log_table = make_gf_power_log_tables()


def gf_add(a, b):
    return a ^ b


def gf_exp(n):
    # g = 0b10     # 其中, g=1x+0 为生成元 (即 0x2)
    assert 0 <= n <= 255
    return gf_power_table[n]


def gf_mul(a, b):
    # a * b = 2^{log(a)+log(b)}
    # a / b = 2^{log(a)-log(b)}
    if a == 0 or b == 0:
        return 0
    sum_log = (gf_log_table[a] + gf_log_table[b]) % 255   # 注意这个是指数相加, 不能用异或!
    return gf_power_table[sum_log]


def gf_inv(a):
    if a == 0:
        return 0   # None
    if a == 1:
        return 1
    else:
        return gf_power_table[255 - gf_power_table.index(a)]


def gf_div(a, b):
    return gf_mul(a, gf_inv(b))


def find_first(array, key):
    if key in array:
        return next(i for i, elem in enumerate(array) if elem == key)
    else:
        return None


def find_last(array, key):
    if key in array:
        return len(array) - 1 - next(i for i, elem in enumerate(reversed(array)) if elem == key)
    else:
        return None


def raid5_encode(data_list):
    '''RAID5编码
    传入数据Block列表, 返回P
    '''
    result = Block(0b00000000, 'p')
    for data in data_list:
        result.byte_data ^= data.byte_data
    return result


def raid5_decode(data_list):
    '''RAID5解码
       传入残存的数据/P列表, 返回恢复出来的数据
    '''
    return raid5_encode(data_list)    # RAID5 的解码等同于 RAID5 编码


def raid6_encode(data_list):
    '''RAID6编码
    传入数据Block列表, 返回P Q
    '''
    # print(f'got encode {len(data_list)}')
    block_p = raid5_encode(data_list)
    block_q_clips = [gf_mul(gf_exp(i), data.byte_data) for i, data in enumerate(data_list)]
    q_data = 0
    for q in block_q_clips:
        q_data = gf_add(q_data, q)
    return block_p, Block(q_data, 'q')


def raid6_decode(data_list):
    '''RAID6解码
    目前只支持需要解码两个DataBlock
    传入残存的数据Block列表以及P Q,
    返回恢复出来的两个DataBlock
    '''
    ix = find_first(data_list, '?')
    iy = find_last(data_list, '?')

    print('raid6_decode get total', data_list, 'should recover', (ix, iy))
    assert ix < iy < len(data_list) - 2

    p = data_list[-2]
    q = data_list[-1]
    p_nxy, q_nxy = raid6_encode([d if isinstance(d, Block) else Block(0) for d in data_list[:-2]])

    p = p.byte_data
    q = q.byte_data
    p_nxy = p_nxy.byte_data
    q_nxy = q_nxy.byte_data
    a = gf_mul(gf_exp(iy - ix), gf_inv(gf_add(1, gf_exp(iy - ix))))
    # b = gf_mul(gf_inv(gf_exp(ix)), gf_inv((1 ^ gf_exp(iy - ix))))
    # b = gf_inv(gf_exp(ix) ^ gf_exp(iy))
    b = gf_mul(gf_exp(255 - ix), gf_inv(gf_add(1, gf_exp(iy - ix))))
    dp = gf_add(p, p_nxy)
    dq = gf_add(q, q_nxy)
    dx = gf_add(gf_mul(a, dp), gf_mul(b, dq))
    dy = gf_add(dx, dp)
    return Block(dx, label='dx'), Block(dy, label='dy')
    # return Block(label='p'), Block(label='q')


def ut_basic():
    data_zero = Block(0b110100, 'zero')
    data_one = Block(0b11111111, 'one')
    data_random = Block(0b11011001, 'random')
    print(data_zero, data_one, data_random)

    assert raid5_encode([Block(0b10011001), Block(0b11001100)]) == Block(0b01010101)
    assert raid5_encode([Block(0), Block(1)]) == Block(1)
    assert raid5_encode([Block(255), Block(255)]) == Block(0)
    assert raid5_encode([Block(255), Block(0)]) == Block(255)
    assert raid5_encode([Block(255), Block(1)]) == Block(254), f'cal {raid5_encode([Block(255), Block(1)])}'


def ut_raid6_1():
    zero = Block(0b000000, 'zero')
    one = Block(0b11111111, 'one')

    print(raid6_encode([zero, zero, zero]))
    print(raid6_encode([one, one, one]))
    print(raid6_encode([one, one, one, one]))
    print(raid6_encode([one, one, one, one, one]))
    print(raid6_encode([Block(0b00000001), Block(0b00000100)]))
    print(raid6_encode([Block(0b00000100), Block(0b00000001)]))


def ut_raid6_2():
    print(111)
    p, q = raid6_encode([Block(0b00000001), Block(0b00000100), Block(0b00000000), ])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode([Block(0b00000001), '?', '?', p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [Block(0b00000100), Block(0b00000000)]

    print(222)
    p, q = raid6_encode([Block(0b00000001), Block(0b00000001), Block(0b00000001), Block(0b00000001)])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode(['?', Block(0b00000001), '?', Block(0b00000001), p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [Block(0b00000001), Block(0b00000001)]

    print(333)
    p, q = raid6_encode([Block(0b00010001), Block(0b00010001), Block(0b00000001), Block(0b00000001)])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode(['?', Block(0b00010001), '?', Block(0b00000001), p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [Block(0b00010001), Block(0b00000001)]

    print(444)
    p, q = raid6_encode([Block(0b00011001), Block(0b00000001), Block(0b11000001)])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode([Block(0b00011001), '?', '?', p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [Block(0b00000001), Block(0b11000001)]

    print(555)
    p, q = raid6_encode([Block(0b00000001), Block(0b00000100), Block(0b00000100), Block(0b00100100)])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode(['?', Block(0b00000100), '?', Block(0b00100100), p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [Block(0b00000001), Block(0b00000100)]

    print(666)
    random1 = Block(0b11011001, 'random1')
    random2 = Block(0b00011101, 'random2')
    random3 = Block(0b10001000, 'random3')
    p, q = raid6_encode([random1, random2, random3])
    print('raid6_encode result', p, q)
    recover1, recover2 = raid6_decode([random1, '?', '?', p, q])
    print('decoded got', recover1, recover2)
    assert [recover1, recover2] == [random2, random3]


def ut_raid6_random_gen():
    import random
    for _ in range(200):
        raid_size = random.randint(5, 10)
        data_list = [Block(random.randint(0, 255), label=f'data{i}') for i in range(raid_size)]
        p, q = raid6_encode(data_list)
        killed_index1, killed_index2 = sorted(random.sample(list(range(raid_size)), 2))
        killed_data1, killed_data2 = data_list[killed_index1], data_list[killed_index2]
        print(f'gen raid6 {raid_size}+2')
        print(f'    datablocks = {data_list}')
        print(f'    encode     p={p} q={q}')
        print(f'    random kill index {killed_index1} {killed_index2} {killed_data1} {killed_data2}')
        current_data_list = ['?' if i in (killed_index1, killed_index2) else d for i, d in enumerate(data_list)] + [p, q]
        recovered1, recovered2 = raid6_decode(current_data_list)
        print(f'    recovered  {recovered1} {recovered2}')
        assert [recovered1, recovered2] == [killed_data1, killed_data2]


def ut_raid6_basic_cal():
    assert gf_mul(0, 0) == 0
    assert gf_mul(1, 1) == 1
    assert gf_mul(3, 0) == 0
    assert gf_mul(2, 6) == 0x0c
    assert gf_mul(2, 253) == 0xe7
    assert gf_mul(2, 255) == 0xe3
    assert gf_mul(3, 3) == 0x05
    assert gf_mul(5, 7) == 0x1b
    assert gf_mul(255, 1) == 0xff
    assert gf_mul(1, 255) == 0xff
    assert gf_mul(255, 2) == 0xe3
    assert gf_mul(254, 2) == 0xe1

    assert gf_exp(0) == 0x01
    assert gf_exp(1) == 0x02
    assert gf_exp(2) == 0x04
    assert gf_exp(3) == 0x08
    assert gf_exp(9) == 0x3a
    assert gf_exp(254) == 0x8e
    assert gf_exp(255) == 0x01  # in simraid 0x00

    assert gf_inv(0) == 0x00
    assert gf_inv(1) == 0x01
    assert gf_inv(2) == 0x8e == 142
    assert gf_inv(3) == 0xf4
    assert gf_inv(4) == 0x47 == 71
    assert gf_inv(8) == 0xad == 173
    assert gf_inv(32) == 0x6c
    assert gf_inv(33) == 0xed
    assert gf_inv(34) == 0x39
    assert gf_inv(254) == 0x7e
    assert gf_inv(255) == 0xfd

    assert gf_inv(1 ^ gf_exp(0)) == 0x00
    assert gf_inv(1 ^ gf_exp(1)) == 0xf4
    assert gf_inv(1 ^ gf_exp(2)) == 0xa7
    assert gf_inv(1 ^ gf_exp(3)) == 0x9d
    assert gf_inv(1 ^ gf_exp(32)) == 0xac
    assert gf_inv(1 ^ gf_exp(32)) == 0xac
    assert gf_inv(1 ^ gf_exp(48)) == 0x04
    assert gf_inv(1 ^ gf_exp(50)) == 0x47
    assert gf_inv(1 ^ gf_exp(53)) == 0xd0
    assert gf_inv(1 ^ gf_exp(254)) == 0xf5
    assert gf_inv(1 ^ gf_exp(255)) == 0x00


def ut_raid6_from_real_raid6_encoder():
    # 造 3+2 RAID6, 实际写一些随机数, 查看编码的 PQ
    # data1 = (A  B  C  P  Q)
    # 用来对拍实现是否正确
    data1 = (0xbb, 0x38, 0x24, 0xa7, 0x5b)
    data2 = (0xfe, 0x08, 0xd0, 0x26, 0x89)
    data3 = (0x96, 0x8f, 0xfa, 0xe3, 0x5a)
    data4 = (0x62, 0x51, 0x34, 0x07, 0x10)
    assert (Block(data1[3]), Block(data1[4])) == raid6_encode([Block(d) for d in data1[0:3]])
    assert (Block(data2[3]), Block(data2[4])) == raid6_encode([Block(d) for d in data2[0:3]])
    assert (Block(data3[3]), Block(data3[4])) == raid6_encode([Block(d) for d in data3[0:3]])
    assert (Block(data4[3]), Block(data4[4])) == raid6_encode([Block(d) for d in data4[0:3]])


if __name__ == '__main__':
    # ut_basic()
    ut_raid6_basic_cal()
    ut_raid6_1()
    ut_raid6_2()
    ut_raid6_random_gen()