modestudy

[题解分析]

qwb经典study系列(套娃

  • Challenge-1: AES/CBC字节翻转

  • Challenge-2: AES/CBC Encrypt Oracle(iv fixed),发送b'\x00'*32,即可还原iv

  • Challenge-3: ECB,将第三组cipher换为第五组cipher即可

  • Challenge-4: ECB选择明文攻击 ($E_{k}(pt||salt)$),至多256*len(salt)次即可恢复salt

    (由于访问频率限制问题,在爆破单个salt比特位时,宜将256组拼接,在一次中请求加密)

    1
    2
    3
    4
    mt = b""
    for lsb in range(0x100):
    mt += (prefix + secret + bytes([lsb]))
    ct = chal4_enc(mt)[:16*0xff]
  • Challenge-5: 通过报错发现assert(len(pt) % 2 == 0),因此myblockencrypt_ecb是个分组长度仅为2bytes的ECB,爆破即可(同样由于频率限制,宜拼接一次性发送,打表本地查找)

  • Challenge-6: Padding Oracle Attack(crack_range为第一块,且IV可控)

[exp]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import string, re, random
from tqdm import tqdm
from hashlib import sha256
from binascii import unhexlify, hexlify
from pwn import *

io = remote("139.224.254.172", "7777")
# context.log_level = 'debug'

def passpow():
io.recvuntil("sha256")
msg = io.recvline().strip().decode()
prefix = re.findall(r"\((.*)\+",msg)[0]
while True:
answer = ''.join(random.choice(string.ascii_letters + string.digits) for i in range(8))
cipher = sha256((prefix + answer).encode()).digest()
if cipher[0] < 8:
io.sendlineafter("?=", answer)
break
io.sendlineafter("teamtoken=", "icq6c57ba701346474cd670a5f3b1c66")

def xor(a, b):
return bytes(x ^ y for x, y in zip(a, b))

def chal1():
io.sendlineafter("[-] your choice:", "1")
io.recvuntil("[+] cookie:")
msg = io.recvline().strip().decode()
cookie = re.findall(r"^(.*);checksum=(.*)", msg)[0]
session, checksum = cookie[0], cookie[1]
session = session.encode()
checksum = unhexlify(checksum)
payload = session[-16:-1] + b'1'
checksum = hexlify(xor(xor(payload, checksum[-32:-16]), session[-16:]) + checksum[-16:]).decode()
io.sendlineafter("[-] cookie:", session[:-1].decode() + "1;checksum=" + checksum)
print("[+] challenge-1 pass")

def chal2():
io.sendlineafter("[-] your choice:", "2")
io.recvuntil("sha256")
msg = io.recvline().strip().decode()
iv_hash = re.findall(r"\(iv\)=(.*)", msg)[0] # sha256(iv).hexdigest()
# ct = hexlify(b'\x00' * 32).decode()
ct = b'\x00' * 32
io.sendlineafter("[-] your choice:", "1")
io.sendlineafter("[-] c:", ct)
msg = io.recvline().strip().decode()
pt = re.findall(r"\[\+\] (.*)", msg)[0]
pt = unhexlify(pt)
iv = xor(pt[:16], pt[16:])
assert(sha256(iv).hexdigest() == iv_hash)
io.sendlineafter("[-] your choice:", "2")
io.sendlineafter("[-] iv(encode hex):", hexlify(iv))
print("[+] challenge-2 pass")

def chal3():
io.sendlineafter("[-] your choice:", "3")
io.recvuntil("[+] cookie=")
'''
cookie = io.recvline().strip()
for i in range(0, len(cookie), 16):
print(cookie[i:i+16])
'''
io.recvuntil("[+] 128bit_ecb_encrypt(cookie):")
cookie = unhexlify(io.recvline().strip())
cookie = cookie[:16*2] + cookie[16*4:16*5] + cookie[16*3:]
cookie = hexlify(cookie)
io.sendlineafter("[-] input your encrypted cookie(encode hex):", cookie)
print("[+] challenge-3 pass")

def chal4_enc(m):
io.sendlineafter("[-] your choice:", "1")
io.sendlineafter("[-] input(encode hex):", hexlify(m))
io.recvuntil("[+] encrypted msg: ")
c = unhexlify(io.recvline().strip())
return c

def chal4():
io.sendlineafter("[-] your choice:", "4")
io.recvuntil("sha256")
msg = io.recvline().strip().decode()
secret_hash = re.findall(r"\(secret\)=(.*)", msg)[0]
secret = b""
for i in tqdm(range(15, -1, -1)):
prefix = b"\x00" * i
cipher = chal4_enc(prefix)[:16]
mt = b""
for lsb in range(0x100):
mt += (prefix + secret + bytes([lsb]))
ct = chal4_enc(mt)[:16*0xff]
for lsb in range(0x100):
if ct[16*lsb:16*(lsb+1)] == cipher:
secret += bytes([lsb])
break
io.sendlineafter("[-] your choice:", "2")
assert(sha256(secret).hexdigest() == secret_hash)
io.sendlineafter("[-] secret(encode hex):", hexlify(secret))
print("[+] challenge-4 pass")

def chal5_enc(m):
io.sendlineafter("[-] your choice:", "1")
io.sendlineafter("[-] input(encode hex):", hexlify(m))
io.recvuntil("[+] myblockencrypt_ecb(your_input).encode(\"hex\"):")
c = unhexlify(io.recvline().strip())
return c

def chal5():
io.sendlineafter("[-] your choice:", "5")
io.recvuntil("sha256")
msg = io.recvline().strip().decode()
secret_hash = re.findall(r"\(secret\)=(.*)", msg)[0]
io.recvuntil("[+] myblockencrypt_ecb(secret).encode(\"hex\")=")
secret_cipher = unhexlify(io.recvline().strip())
mt = b""
for i in range(0x10000):
mt += int.to_bytes(i, 2, "big")
# context.log_level = 'debug'
ct = chal5_enc(mt)
secret_list = [ct[2*i:2*(i+1)] for i in range(0x10000)]
secret = b""
for i in range(0, len(secret_cipher), 2):
cur_secret = secret_list.index(secret_cipher[i:i+2])
secret += int.to_bytes(cur_secret, 2, 'big')
io.sendlineafter("[-] your choice:", "2")
assert(sha256(secret).hexdigest() == secret_hash)
io.sendlineafter("[-] secret(encode hex):", hexlify(secret))
print("[+] challenge-5 pass")

def chal6_dec(c):
io.sendlineafter("[-] your choice:", "1")
io.sendlineafter("[-] input your iv+c (encode hex):", hexlify(c))
resp = io.recvline().strip().decode()
if "success" in resp:
return True
else:
return False

def chal6():
io.sendlineafter("[-] your choice:", "6")
io.recvuntil("[+] iv+aes128_cbc(key,iv,padding(secret)):")
msg = unhexlify(io.recvline().strip())
iv, cipher = msg[:16], msg[16:32]
'''
secret = b''
for i in range(16):
for j in tqdm(range(0x100)):
ct = iv[:15-i] + bytes([j]) + iv[16-i:] + cipher
if chal6_dec(ct) == True:
secret = bytes([iv[-i - 1] ^ j ^ (i + 1)]) + secret
iv = iv[:15-i] + xor(ct[15-i:16], bytes([(i + 1) ^ (i + 2)]) * (i + 1))
break
print(secret)
'''
secret = b'\xe6?\xee\x07\xa3\x07\xaa\xc8\x8em\xb5\xa0\x90'
iv = iv[:16-len(secret)] + xor(iv[16-len(secret):], xor(secret, bytes([len(secret) + 1]) * len(secret)))
# context.log_level = 'debug'
for i in range(len(secret), 16):
for j in tqdm(range(0x100)):
ct = iv[:15-i] + bytes([j]) + iv[16-i:] + cipher
if chal6_dec(ct) == True:
secret = bytes([iv[-i - 1] ^ j ^ (i + 1)]) + secret
iv = iv[:15-i] + xor(ct[15-i:16], bytes([(i + 1) ^ (i + 2)]) * (i + 1))
break
print(secret)
io.sendlineafter("[-] your choice:", "2")
io.sendlineafter("[-] secret(encode hex):", hexlify(secret))

def main():
passpow()
chal1()
chal2()
chal3()
chal4()
chal5()
chal6()
io.sendlineafter("[-] your choice:", "7")
io.interactive()

if __name__ == '__main__':
main()

fault

[题解分析]

SM4加密原理

SM4明文长度128bits,分为4个32bits的字$X_{0},X_{1},X_{2},X_{3}$,进行32轮加密,每轮的示意图如下:

其中SM4的S盒均相同,线性变换L为lambda y:(y ^^ (rotl(y, 2)) ^^ (rotl(y, 10)) ^^ (rotl(y, 18)) ^^ (rotl(y, 24)))

抽象成1*32矩阵与32*32矩阵相乘的变换即可,发现该32*32的系数矩阵可逆,即L可逆(注意转置系数矩阵,且要在GF(2)下)

第32轮加密后,要实现一次$X_{32},X_{33},X_{34},X_{35}$四字的逆序,即C=($X_{35},X_{34},X_{33},X_{32}$)

SM4密钥扩展

系统参数$FK=(FK_{0},FK_{1},FK_{2},FK_{3}),FK_{i}\in F_{2}^{32}$及$CK=(CK_{0},CK_{1}…,CK_{31}),CK_{i}\in F_{2}^{32}$固定

密钥为$(MK_{0},MK_{1},MK_{2},MK_{3})$,则初始化

$(K_{0},K_{1},K_{2},K_{3})=(MK_{0}\oplus FK_{0},MK_{1}\oplus FK_{1},MK_{2}\oplus FK_{2},MK_{3}\oplus FK_{3})$

迭代$rk_{i}=K_{i+4}=K_{i}\oplus T(K_{i+1}\oplus K_{i+2}\oplus K_{i+3}\oplus CK_{i}),i=0,1,…,31$

即可获得完整的32个轮密钥$rk_{i}$

(T函数与SM4加密轮函数中的S+L基本类似,只需将L改作lambda y:(y ^^ (rotl(y, 13)) ^^ (rotl(y, 23))即可)

SM4故障差分分析

基于能在SM4的某轮轮函数中注入错误,挺有意思的攻击- -

攻击流程如下:

  • 选择随机明文,发送至服务端进行加密,拿到返回的ct
  • 选择和上一步相同的明文,但在某轮注入错误,获取返回的ct_ast,同时也记录密文差分ct_diff
  • 降轮次的常规差分攻击or直接对S盒的差分攻击

由于本题中的出错轮次(0~31),出错字节下标(0~15),出错字节异或上的fault(0~0xfe)均可控

1
2
3
4
5
6
7
self.send(b"give me the value of r f p", False)
tmp = self.recv(prompt=b":")
r, f, p = tmp.split(b" ")
r = int(r) % 0x20
f = int(f) % 0xff
p = int(p) % 16
ct = self.encrypt2(key, unhexlify(pt), r, f, p)

因此我们可以直接将其转化为对单个S盒的差分攻击(常规差分攻击基于高概率差分特征,爆破key部分比特观察counter来实现,但对于单个S盒的差分攻击则只需求交集即可)

且在对于单个S盒的差分攻击中,若固定输入差分或输出差分其中任一者,均只能至多将key候选降至2个,想要确定唯一的key,则必须有至少两个输出差分来进行共同差分分析取key交集

也正是因为这一点,我们要破解key[31],所选注入fault的地方应为第29轮的第16个字节,且要注入两次不同错误来共同分析(比如当我们选择第31轮出错时,由于题目只允许在每一轮的后四个字节之一出错,因此无法篡改X31’,其值将固定为0x00000000)

在第29轮注入fault(假设为0x22)后,

可以发现第29轮的差分为0x00000000, 0x00000000, 0x00000000, 0x00000022

则第32轮的输出差分为$X31’\oplus X35’$,输入差分为$X32’\oplus X33’\oplus X34’$,且ct_diff对应的ct和ct_ast也均已知,逆L模块后对S盒进行差分攻击即可

上述攻击成功后,会令K[31]留下两个candidate,此时再注入fault=0x33,重复操作即可唯一确定K[31]

得到K[31]后,我们可对攻击K[31]时的所有密文数据进行一轮解密,得到31轮加密后的结果及加密差分,类似对S盒进行差分分析,直到恢复出K[30]和K[29]

但由于SM4的密钥生成算法,至少需要四个连续K,才能恢复出完整32个子密钥,因此我们对26轮第16个字节注入错误0x22和0x33,并利用已得到的K[29:32]进行解密得到第29轮加密后的结果及加密差分,此时情况已等价于求解K[31]

恢复出子密钥K[28:32]后,逆向得到32个子密钥,解密得到flag

[exp]

Exp被我写的极其冗余…凑合看8

get_data.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import string, re
from tqdm import tqdm
from hashlib import sha256
from binascii import hexlify, unhexlify
from Crypto.Util.number import *
from pwn import *
from pwnlib.util.iters import mbruteforce

io = remote("39.101.134.52", "8006")
# context.log_level = 'debug'

def proof_of_work():
io.recvuntil("sha256")
msg = io.recvline().strip().decode()
suffix = re.findall(r"XXX\+([^\)]+)", msg)[0]
cipher = re.findall(r"== (.*)", msg)[0]
proof = mbruteforce(lambda x: sha256((x + suffix).encode("latin-1")).hexdigest() ==
cipher, string.ascii_letters + string.digits, length=3, method='fixed')
io.sendlineafter("Give me XXX:", proof)
io.sendlineafter("teamtoken:", "icq6c57ba701346474cd670a5f3b1c66")

def get_enc_flag():
io.recvline()
enc_flag = io.recvline().strip().decode()
return enc_flag

def encrypt(m):
io.sendlineafter('> ', '1')
io.sendlineafter('your plaintext in hex:', hexlify(m))
msg = io.recvline().strip().decode()
c = re.findall(r"your ciphertext in hex:(.*)", msg)[0]
return c

def inject_fault(m, r, f, p):
io.sendlineafter('> ', '2')
io.sendlineafter('your plaintext in hex:', hexlify(m))
io.sendlineafter('give me the value of r f p:', '{} {} {}'.format(r, f, p))
msg = io.recvline().strip().decode()
c = re.findall(r"your ciphertext in hex:(.*)", msg)[0]
return c

def xor(a, b):
return bytes(x ^ y for x, y in zip(a, b))

def main():
proof_of_work()
enc_flag = get_enc_flag()
r, f, p = 29, 0x22, 15
fw = open('data', 'w')
fw.write(enc_flag + '\n')
# prepared for rk[29], rk[30], rk[31]
for i in tqdm(range(5)):
m = int.to_bytes(getRandomNBitInteger(128), 16, 'big')
c = encrypt(m)
c_ast = inject_fault(m, r, f, p)
c_diff = hexlify(xor(unhexlify(c), unhexlify(c_ast))).decode()
fw.write('({},{},{})\n'.format(c, c_ast, c_diff))
f = 0x33
for i in tqdm(range(5)):
m = int.to_bytes(getRandomNBitInteger(128), 16, 'big')
c = encrypt(m)
c_ast = inject_fault(m, r, f, p)
c_diff = hexlify(xor(unhexlify(c), unhexlify(c_ast))).decode()
fw.write('({},{},{})\n'.format(c, c_ast, c_diff))
# prepared for rk[28]
r -= 3
f = 0x22
for i in tqdm(range(5)):
m = int.to_bytes(getRandomNBitInteger(128), 16, 'big')
c = encrypt(m)
c_ast = inject_fault(m, r, f, p)
c_diff = hexlify(xor(unhexlify(c), unhexlify(c_ast))).decode()
fw.write('({},{},{})\n'.format(c, c_ast, c_diff))
f = 0x33
for i in tqdm(range(5)):
m = int.to_bytes(getRandomNBitInteger(128), 16, 'big')
c = encrypt(m)
c_ast = inject_fault(m, r, f, p)
c_diff = hexlify(xor(unhexlify(c), unhexlify(c_ast))).decode()
fw.write('({},{},{})\n'.format(c, c_ast, c_diff))
fw.close()

if __name__ == '__main__':
main()

exp_fault.sage

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import re
from Crypto.Util.number import *

get_uint32_be = lambda key_data:((key_data[0] << 24) | (key_data[1] << 16) | (key_data[2] << 8) | (key_data[3]))
put_uint32_be = lambda n:[((n>>24)&0xff), ((n>>16)&0xff), ((n>>8)&0xff), ((n)&0xff)]
get_uint128_be = lambda key_data:((key_data[0] << 96) | (key_data[1] << 64) | (key_data[2] << 32) | (key_data[3]))
put_uint128_be = lambda n:[((n>>96)&0xffffffff), ((n>>64)&0xffffffff), ((n>>32)&0xffffffff), ((n)&0xffffffff)]
rotl = lambda x, n:((x << n) & 0xffffffff) | ((x >> (32 - n)) & 0xffffffff)
L = lambda y:(y ^^ (rotl(y, 2)) ^^ (rotl(y, 10)) ^^ (rotl(y, 18)) ^^ (rotl(y, 24)))

SM4_FK = [0xa3b1bac6,0x56aa3350,0x677d9197,0xb27022dc]
SM4_CK = [
0x00070e15,0x1c232a31,0x383f464d,0x545b6269,
0x70777e85,0x8c939aa1,0xa8afb6bd,0xc4cbd2d9,
0xe0e7eef5,0xfc030a11,0x181f262d,0x343b4249,
0x50575e65,0x6c737a81,0x888f969d,0xa4abb2b9,
0xc0c7ced5,0xdce3eaf1,0xf8ff060d,0x141b2229,
0x30373e45,0x4c535a61,0x686f767d,0x848b9299,
0xa0a7aeb5,0xbcc3cad1,0xd8dfe6ed,0xf4fb0209,
0x10171e25,0x2c333a41,0x484f565d,0x646b7279
]

'''
generate differential dist for SBOX
'''
SM4_BOXES_TABLE = [
0xd6,0x90,0xe9,0xfe,0xcc,0xe1,0x3d,0xb7,0x16,0xb6,0x14,0xc2,0x28,0xfb,0x2c,
0x05,0x2b,0x67,0x9a,0x76,0x2a,0xbe,0x04,0xc3,0xaa,0x44,0x13,0x26,0x49,0x86,
0x06,0x99,0x9c,0x42,0x50,0xf4,0x91,0xef,0x98,0x7a,0x33,0x54,0x0b,0x43,0xed,
0xcf,0xac,0x62,0xe4,0xb3,0x1c,0xa9,0xc9,0x08,0xe8,0x95,0x80,0xdf,0x94,0xfa,
0x75,0x8f,0x3f,0xa6,0x47,0x07,0xa7,0xfc,0xf3,0x73,0x17,0xba,0x83,0x59,0x3c,
0x19,0xe6,0x85,0x4f,0xa8,0x68,0x6b,0x81,0xb2,0x71,0x64,0xda,0x8b,0xf8,0xeb,
0x0f,0x4b,0x70,0x56,0x9d,0x35,0x1e,0x24,0x0e,0x5e,0x63,0x58,0xd1,0xa2,0x25,
0x22,0x7c,0x3b,0x01,0x21,0x78,0x87,0xd4,0x00,0x46,0x57,0x9f,0xd3,0x27,0x52,
0x4c,0x36,0x02,0xe7,0xa0,0xc4,0xc8,0x9e,0xea,0xbf,0x8a,0xd2,0x40,0xc7,0x38,
0xb5,0xa3,0xf7,0xf2,0xce,0xf9,0x61,0x15,0xa1,0xe0,0xae,0x5d,0xa4,0x9b,0x34,
0x1a,0x55,0xad,0x93,0x32,0x30,0xf5,0x8c,0xb1,0xe3,0x1d,0xf6,0xe2,0x2e,0x82,
0x66,0xca,0x60,0xc0,0x29,0x23,0xab,0x0d,0x53,0x4e,0x6f,0xd5,0xdb,0x37,0x45,
0xde,0xfd,0x8e,0x2f,0x03,0xff,0x6a,0x72,0x6d,0x6c,0x5b,0x51,0x8d,0x1b,0xaf,
0x92,0xbb,0xdd,0xbc,0x7f,0x11,0xd9,0x5c,0x41,0x1f,0x10,0x5a,0xd8,0x0a,0xc1,
0x31,0x88,0xa5,0xcd,0x7b,0xbd,0x2d,0x74,0xd0,0x12,0xb8,0xe5,0xb4,0xb0,0x89,
0x69,0x97,0x4a,0x0c,0x96,0x77,0x7e,0x65,0xb9,0xf1,0x09,0xc5,0x6e,0xc6,0x84,
0x18,0xf0,0x7d,0xec,0x3a,0xdc,0x4d,0x20,0x79,0xee,0x5f,0x3e,0xd7,0xcb,0x39,
0x48,
]

diff_dist = dict()
# initiate dist by set
for i in range(0x100):
for j in range(0x100):
diff_dist[(i, j)] = set()
for i in range(0x100):
for j in range(0x100):
pt_diff = i ^^ j
ct_diff = SM4_BOXES_TABLE[i] ^^ SM4_BOXES_TABLE[j]
diff_dist[(pt_diff, ct_diff)].add(i)
diff_dist[(pt_diff, ct_diff)].add(j)

'''
get data for DFA attack
'''
enc_flag, cipher1, cipher2 = None, [], []
with open('data', 'r') as fr:
i = 0
for line in fr:
if i == 0:
enc_flag = line.strip()
elif 1 <= i <= 10:
c_data = re.findall(r"\((.*),(.*),(.*)\)", line)[0]
ct, ct_ast, ct_diff = c_data[0], c_data[1], c_data[2]
ct_list = put_uint128_be(int(ct, 16))[::-1]
# ct_ast_list = put_uint128_be(int(ct_ast, 16))[::-1]
ct_diff_list = put_uint128_be(int(ct_diff, 16))[::-1]
cipher1.append((get_uint128_be(ct_list), get_uint128_be(ct_diff_list)))
else:
c_data = re.findall(r"\((.*),(.*),(.*)\)", line)[0]
ct, ct_ast, ct_diff = c_data[0], c_data[1], c_data[2]
ct_list = put_uint128_be(int(ct, 16))[::-1]
ct_ast_list = put_uint128_be(int(ct_ast, 16))[::-1]
cipher2.append((get_uint128_be(ct_list), get_uint128_be(ct_ast_list)))
i += 1

'''
COEF prepared for function `inv_L`
'''
coef_list = []
for i in range(32):
cur_coef_list = [0] * 32
cur_coef_list[i] = 1
cur_coef_list[(i + 2) % 32] = 1
cur_coef_list[(i + 10) % 32] = 1
cur_coef_list[(i + 18) % 32] = 1
cur_coef_list[(i + 24) % 32] = 1
coef_list.append(cur_coef_list)
COEF = (Matrix(GF(2), coef_list).T).inverse()

def inv_L(ct):
ct = [int(_) for _ in bin(ct)[2:].rjust(32, '0')]
C = Matrix(GF(2), ct)
M = C * COEF
mt = int(''.join(str(_) for _ in M[0]), 2)
return mt

'''
use a round key to decrypt(only one round)
'''
def dec_one_round(cipher, rk):
cipher_list = put_uint128_be(cipher)
x = put_uint32_be(cipher_list[0] ^^ cipher_list[1] ^^ cipher_list[2] ^^ rk)
y = []
for i in range(4):
y.append(SM4_BOXES_TABLE[x[i]])
y = get_uint32_be(y)
y = L(y)
new_cipher_list = [y ^^ cipher_list[3], cipher_list[0], cipher_list[1], cipher_list[2]]
return get_uint128_be(new_cipher_list)

def inv_key(key, i):
key1, key2, key3, key4 = key[0], key[1], key[2], key[3]
k_in = put_uint32_be(key1 ^^ key2 ^^ key3 ^^ SM4_CK[i])
k_out = []
for i in range(4):
k_out.append(SM4_BOXES_TABLE[k_in[i]])
k_out = get_uint32_be(k_out)
k_out = k_out ^^ (rotl(k_out, 13)) ^^ (rotl(k_out, 23))
key0 = key4 ^^ k_out
return key0

def crack_key_31():
global cipher1
key_31 = [set(range(0x100)), set(range(0x100)), set(range(0x100)), set(range(0x100))]
for i in range(5):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3] ^^ 0x00000022))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_31_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_31_candidate.add(x[j] ^^ _)
key_31[j] = key_31[j] & key_31_candidate
status = True
for j in range(4):
if len(key_31[j]) > 2:
status = False
break
if status:
break
for i in range(5, 10):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3] ^^ 0x00000033))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_31_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_31_candidate.add(x[j] ^^ _)
key_31[j] = key_31[j] & key_31_candidate
status = True
for j in range(4):
if len(key_31[j]) > 1:
status = False
break
if status:
break
# print(key_31)
key_31 = get_uint32_be([list(_)[0] for _ in key_31])
print("[+] KEY_31: " + hex(key_31)[2:].rjust(8, '0'))
return key_31

def crack_key_30(key_31):
global cipher1
for i in range(5):
cipher1[i] = (dec_one_round(cipher1[i][0], key_31), get_uint128_be([0x00000022] + put_uint128_be(cipher1[i][1])[:3]))
cipher1[i+5] = (dec_one_round(cipher1[i+5][0], key_31), get_uint128_be([0x00000033] + put_uint128_be(cipher1[i+5][1])[:3]))
key_30 = [set(range(0x100)), set(range(0x100)), set(range(0x100)), set(range(0x100))]
for i in range(5):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3]))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_30_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_30_candidate.add(x[j] ^^ _)
key_30[j] = key_30[j] & key_30_candidate
status = True
for j in range(4):
if len(key_30[j]) > 2:
status = False
break
if status:
break
for i in range(5, 10):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3]))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_30_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_30_candidate.add(x[j] ^^ _)
key_30[j] = key_30[j] & key_30_candidate
status = True
for j in range(4):
if len(key_30[j]) > 1:
status = False
break
if status:
break
# print(key_30)
key_30 = get_uint32_be([list(_)[0] for _ in key_30])
print("[+] KEY_30: " + hex(key_30)[2:].rjust(8, '0'))
return key_30

def crack_key_29(key_30):
global cipher1
for i in range(10):
cipher1[i] = (dec_one_round(cipher1[i][0], key_30), get_uint128_be([0x00000000] + put_uint128_be(cipher1[i][1])[:3]))
key_29 = [set(range(0x100)), set(range(0x100)), set(range(0x100)), set(range(0x100))]
for i in range(5):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3]))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_29_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_29_candidate.add(x[j] ^^ _)
key_29[j] = key_29[j] & key_29_candidate
status = True
for j in range(4):
if len(key_29[j]) > 2:
status = False
break
if status:
break
for i in range(5, 10):
ct, ct_diff = cipher1[i][0], cipher1[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3]))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_29_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_29_candidate.add(x[j] ^^ _)
key_29[j] = key_29[j] & key_29_candidate
status = True
for j in range(4):
if len(key_29[j]) > 1:
status = False
break
if status:
break
# print(key_29)
key_29 = get_uint32_be([list(_)[0] for _ in key_29])
print("[+] KEY_29: " + hex(key_29)[2:].rjust(8, '0'))
return key_29

def crack_key_28(key_31, key_30, key_29):
global cipher2
for i in range(5):
cipher2[i] = (dec_one_round(cipher2[i][0], key_31), dec_one_round(cipher2[i][1], key_31))
cipher2[i] = (dec_one_round(cipher2[i][0], key_30), dec_one_round(cipher2[i][1], key_30))
cipher2[i] = (dec_one_round(cipher2[i][0], key_29), dec_one_round(cipher2[i][1], key_29))
cipher2[i+5] = (dec_one_round(cipher2[i+5][0], key_31), dec_one_round(cipher2[i+5][1], key_31))
cipher2[i+5] = (dec_one_round(cipher2[i+5][0], key_30), dec_one_round(cipher2[i+5][1], key_30))
cipher2[i+5] = (dec_one_round(cipher2[i+5][0], key_29), dec_one_round(cipher2[i+5][1], key_29))
for i in range(10):
cipher2[i] = (cipher2[i][0], cipher2[i][0] ^^ cipher2[i][1])
key_28 = [set(range(0x100)), set(range(0x100)), set(range(0x100)), set(range(0x100))]
for i in range(5):
ct, ct_diff = cipher2[i][0], cipher2[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3] ^^ 0x00000022))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_28_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_28_candidate.add(x[j] ^^ _)
key_28[j] = key_28[j] & key_28_candidate
status = True
for j in range(4):
if len(key_28[j]) > 2:
status = False
break
if status:
break
for i in range(5, 10):
ct, ct_diff = cipher2[i][0], cipher2[i][1]
ct_list = put_uint128_be(ct)
ct_diff_list = put_uint128_be(ct_diff)
y_diff = put_uint32_be(inv_L(ct_diff_list[3] ^^ 0x00000033))
x_diff = put_uint32_be(ct_diff_list[0] ^^ ct_diff_list[1] ^^ ct_diff_list[2])
x = put_uint32_be(ct_list[0] ^^ ct_list[1] ^^ ct_list[2])
for j in range(4):
key_28_candidate = set()
for _ in diff_dist[(x_diff[j], y_diff[j])]:
key_28_candidate.add(x[j] ^^ _)
key_28[j] = key_28[j] & key_28_candidate
status = True
for j in range(4):
if len(key_28[j]) > 1:
status = False
break
if status:
break
key_28 = get_uint32_be([list(_)[0] for _ in key_28])
print("[+] KEY_28: " + hex(key_28)[2:].rjust(8, '0'))
return key_28

def decrypt(enc_flag, key):
flag = b''
cur_key = key.copy()
for i in range(31, 3, -1):
cur_key = [inv_key(cur_key, i)] + cur_key[:3]
key = [cur_key[0]] + key
print(key)
for i in range(0, len(enc_flag), 32):
ct = int(enc_flag[i:i+32], 16)
ct = put_uint128_be(ct)[::-1]
ct = get_uint128_be(ct)
for j in range(32):
ct = dec_one_round(ct, key[31-j])
flag += long_to_bytes(ct)
return flag

if __name__ == '__main__':
print("[+] ENC_FLAG: " + enc_flag)
key_31 = crack_key_31()
key_30 = crack_key_30(key_31)
key_29 = crack_key_29(key_30)
key_28 = crack_key_28(key_31, key_30, key_29)
flag = decrypt(enc_flag, [key_28, key_29, key_30, key_31])
print("[+] FLAG: ", end="")
print(flag)
'''
[+] ENC_FLAG: 727c7eaf8523804472be1ff976134a65d73adba28d7345930e093b85afda0ac5b5656f7514c7736aac88a71bb6643d74
[+] KEY_31: 0fbe90cc
[+] KEY_30: 16c7d67f
[+] KEY_29: 35391715
[+] KEY_28: f28f45db
[3882886102, 2010179002, 1572873813, 3432904981, 3427876090, 3281517115, 2835549035, 4196177867, 1356056259, 2774742866, 1808643826, 2696649042, 654419433, 3294131222, 1850743528, 2512000155, 2125249481, 1701367673, 2144276591, 2240982232, 3335194424, 4293333648, 2630048960, 3717105645, 3211860000, 304813210, 2907456674, 334588889, 4069475803, 892933909, 382195327, 264147148]
[+] FLAG: b'flag{a9028a9c58c5749cdad329c09c80cc1b}\n\n\n\n\n\n\n\n\n\n'
'''

https://eprint.iacr.org/2010/063.pdf

flag_system

[题解分析]

魔改md5碰撞题- -就只有长亭拿了个一血(以后有空再补吧