b01lersCTF 2022 - veryfastvm
Pwnのラベルの付いたRev問です。時間内に解けなかったのですが、おもしろかったので書き残します。
この問題は290行くらいのpythonコードでvmが書かれていて、ユーザが入力した命令列を動作させます。
変数が意味のない文字列になっているのでところどころ勝手に変えつつ、何をしたらよいのかを探します。
まず、フラグが出る部分です。magicという命令を入れたときに、リセットが2回、かつレジスタのr0
からr3
が決まった乱数値になっていればフラグがレジスタに出力されます。
elif ins.op == "magic":
if self.RESET == 2:
if tuple(self.REG[0:4]) == self.RANDOMS:
with open("flag.txt", "rb") as fp:
cc = fp.read()
cc = cc.strip()
cc = cc.ljust(len(self.REG)*4, b"\x00")
for i in range(len(self.REG)):
self.REG[i] = struct.unpack("<I", cc[i*4:(i+1)*4])[0]
この乱数値はCpuクラスの初期化時にRANDOMSにロードされて、これがvmが動作を始めた最初に一度だけ、メモリにロードされます。
class Cpu:
PC = 0
instructions = None
REG = None
memory = None
COUNTER = 0
CACHE = None
RESET = 0
RANDOMS = None
def __init__(self):
self.instructions = []
self.CACHE = {}
self.RANDOMS = (random.randint(1,4200000000), random.randint(1,4200000000) , random.randint(1,4200000000), random.randint(1,4200000000))
self.reset()
...
def run(self, debug=1):
ins = self.instructions[0]
for i,v in enumerate(self.RANDOMS):
self.memory[i] = v
while (self.PC>=0 and self.PC<len(self.instructions) and self.RESET<4 and self.COUNTER<20000):
ins = self.instructions[self.PC]
self.execute(ins)
じゃあリセットして、ロードされたのをレジスタに入れてmagic
したら終わりだねと思って組むと、上手くいきません。なぜでしょう。
おなじCpuクラスのreset
が、メモリを消しています。
def reset(self):
self.PC = 0
self.REG = [0 for r in range(reg_nums)]
self.memory = [0 for _ in range(memory_size)]
self.COUNTER = 0
for k in self.CACHE.keys():
self.CACHE[k] = 0
とりあえず何をすれば良いのかがここで分かります。この問題は、リセットを挟みつつ、RANDOMSに乗っている値を復元する必要があると読み取れます。
一見きれいさっぱり消されているように見えますが、CACHEだけはCACHE={}
とすればよいところ、元からあるものをそのまま使っています。これが今回のポイントです。
CACHEは、movfrom
を使ってメモリからレジスタに値をロードするときに、ロードした値を保持して置く機能です。逆にmovto
を使ってメモリに書き込むときにはそのキーを消す。という実装になっています。
elif ins.op == "movfrom":
memory_address = ins.offset + self.REG[ins.dsp]
memory_address = memory_address % len(self.memory)
if memory_address in self.CACHE:
v = self.CACHE[memory_address]
v = (v & 0xffffffff)
self.REG[ins.dst] = v
self.PC += 1
else:
v = self.memory[memory_address]
self.CACHE[memory_address] = v
self.execute(ins)
elif ins.op == "movto":
memory_address = ins.offset + self.REG[ins.dsp]
memory_address = memory_address % len(self.memory)
if memory_address in self.CACHE:
del self.CACHE[memory_address]
v = (self.REG[ins.dst] & 0xffffffff)
self.memory[memory_address] = v
self.PC += 1
結局valueは全部飛んじゃうので、reset前の情報は保持できないのではと思ってほぼ諦めていましたが、日曜日にふと、「CACHEのkeyが残るならそこに残せばいいだけか」と思いつきます。(これが毎度遅いです。)
上記のコードの通り、movfrom
を処理するときに該当のkeyが無い場合は新たに作成し、execute(ins)
としてmovfrom
を再帰的に処理します。この影響で実行した命令数を格納するカウンタが1多くなります。つまりCACHEが存在する場合はカウンタが1小さいので、これを利用してRANDOMSを復元すれば良いのです。
愚直に書くと以下のようになります。キャッシュヒットすると3、しないと4になりますので、そこで場合分けができます。
time
mov r0 r1
movfrom r0 0 r2
time
sub r0 r1
movc r1 3
sub r0 r1
jmpz 2
(キャッシュが存在したときの処理)
...
RANDOMSは16バイトあります。剰余を入れたりと色々な格納方法を考えましたが、結局ビットで分けるのが一番良さそうだと思ったので、1を倍々にした数字をRANDOMSからロードしたレジスタとandして、非ゼロならキャッシュを作るというコードを作りました。無茶苦茶長いやつをやっとこさ作ったところで、55命令という制限を見つけて困ります。ゴルフ問でした。
必死に削って、ちょうど55命令の乱数退避→キャッシュ復元→キャッシュ初期化→magic→キャッシュに乱数を埋める→reset
というコードができて動作も確認しましたが、これも上手くいきません。なんとmagic
で強制停止するコードを使って検証していたため、元のコードではmagic
を呼んだあとhalt
しないと読めないということが原因と気づき、カツカツの命令数で今更timeの条件分岐なんか突っ込めません。ここで静かに寝ました。
知識が浅かったです。皆さんのwriteupを見た後の感想としては、以下のような改善をするべきと感じました。
- jmpのオフセット計算をやる時点で、これはスクリプトにするべき
- 1つの大きなループにする必要はない。最初にスイッチを作って、キャッシュのkeyに乱数を埋める方と、復元してmagicしてhaltする方を作ればよい
- 疑似コードを書くのが良い
これらを踏まえて作り直してみました。最初のtime
は引き算しないのが地味に効きます。ほか、r0
がtime
やjmpg
で使われるので、カウンタや定数で使わないようにする必要があります。
# start()
# r1 = movfrom(0, r0)
# r0 = time
# r1 = 2
# if r0 > 2:
# goto decompose
start:
movfrom r1 0 r0
time
movc r1 2
jmpg r1 decompose
# restore()
# r1 = 1
# r7 = 2
#
# for r6 in range(4): <- out loop
# r4 = 1
# r5 = 0
# r2 = 0
# for r2 in range(32): <- in loop
# r3 = time
# movfrom(100, r8)
# r0 = time
# r0 -= r3
# if r0 > 3:
# pass
# else:
# r5 += r4
# r4 *= r7
# r8 += 1
# r2 += 1
# movto(r5, 999, r6)
# r6 += 1
restore:
movc r1 1
movc r7 2
restore_out_loop:
movc r4 1
movc r5 0
movc r2 0
restore_in_loop:
time
mov r3 r0
movfrom r9 100 r8
time
sub r0 r3
movc r9 3
jmpg r9 not_hit
hit:
add r5 r4
not_hit:
mul r4 r7
add r8 r1
add r2 r1
movc r0 32
jmpg r2 restore_in_loop
movto r5 1000 r6
add r6 r1
movc r0 4
jmpg r6 restore_out_loop
movfrom r3 999 r6
movfrom r2 998 r6
movfrom r1 997 r6
movfrom r0 996 r6
magic
halt
# decompose()
# r4 = 1
# r5 = 2
#
# for r6 in range(4): <- out loop
# r1 = 0
# r2 = 1
# for r1 in range(32): <- in loop
# r3 = movfrom(0, r6)
# if r3 & r2 != 0:
# movfrom(100, r8)
# r2 *= r5
# r8 += 1
# r1 += 1
# r6 += 1
decompose:
movc r4 1
movc r5 2
decomp_out_loop:
movc r1 0
movc r2 1
decomp_in_loop:
movfrom r3 0 r6
and r3 r2
mov r0 r3
jmpz bit_off
bit_on:
movfrom r9 100 r8
bit_off:
mul r2 r5
add r8 r4
add r1 r4
movc r0 32
jmpg r1 decomp_in_loop
add r6 r4
movc r0 4
jmpg r6 decomp_out_loop
reset
ラベルを計算してくれるスクリプトはこちらです。
from pwn import *
if not args.R:
r = process(['python3', 'cpu.py'])
else:
r = remote('ctf.b01lers.com', 9204)
# variable
row_count = 0
actual_code = ''
labels = []
# function
def is_label(line):
if line.strip(' ')[-1] == ':':
assert '\t' not in line, "Colon and Tab are in same line."
return True
return False
def search_label(line):
for l in line.split(' '):
if "jmp" in l:
continue
for x in labels:
if x[1] == l:
return x
print("something wrong")
assert()
def check_size(code):
instructions = len(code.split('\n'))
print("Code size:", instructions)
assert instructions <= 55
# main
with open('./code', 'r') as f:
lines = f.read().split('\n')
# collect all label line -> store with row_count
for line in lines:
if len(line) < 2 or line[0] == "#":
continue
if is_label(line) == True:
labels.append((row_count, line.strip(' ').split(':')[0]))
continue
row_count += 1
row_count = 0
for line in lines:
# ignore newline, comment
if len(line) < 2 or line[0] == "#":
continue
# ignore label line this time
if is_label(line) == True:
continue
# jmp line -> replace label with row_count
if "jmp" in line:
tup = search_label(line)
fixed_line = line.replace(tup[1], str(tup[0] - row_count))
else:
fixed_line = line
# normal line -> just add
actual_code += fixed_line
actual_code += '\n'
row_count += 1
check_size(actual_code)
payload = b''
payload += actual_code.encode()
payload += b'\n' * 3
print("Sending:")
print(actual_code)
r.sendlineafter(b'lines.\n', payload)
r.recvuntil(b'Registers: [')
flagint = []
for _ in range(4):
flagint.append(int(r.recvuntil(b', ', True)))
flag = ''.join([p32(i).decode() for i in flagint])
r.recvuntil(b']\n')
print("FLAG:", flag)
r.close()
簡単かと思ったらそうでもなくて、思いついたと思ってもそれがスッと形にできないが、答えはきれいに書けるようになっていて、おもしろかったです。
解法が思いつかない時も、とりあえずと思って適当なぐあいのスクリプトを用意できるといいかもしれないと思いました。