[0CTF] One Time Pad by Z3

This challenge can be solved by two ways, mathematically or using SMT Solver

Intially, I want to remind you that there is a blog post written carefully how to solve it in mathematical way. So please read it first.

One Time Pad 1

Now let take a look at the source code:

One Time Pad 1
#!/usr/bin/env python
# coding=utf-8

from os import urandom

def process(m, k):
    tmp = m ^ k
    res = 0
    for i in bin(tmp)[2:]:
        res = res << 1;
        if (int(i)):
            res = res ^ tmp
        if (res >> 256):
            res = res ^ P
    return res

def keygen(seed):
    key = str2num(urandom(32))
    while True:
        yield key
        key = process(key, seed)

def str2num(s):
    return int(s.encode('hex'), 16)

P = 0x10000000000000000000000000000000000000000000000000000000000000425L

true_secret = open('flag.txt').read()[:32]
assert len(true_secret) == 32
print 'flag{%s}' % true_secret
fake_secret1 = "I_am_not_a_secret_so_you_know_me"
fake_secret2 = "feeddeadbeefcafefeeddeadbeefcafe"
secret = str2num(urandom(32))

generator = keygen(secret)
ctxt1 = hex(str2num(true_secret) ^ generator.next())[2:-1]
ctxt2 = hex(str2num(fake_secret1) ^ generator.next())[2:-1]
ctxt3 = hex(str2num(fake_secret2) ^ generator.next())[2:-1]
f = open('ciphertext', 'w')
f.write(ctxt1+'\n')
f.write(ctxt2+'\n')
f.write(ctxt3+'\n')
f.close()

Ok, so the logic here is simple, they generated a squences of keys which based on previous one.

And the function they use to do it is process. If we can reverse process, we have the key.

Since we have 2 pair of plaintexts and ciphertexts, we firgured out the key1 and key2. Base on them, we gonna find key0 — flag key.

By analysing the process function, we see that it’s kind of xor

seed = 0x12341234
p = process
assert key2 == p(key1, seed) == p(key1, 0) ^ p(seed, 0)
# So we assume:
# assert key1 == p(key0, seed) == p(key0, 0) ^ p(seed, 0)

Therefore, it’s turn out: p(key0, seed) = p(key0, 0) ^ p(seed, 0)

As a result, we combine to get the final xor equation:

key2 ^ key1 == p(key1, 0) ^ p(key0, 0)
=>> p(key0, 0) == p(key1, 0) ^ key1 ^ key2

So, the rest is how to reverse the process function.

Note: When I tried Sage, I don’t know that Sage already have attribute sqrt() for polynomial to calculate the square root. Next time, I should check out before trying to implement similar function to it

Here is my solver for it, by z3, of course. Thanks to Van Hoa points out the ZeroExt and BitVecVal functions, without it I cannot complete my solver.

solve.py
#!/usr/bin/python
from z3 import *

P = 0x10000000000000000000000000000000000000000000000000000000000000425L

def get_result_from_z3(result):
    dic = {}
    for d in result.decls():
        dic[d.name().lstrip('S')] = result[d].as_long()
    return dic

def str2num(s):
    return int(s.encode('hex'), 16)


def process(m):
    tmp = m
    res = 0
    for i in bin(tmp)[2:]:
        res = res << 1;
        if (int(i)):
            res = res ^ tmp
        if (res >> 256):
            res = res ^ P
    return res

x = [BitVec('x_%s' % i, 1) for i in range(256)]

ct0 = 0xaf3fcc28377e7e983355096fd4f635856df82bbab61d2c50892d9ee5d913a07f
ct1 = 0x630eb4dce274d29a16f86940f2f35253477665949170ed9e8c9e828794b5543c
ct2 = 0xe913db07cbe4f433c7cdeaac549757d23651ebdccf69d7fbdfd5dc2829334d1b

fake_secret1 = "I_am_not_a_secret_so_you_know_me"
fake_secret2 = "feeddeadbeefcafefeeddeadbeefcafe"

k1 = ct1 ^ str2num(fake_secret1)
k2 = ct2 ^ str2num(fake_secret2)

p_k0 = k2 ^ k1 ^ process(k1)

s = Solver()
temp =[]

for i in range(256):
    l = ZeroExt(511,x[i]) * BitVecVal(process( 2**i ), 512)
    # Because P is 256 bits, so I use extend to make BitVec multiply possible
    # Use BitVecVal so the value will be consider, otherwise z3 sees it as zero
    temp.append( l )

# xor together

a = temp[0]
for i in temp[1:]:
    a = a^i


s.add(a == p_k0)

if s.check() == sat:

    res = get_result_from_z3( s.model() )

    t = ''
    for i in range(256):
        t += str(eval("res['x_{}']".format(i)))
    print t[::-1]

    bla = int(t[::-1], 2) ^ ct0
    print hex(bla)[2:-1].decode('hex')



else:
    print 'UNSAT'
log
1101101100001111100100110110101000000100001000010000110010101100010111010011000101100110001000101000101110010011010110111011010100111000100111110100001111100101100001110110111001110011001111101110110001001110101011011001011011101100011100101101001000100110
t0_B3_r4ndoM_en0Ugh_1s_nec3s5arY