#prob.py
from cipher import STREAM
import random
if __name__ == "__main__":
with open("flag", "rb") as f:
flag = f.read()
assert flag[:3] == b'DH{' and flag[-1:] == b'}'
seed = random.getrandbits(16)
stream = STREAM(seed, 16)
print(f"encrypted flag > {stream.encrypt(flag).hex()}")
#cipher.py
class STREAM:
def __init__(self, seed, size):
self.state = self.num2bits(seed, size)
def num2bits(self, num, size):
assert num < (1 << size)
return bin(num)[2:].zfill(size)
def bits2num(self, bits):
return int('0b' + bits, 2)
def shift(self):
new_bit = self.state[-1]
self.state = new_bit + self.state[:-1]
return new_bit
def getNbits(self, num):
sequence = ""
for _ in range(num):
sequence += self.shift()
return sequence
def encrypt(self, plaintext):
ciphertext = b""
for p in plaintext:
stream = self.bits2num(self.getNbits(8))
c = p ^ stream
ciphertext += bytes([c])
return ciphertext
def decrypt(self, ciphertext):
plaintext = b""
for c in ciphertext:
stream = self.bits2num(self.getNbits(8))
p = c ^ stream
plaintext += bytes([p])
return plaintext
if __name__ == "__main__":
import os
for seed in range(0x100):
Alice = STREAM(seed, 16)
Bob = STREAM(seed, 16)
plaintext = os.urandom(128)
ciphertext = Alice.encrypt(plaintext)
assert plaintext == Bob.decrypt(ciphertext)
prob에서는 cipher 파일에 있는 STREAM 클래스를 가져와 암호화를 수행하고 hex 값으로 플래그를 출력해주는 것을 확인할 수 있다.
그리고 cipher를 살펴보면 클래스 내부에서 암호화 함수뿐만 아니라, 복호화 함수도 제공하는 것을 알 수 있다.
여기서 유심히 봐야 할 점은, prob에서 seed를 생성할 때 16비트를 랜덤하게 생성하여 사용하는 것과, cipher에서 복호화 함수를 제공한다는 것이다.
prob에서 16비트를 난수를 seed로 사용하는데, 16비트의 최대 값은 2^16 = 65,536 이므로, 브루트 포스가 충분히 가능하다고 생각해볼 수 있다.
output으로 주어진 hex 값을 다시 bytes로 바꿔서 decrypt 함수에 65,536번 반복문을 돌려 모든 가능한 경우의 seed 값을 주면서 테스트하면 플래그를 획득할 수 있다.
from cipher import STREAM
brute_force = 2**16
output_hex = "3cef03c64ac240c349971d9e4c951cc14ec4199f409249c21e964ac540c540944f901c934cc240934d96419f4b9e4d9f1cc41dc61dc34e9219c31bc11a914f9141c61ada"
output = bytes.fromhex(output_hex)
for i in range(brute_force):
seed = i
stream = STREAM(seed,16)
flag = stream.decrypt(output)
if flag[:3] == b'DH{' and flag[-1:] == b'}':
print(flag)
break
else:
print(i)