Pwnのラベルの付いたRev問です。時間内に解けなかったのですが、おもしろかったので書き残します。

この問題は290行くらいのpythonコードでvmが書かれていて、ユーザが入力した命令列を動作させます。

変数が意味のない文字列になっているのでところどころ勝手に変えつつ、何をしたらよいのかを探します。

まず、フラグが出る部分です。magicという命令を入れたときに、リセットが2回、かつレジスタのr0からr3が決まった乱数値になっていればフラグがレジスタに出力されます。

        elif ins.op == "magic":
            if self.RESET == 2:
                if tuple(self.REG[0:4]) == self.RANDOMS:
                    with open("flag.txt", "rb") as fp:
                        cc = fp.read()
                    cc = cc.strip()
                    cc = cc.ljust(len(self.REG)*4, b"\x00")
                    for i in range(len(self.REG)):
                        self.REG[i] = struct.unpack("<I", cc[i*4:(i+1)*4])[0]

この乱数値はCpuクラスの初期化時にRANDOMSにロードされて、これがvmが動作を始めた最初に一度だけ、メモリにロードされます。

class Cpu:
    PC = 0
    instructions = None
    REG = None
    memory = None
    COUNTER = 0
    CACHE = None
    RESET = 0
    RANDOMS = None
    def __init__(self):
        self.instructions = []
        self.CACHE = {}
        self.RANDOMS = (random.randint(1,4200000000), random.randint(1,4200000000) , random.randint(1,4200000000), random.randint(1,4200000000))
        self.reset()

    ...

    def run(self, debug=1):
        ins = self.instructions[0]
        for i,v in enumerate(self.RANDOMS):
            self.memory[i] = v 
        while (self.PC>=0 and self.PC<len(self.instructions) and self.RESET<4 and self.COUNTER<20000):
            ins = self.instructions[self.PC]
            self.execute(ins)    

じゃあリセットして、ロードされたのをレジスタに入れてmagicしたら終わりだねと思って組むと、上手くいきません。なぜでしょう。

おなじCpuクラスのresetが、メモリを消しています。

    def reset(self):
        self.PC = 0
        self.REG = [0 for r in range(reg_nums)]
        self.memory = [0 for _ in range(memory_size)]
        self.COUNTER = 0
        for k in self.CACHE.keys():
            self.CACHE[k] = 0

とりあえず何をすれば良いのかがここで分かります。この問題は、リセットを挟みつつ、RANDOMSに乗っている値を復元する必要があると読み取れます。

一見きれいさっぱり消されているように見えますが、CACHEだけはCACHE={}とすればよいところ、元からあるものをそのまま使っています。これが今回のポイントです。

CACHEは、movfromを使ってメモリからレジスタに値をロードするときに、ロードした値を保持して置く機能です。逆にmovtoを使ってメモリに書き込むときにはそのキーを消す。という実装になっています。

        elif ins.op == "movfrom":
            memory_address = ins.offset + self.REG[ins.dsp]
            memory_address = memory_address % len(self.memory)
            if memory_address in self.CACHE:
                v = self.CACHE[memory_address]
                v = (v & 0xffffffff)
                self.REG[ins.dst] = v
                self.PC += 1
            else:
                v = self.memory[memory_address]
                self.CACHE[memory_address] = v
                self.execute(ins)
        elif ins.op == "movto":
            memory_address = ins.offset + self.REG[ins.dsp]
            memory_address = memory_address % len(self.memory)
            if memory_address in self.CACHE:
                del self.CACHE[memory_address]
            v = (self.REG[ins.dst] & 0xffffffff)
            self.memory[memory_address] = v
            self.PC += 1

結局valueは全部飛んじゃうので、reset前の情報は保持できないのではと思ってほぼ諦めていましたが、日曜日にふと、「CACHEのkeyが残るならそこに残せばいいだけか」と思いつきます。(これが毎度遅いです。)

上記のコードの通り、movfromを処理するときに該当のkeyが無い場合は新たに作成し、execute(ins)としてmovfromを再帰的に処理します。この影響で実行した命令数を格納するカウンタが1多くなります。つまりCACHEが存在する場合はカウンタが1小さいので、これを利用してRANDOMSを復元すれば良いのです。

愚直に書くと以下のようになります。キャッシュヒットすると3、しないと4になりますので、そこで場合分けができます。

time
mov r0 r1
movfrom r0 0 r2
time
sub r0 r1
movc r1 3
sub r0 r1
jmpz 2
(キャッシュが存在したときの処理)
...

RANDOMSは16バイトあります。剰余を入れたりと色々な格納方法を考えましたが、結局ビットで分けるのが一番良さそうだと思ったので、1を倍々にした数字をRANDOMSからロードしたレジスタとandして、非ゼロならキャッシュを作るというコードを作りました。無茶苦茶長いやつをやっとこさ作ったところで、55命令という制限を見つけて困ります。ゴルフ問でした。

必死に削って、ちょうど55命令の乱数退避→キャッシュ復元→キャッシュ初期化→magic→キャッシュに乱数を埋める→resetというコードができて動作も確認しましたが、これも上手くいきません。なんとmagicで強制停止するコードを使って検証していたため、元のコードではmagicを呼んだあとhaltしないと読めないということが原因と気づき、カツカツの命令数で今更timeの条件分岐なんか突っ込めません。ここで静かに寝ました。

知識が浅かったです。皆さんのwriteupを見た後の感想としては、以下のような改善をするべきと感じました。

  • jmpのオフセット計算をやる時点で、これはスクリプトにするべき
  • 1つの大きなループにする必要はない。最初にスイッチを作って、キャッシュのkeyに乱数を埋める方と、復元してmagicしてhaltする方を作ればよい
  • 疑似コードを書くのが良い

これらを踏まえて作り直してみました。最初のtimeは引き算しないのが地味に効きます。ほか、r0timejmpgで使われるので、カウンタや定数で使わないようにする必要があります。

# start()
#       r1 = movfrom(0, r0)
#       r0 = time
#       r1 = 2
#       if r0 > 2:
#               goto decompose
start:
        movfrom r1 0 r0
        time
        movc r1 2
        jmpg r1 decompose

# restore()
#       r1 = 1
#       r7 = 2
#
#       for r6 in range(4): <- out loop
#               r4 = 1
#               r5 = 0
#               r2 = 0
#               for r2 in range(32): <- in loop
#                       r3 = time
#                       movfrom(100, r8)
#                       r0 = time
#                       r0 -= r3
#                       if r0 > 3:
#                               pass
#                       else:
#                               r5 += r4
#                       r4 *= r7
#                       r8 += 1
#                       r2 += 1
#               movto(r5, 999, r6)
#               r6 += 1
restore:
        movc r1 1
        movc r7 2

restore_out_loop:
        movc r4 1
        movc r5 0
        movc r2 0

restore_in_loop:
        time 
        mov r3 r0
        movfrom r9 100 r8
        time
        sub r0 r3
        movc r9 3
        jmpg r9 not_hit
hit:
        add r5 r4
not_hit:
        mul r4 r7
        add r8 r1
        add r2 r1
        movc r0 32
        jmpg r2 restore_in_loop
        movto r5 1000 r6
        add r6 r1
        movc r0 4
        jmpg r6 restore_out_loop
        movfrom r3 999 r6
        movfrom r2 998 r6
        movfrom r1 997 r6
        movfrom r0 996 r6
        magic
        halt

# decompose()
#       r4 = 1
#       r5 = 2
#
#       for r6 in range(4): <- out loop
#               r1 = 0
#               r2 = 1
#               for r1 in range(32): <- in loop
#                       r3 = movfrom(0, r6)
#                       if r3 & r2 != 0:
#                               movfrom(100, r8)
#                       r2 *= r5
#                       r8 += 1
#                       r1 += 1
#               r6 += 1
decompose:
        movc r4 1
        movc r5 2

decomp_out_loop:
        movc r1 0
        movc r2 1

decomp_in_loop:
        movfrom r3 0 r6
        and r3 r2
        mov r0 r3
        jmpz bit_off
bit_on:
        movfrom r9 100 r8
bit_off:
        mul r2 r5
        add r8 r4
        add r1 r4
        movc r0 32
        jmpg r1 decomp_in_loop
        add r6 r4
        movc r0 4
        jmpg r6 decomp_out_loop
        reset

ラベルを計算してくれるスクリプトはこちらです。

from pwn import *
if not args.R:
        r = process(['python3', 'cpu.py'])
else:
        r = remote('ctf.b01lers.com', 9204)

# variable
row_count = 0
actual_code = ''
labels = []

# function
def is_label(line):
        if line.strip(' ')[-1] == ':':
                assert '\t' not in line, "Colon and Tab are in same line."
                return True
        return False

def search_label(line):
        for l in line.split(' '):
                if "jmp" in l:
                        continue
                for x in labels:
                        if x[1] == l:
                                return x
        print("something wrong")
        assert()

def check_size(code):
        instructions = len(code.split('\n'))
        print("Code size:", instructions)
        assert instructions <= 55 

# main
with open('./code', 'r') as f:
        lines = f.read().split('\n')

# collect all label line -> store with row_count
for line in lines:
        if len(line) < 2 or line[0] == "#":
                continue
        if is_label(line) == True:
                labels.append((row_count, line.strip(' ').split(':')[0]))
                continue
        row_count += 1

row_count = 0
for line in lines:
        # ignore newline, comment
        if len(line) < 2 or line[0] == "#":
                continue
        # ignore label line this time
        if is_label(line) == True:
                continue
        # jmp line -> replace label with row_count
        if "jmp" in line:
                tup = search_label(line)
                fixed_line = line.replace(tup[1], str(tup[0] - row_count))

        else:
                fixed_line = line

        # normal line -> just add
        actual_code += fixed_line
        actual_code += '\n'
        row_count += 1

check_size(actual_code)

payload = b''
payload += actual_code.encode()
payload += b'\n' * 3

print("Sending:")
print(actual_code)
r.sendlineafter(b'lines.\n', payload)
r.recvuntil(b'Registers: [')

flagint = []
for _ in range(4):
        flagint.append(int(r.recvuntil(b', ', True)))
flag = ''.join([p32(i).decode() for i in flagint])
r.recvuntil(b']\n')
print("FLAG:", flag)

r.close()

簡単かと思ったらそうでもなくて、思いついたと思ってもそれがスッと形にできないが、答えはきれいに書けるようになっていて、おもしろかったです。

解法が思いつかない時も、とりあえずと思って適当なぐあいのスクリプトを用意できるといいかもしれないと思いました。