
/*
  This file implements CTracer class,
  which is process tracer,
  used to find out instruction execution order
  and other related information.
*/

#define LOG_TRACER
//#define LOG_OPCODES
//#define TRACE_DLL_STARTUPS

#include "rwmem.cpp"
#include "t2h.cpp"
#include "int3.cpp"
#include "range.cpp"

class CTracer
{
  public:

  CRangeList RangeList;
  CT2HList   T2HList;
  CInt3List  Int3List;

  PROCESS_INFORMATION pinfo;
  t2h_struct* t2h;
  range_struct* r;
  CONTEXT ctx;

  virtual void OnTrace();

  int TraceFile(char* cmdline, DWORD tracetime, DWORD traceflags); // TR_xxx
  void DumpLog(char* fname);
}; // class CTracer

void CTracer::OnTrace()
{
  // virtual method -- implement your own stuff here
  log("virtual method called!\n");
  exit(0);
} // CTracer::OnTrace

int CTracer::TraceFile(char* cmdline, DWORD tracetime, DWORD traceflags)
{
  if (GetVersion() & 0x80000000)
  {
#ifdef LOG_TRACER
    log("ERROR: cant trace under w9x (because its sucks. hint: use win2k instead)\n");
#endif
    return ERR_TRACER_W9X;
  }

  STARTUPINFO sinfo = {sizeof(STARTUPINFO)};
  sinfo.dwFlags = STARTF_USESHOWWINDOW;
  sinfo.wShowWindow = SW_NORMAL;                // SW_HIDE

  int res = CreateProcess(0,cmdline,0,0,0,
                         DEBUG_PROCESS /*| DEBUG_ONLY_THIS_PROCESS*/,0,0,
                         &sinfo, &pinfo);
  if (res == 0)
  {
#ifdef LOG_TRACER
    log("CreateProcess() failed, GetLastError = %i\n", GetLastError());
#endif
    return ERR_TRACER_CREATEPROCESS;
  }

  SetDebugErrorLevel(SLE_WARNING);

  int pcount = 1;

  RangeList.Empty();

  int first_process = 1;
  int first_int3 = 1;
  int was_except = 0;
  int result;
  DWORD debugbreak;

  DWORD endtime = GetTickCount() + tracetime;
  DWORD g_index = 0;

  for(;;)       // main cycle
  {

    if (GetTickCount() > endtime)
    {
      if (was_except == 1)  // if something has been traced
      {
#ifdef LOG_TRACER
        log("TraceTime expired, break (success)\n");
#endif
        result = ERR_TRACER_SUCCESS;
      }
      else
      {
#ifdef LOG_TRACER
        log("TraceTime expired, break (timeout, nothing traced)\n");
#endif
        result = ERR_TRACER_TIMEOUT;
      }
      break;
    }

    DEBUG_EVENT debugevent;

    int res = WaitForDebugEvent(&debugevent, 60*1000);    // max 60 sec
    if (res == 0)
    {
#ifdef LOG_TRACER
      log("WaitForDebugEvent() failed, GetLastError = %i, break\n", GetLastError());
#endif
      result = ERR_TRACER_WAITTIMEOUT;
      break;
    }

    DWORD ContinueStatus = DBG_CONTINUE;

    if (debugevent.dwDebugEventCode == CREATE_PROCESS_DEBUG_EVENT)
    {
#ifdef LOG_TRACER
      log("[CREATE_PROCESS_DEBUG_EVENT]\n");
#endif
      if (!first_process--)
      {
#ifdef LOG_TRACER
        log("ERROR: child process started, terminating\n");
        int res =
#endif
        TerminateProcess(debugevent.u.CreateProcessInfo.hProcess, 0);
#ifdef LOG_TRACER
        if (res == 0)
          log("TerminateProcess() failed, error = %i\n", GetLastError());
#endif
        result = ERR_TRACER_CHILDPROCESS;
        break;
      }

      if (T2HList.Add( debugevent.dwThreadId, debugevent.u.CreateProcessInfo.hThread)==0)
      {
#ifdef LOG_TRACER
        log("ERROR: cant add T2H record\n");
#endif
        result = ERR_TRACER_T2H;
        break;
      }
      if (RangeList.Add( pinfo.hProcess, (DWORD)debugevent.u.CreateProcessInfo.lpBaseOfImage )==0)
      {
#ifdef LOG_TRACER
        log("ERROR: range error\n");
#endif
        result = ERR_TRACER_RANGE;
        break;
      }
      if (Int3List.Insert( pinfo.hProcess, (DWORD)debugevent.u.CreateProcessInfo.lpStartAddress)==0)
      {
#ifdef LOG_TRACER
        log("ERROR: cant insert INT3 into main thread\n");
#endif
        result = ERR_TRACER_CANTINSERTINT3TOTHREAD;
        break;
      }
    } // CREATE_PROCESS_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == LOAD_DLL_DEBUG_EVENT)
    {
#ifdef LOG_TRACER
      log("[LOAD_DLL_DEBUG_EVENT]\n");
#endif
      if (RangeList.Add( pinfo.hProcess, (DWORD)debugevent.u.LoadDll.lpBaseOfDll )==0)
      {
#ifdef LOG_TRACER
        log("ERROR: range error\n");
#endif
        result = ERR_TRACER_RANGE;
        break;
      }

      if (traceflags & TR_TRACE_DLL_STARTUPS)
      {
        DWORD peptr = *(DWORD*) &((range_struct*)RangeList.tail)->mem0[ 0x3C ];
        DWORD eip   = *(DWORD*) &((range_struct*)RangeList.tail)->mem0[ peptr+0x28 ];
        if (eip)
        {
          if (Int3List.Insert( pinfo.hProcess, ((range_struct*)RangeList.tail)->base + eip )==0)
          {
#ifdef LOG_TRACER
            log("ERROR: cant insert INT3 into DLL entry\n");
#endif
            result = ERR_TRACER_CANTINSERTINT3TODLL;
            break;
          }
        }
      } // TR_TRACE_DLL_STARTUPS

    } // LOAD_DLL_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == UNLOAD_DLL_DEBUG_EVENT)
    {
      // RangeList.Del( (DWORD)debugevent.u.LoadDll.lpBaseOfDll );
#ifdef LOG_TRACER
      log("[UNLOAD_DLL_DEBUG_EVENT]\n");
      log("got unload dll event, exiting\n");
#endif
      result = ERR_TRACER_UNLOADDLL;
      break;
    } // UNLOAD_DLL_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == CREATE_THREAD_DEBUG_EVENT)
    {
#ifdef LOG_TRACER
      log("[CREATE_THREAD_DEBUG_EVENT]\n");
#endif
      if (T2HList.Add(debugevent.dwThreadId, debugevent.u.CreateThread.hThread)==0)
      {
#ifdef LOG_TRACER
        log("ERROR: cant add T2H record\n");
#endif
        result = ERR_TRACER_T2H;
        break;
      }

      if (Int3List.Insert( pinfo.hProcess, (DWORD)debugevent.u.CreateThread.lpStartAddress)==0)
      {
        result = ERR_TRACER_CANTINSERTINT3TOTHREAD;
        break;
      }
    } // CREATE_THREAD_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == EXIT_THREAD_DEBUG_EVENT)
    {
#ifdef LOG_TRACER
      log("[EXIT_THREAD_DEBUG_EVENT]\n");
#endif
      T2HList.Del(debugevent.dwThreadId);
    } // EXIT_THREAD_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == EXIT_PROCESS_DEBUG_EVENT)
    {
      pcount=0;
#ifdef LOG_TRACER
      log("[EXIT_PROCESS_DEBUG_EVENT]\n");
      log("got exit process event, exiting\n");
#endif
      result = ERR_TRACER_EXITPROCESS;
      break;
    } // EXIT_PROCESS_DEBUG_EVENT

    if (debugevent.dwDebugEventCode == RIP_EVENT)
    {
      pcount=0;
#ifdef LOG_TRACER
      log("[RIP_EVENT]\n");
#endif
      result = ERR_TRACER_RIPPROCESS;
      break;
    } // RIP_EVENT

    if (debugevent.dwDebugEventCode == EXCEPTION_DEBUG_EVENT)
    {
      was_except = 1;

      t2h = T2HList.FindByTID( debugevent.dwThreadId );
      if (t2h == NULL)
      {
#ifdef LOG_TRACER
        log("ERROR: T2HList::FindByTID(tid=%08X) error\n", debugevent.dwThreadId);
#endif
        result = ERR_TRACER_CANTFINDTHBYTID;
        break;
      }
      SuspendThread(t2h->handle);

      ctx.ContextFlags = CONTEXT_CONTROL |
                         CONTEXT_INTEGER |
                         CONTEXT_SEGMENTS;
                      // CONTEXT_FLOATING_POINT | CONTEXT_DEBUG_REGISTERS;

      if (GetThreadContext(t2h->handle, &ctx) == 0)
      {
#ifdef LOG_TRACER
        log("ERROR: GetThreadContext() error, GetLastError() = %i\n", GetLastError());
#endif
        result = ERR_TRACER_GETCONTEXT;
        break;
      }

      DWORD exaddr = (DWORD)debugevent.u.Exception.ExceptionRecord.ExceptionAddress;

      r = RangeList.Find(exaddr);
      if (r == NULL)
      {
#ifdef LOG_TRACER
        log("ERROR:exception address not in my range, exaddr = %08X\n", exaddr);
#endif
        result = ERR_TRACER_EXADDRNOTINRANGE;
        break;
      }

      if (debugevent.u.Exception.ExceptionRecord.ExceptionCode == EXCEPTION_SINGLE_STEP)
      {
        // assert(exaddr == ctx.Eip);

        BYTE o[2];

        if ((traceflags & TR_DONT_BYPASS_REP)==0)
        {
          if (read_memory( pinfo.hProcess, ctx.Eip, o, 2)==0)
          {
#ifdef LOG_TRACER
            log("ERROR: cant read memory at EIP = %08X\n", ctx.Eip);
#endif
            result = ERR_TRACER_CANTREADMEMATEXADDR;
            break;
          }
        }

        if ( ( (traceflags & TR_DONT_BYPASS_REP)==0 ) &&
             ( (((*(WORD*)&o[0])&0xFCFE)==0xA4F2) ||  // F2/F3 A3/A4/A5/A6 repz/repnz movs/cmps
               (((*(WORD*)&o[0])&0xFAFE)==0xAAF2) ) ) // F2/F3 AA/AB/AE/AF repz/repnz stos/scas
        {
//#ifdef LOG_TRACER
//            log("trying to bypass REP at %08X\n", ctx.Eip);
//#endif
          if (Int3List.Insert( pinfo.hProcess, ctx.Eip+2)==0)
          {
#ifdef LOG_TRACER
            log("ERROR: cant insert INT3 into code at %08X\n", ctx.Eip+2);
#endif
            result = ERR_TRACER_CANTINSERTINT3TOCODE;
            break;
          }
        }
        else
        {
          ctx.EFlags |= 0x0100;       // set TF
        }

        ContinueStatus = DBG_CONTINUE;
      } // EXCEPTION_SINGLE_STEP
      else
      if (debugevent.u.Exception.ExceptionRecord.ExceptionCode == EXCEPTION_BREAKPOINT)
      {

        if ((first_int3 == 1) || (exaddr == debugbreak))
        {
          first_int3 = 0;
          debugbreak = exaddr;
#ifdef LOG_TRACER
          log("DebugBreak, bypassing (exaddr = %08X)\n", exaddr);
#endif
        }
        else
        {
          if (Int3List.Disable( pinfo.hProcess, exaddr)==0)
          {
#ifdef LOG_TRACER
            log("ERROR: cant disable int3 at %08X\n",exaddr);
#endif
            result = ERR_TRACER_CANTDISABLEINT3;
            break;
          }
          ctx.Eip--;                    // IMPORTANT
          ctx.EFlags |= 0x0100;         // set TF

        }
        ContinueStatus = DBG_CONTINUE;
      } // EXCEPTION_BREAKPOINT
      else
      {
#ifdef LOG_TRACER
        log("real exception occured, code = %08X, addr = %08X\n", debugevent.u.Exception.ExceptionRecord.ExceptionCode, exaddr);
#endif
        if (r != NULL)
        {
          r->flag[ exaddr - r->base ] |= R_EXCEPTION;
        }

        LDT_ENTRY selector;
        GetThreadSelectorEntry(t2h->handle, ctx.SegFs, &selector);
        DWORD fsaddr = selector.BaseLow |
                      (selector.HighWord.Bytes.BaseMid << 16) |
                      (selector.HighWord.Bytes.BaseHi  << 24);
        DWORD temp;
        if (read_memory( pinfo.hProcess, fsaddr, (BYTE*)&temp, 4) == 0)
        {
#ifdef LOG_TRACER
          log("cant read DWORD at fsaddr = %08X\n", fsaddr);
#endif
          result = ERR_TRACER_CANTREADMEMATFSADDR;
          break;
        }
        DWORD sehaddr;
        if (read_memory( pinfo.hProcess, temp+4, (BYTE*)&sehaddr, 4) == 0)
        {
#ifdef LOG_TRACER
          log("cant read DWORD at temp+4 = %08X\n", temp+4);
#endif
          result = ERR_TRACER_CANTREADMEMATSEHADDR;
          break;
        }

        // since SEH handler can be in other EXE/DLL
        range_struct* seh_range = RangeList.Find(sehaddr);

        if (seh_range == NULL)
        {
#ifdef LOG_TRACER
          log("exception handler not in range, addr = %08X\n", sehaddr);
#endif
          result = ERR_TRACER_SEHADDRNOTINRANGE;
          break;
        }

#ifdef LOG_TRACER
        log("inserting INT3 into exception handler, addr = %08X\n", sehaddr);
#endif
        if (Int3List.Insert( pinfo.hProcess, sehaddr) == 0)
        {
#ifdef LOG_TRACER
          log("ERROR: cant insert INT3 into exception handler, addr = %08X\n", sehaddr);
#endif
          result = ERR_TRACER_CANTINSERTINT3TOSEHADDR;
          break;
        }

        seh_range->flag[ sehaddr - seh_range->base ] |= R_SEHHANDLER;

        ContinueStatus = DBG_EXCEPTION_NOT_HANDLED;
      } // any other exception

      if (SetThreadContext(t2h->handle, &ctx)==0)
      {
#ifdef LOG_TRACER
        log("ERROR: SetThreadContext() error, GetLastError() = %i\n", GetLastError());
#endif
        result = ERR_TRACER_SETCONTEXT;
        break;
      }
      ResumeThread(t2h->handle);

//#ifdef LOG_OPCODES
//      log("%08X: ", exaddr);
//      BYTE bytes[32];
//      if (read_memory( pinfo.hProcess, exaddr, bytes, 32))
//      {
//        xde_instr diza;
//        xde_disasm(bytes, &diza);
//        for(DWORD i=0; i<diza.disasm_len; i++)
//          log(" %02X", bytes[i]);
//      }
//      log("\n");
//#endif // LOG_OPCODES

      // now update info depending on opcode address

      t2h->l_index++;
      g_index++;

      if (r != NULL)
      {

        DWORD t = exaddr - r->base;

        r->flag[ t ] |= R_OPCODE;

        r->count[ t ]++;

        if (r->g_index[ t ] == 0)
          r->g_index[ t ] = g_index;
        else
          r->flag   [ t ] |= R_MULTIEXEC;

        if (r->l_index[ t ] == 0)
          r->l_index[ t ] = t2h->l_index;
        else
          r->flag[ t ] |= R_MULTIEXEC;

        if (r->thread[ t ] != t2h->id)
        {
          if (r->thread[ t ] == NULL)
            r->thread[ t ] = t2h->id;
          else
          {
            r->flag[ t ] |= R_MULTITHREAD;
            r->thread[ t ] = -1;  // multi-thread
          }
        }

        if (t2h->prev_ins != exaddr)            // to avoid REP, JMP $
          r->prev_ins[ t ] = t2h->prev_ins;

      } // if (r != NULL)

      OnTrace();

      t2h->prev_ins = exaddr;

    } // EXCEPTION_DEBUG_EVENT

    res = ContinueDebugEvent(debugevent.dwProcessId,
                             debugevent.dwThreadId,
                             ContinueStatus);
    if (res == 0)
    {
#ifdef LOG_TRACER
      log("ContinueDebugEvent() failed, GetLastError = %i, break\n", GetLastError());
#endif
      result = ERR_TRACER_CONTDBGEVENT;
      break;
    }
  } // main cycle

  if (Int3List.DeleteAll(pinfo.hProcess)==0)
  {
#ifdef LOG_TRACER
    log("Int3List::DeleteAll() failed\n");
#endif
    result = ERR_TRACER_CANTDELETEALLINT3;
  }

  if (result == ERR_TRACER_SUCCESS)
  {
    if (RangeList.root)
      memcpy(((range_struct*)RangeList.root)->filename, cmdline, 260);

    ForEachInList(RangeList, range_struct, r)
    {
      if (read_memory(pinfo.hProcess, r->base, r->mem1, r->size)==0)
      {
#ifdef LOG_RANGE
        log("ERROR: MODULE(base=%08X,size=%08X): cant read the whole image\n", r->base, r->size);
#endif
        return ERR_TRACER_CANTREADIMAGE;
      }

      for(DWORD t = 0; t < r->size; t++)
        if (r->mem0[ t ] != r->mem1[ t ])
          r->flag[ t ] |= R_VARIABLE;

    } // ForEachInList
  }

#ifdef LOG_TRACER
  log("terminating process\n");
  res =
#endif
  TerminateProcess(pinfo.hProcess, 0);
#ifdef LOG_TRACER
  if (res == 0)
    log("TerminateProcess() failed, error = %i\n", GetLastError());
#endif

  if (pcount)
  {

#ifdef LOG_TRACER
    log("waiting for exit event...\n");
#endif

    while(pcount)       // main cycle
    {
      DEBUG_EVENT debugevent;

      int res = WaitForDebugEvent(&debugevent, 60*1000);    // max 60 sec
      if (res == 0) break;

      if ( (debugevent.dwDebugEventCode == EXIT_PROCESS_DEBUG_EVENT) ||
           (debugevent.dwDebugEventCode == RIP_EVENT)  )
        pcount = 0;

      ContinueDebugEvent(debugevent.dwProcessId,
                         debugevent.dwThreadId,
                         DBG_EXCEPTION_NOT_HANDLED);
    }

#ifdef LOG_TRACER
    log("waiting for process to terminate...\n");
#endif

    res = WaitForSingleObject(pinfo.hProcess,60*1000);
    if (res != WAIT_OBJECT_0)
    {
#ifdef LOG_TRACER
      log("WaitForSingleObject() failed, res=%i, GetLastError=%i\n", res, GetLastError());
#endif
      result = ERR_TRACER_CANTKILL;
    }

  }

  return result;
} // CTracer::TraceFile()

void CTracer::DumpLog(char* fname)
{
  FILE*f=fopen(fname,"wb");
  if (f != NULL)
  {
    ForEachInList(RangeList, range_struct, r)
    {

      fprintf(f,"--- base=%08X size=%08X name=[%s]\n",
        r->base,
        r->size,
        r->filename);

      for(DWORD i=0; i<r->size; i++)
      if (r->flag[ i ] != 0)
      {
        fprintf(f,"%08X p=%08X l=%08X g=%08X t=%08X c=%08X %02X->%02X [",
          r->base + i,
          r->prev_ins[ i ],
          r->l_index[ i ],
          r->g_index[ i ],
          r->thread[ i ],
          r->count[ i ],
          r->mem0[i],
          r->mem1[i]);
        if (r->flag[i] & R_OPCODE     ) fprintf(f," R_OPCODE");
        if (r->flag[i] & R_MULTITHREAD) fprintf(f," R_MULTITHREAD");
        if (r->flag[i] & R_MULTIEXEC  ) fprintf(f," R_MULTIEXEC");
        if (r->flag[i] & R_EXCEPTION  ) fprintf(f," R_EXCEPTION");
        if (r->flag[i] & R_SEHHANDLER ) fprintf(f," R_SEHHANDLER");
        if (r->flag[i] & R_VARIABLE   ) fprintf(f," R_VARIABLE");
        fprintf(f," ]\n");
      }

    } // ForEachInList
  }
  fclose(f);
} // CTracer::DumpLog()
