#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Loic Jaquemet loic.jaquemet+python@gmail.com
#

__author__ = "Loic Jaquemet loic.jaquemet+python@gmail.com"

__doc__ = '''
  Reverse heap analysis.
'''

import argparse
import logging
import os
import sys

import haystack
from haystack import argparse_utils
from haystack.config import Config
from haystack.reverse import signature, reversers

log = logging.getLogger('haystack-reverse')

def reverseInstances(opt):
  context = reversers.reverseInstances(opt.dumpname)
  return

def writeReversedTypes(opt):
  '''reverse types from a memorydump, and write structure definition to file '''
  context, sizeCache = signature.makeSizeCaches(opt.dumpname)  
  context = signature.makeReversedTypes(context, sizeCache)
  outfile = file(Config.getCacheFilename(Config.REVERSED_TYPES_FILENAME, context.dumpname),'w')
  for revStructType in context.listReversedTypes():
    outfile.write(revStructType.toString())
  outfile.close()
  log.info('[+] Wrote to %s'%(outfile.name))
  return 

def groupStructures(opt):
  ''' show sorted structure instances groups to stdout '''
  context, sizeCache = signature.makeSizeCaches(opt.dumpname)  
  for chains in signature.buildStructureGroup(context, sizeCache, opt.size ):
    signature.printStructureGroups(context, chains, opt.address )
  return
  
def saveSignatures(opt):
  ''' translate a memdump into a signature based file NULL,POINTERS,OTHERS'''
  context, sig = signature.makeSignatures(opt.dumpname)
  outfile = Config.getCacheFilename(Config.SIGNATURES_FILENAME, context.dumpname)
  file(outfile,'w').write(sig)
  log.info('[+] Signature written to %s'%(outfile))
  return


def show(opt):
  ''' Show a structure '''
  log.info('[+] Load context')
  context = reversers.getContext(opt.dumpname)
  log.info('[+] Find Structure at: @%x'%(opt.address))
  try:
    st = context.getStructureForOffset(opt.address)
    st.decodeFields()
    print st.toString()
  except ValueError,e:
    log.info('[+] Found no structure.')
    return
  return

def printParents(opt):
  ''' print the parental structures '''
  log.info('[+] Load context')
  context = reversers.getContext(opt.dumpname)
  log.info('[+] find offsets of struct_addr:%x'%(opt.address))
  i = 0
  try:
    child_address = context.getStructureAddrForOffset(opt.address)
    for st in context.listStructuresForPointerValue(child_address):
      st.decodeFields()
      print st.toString()
      i+=1
  except ValueError,e:
    log.info('[+] Found no structures.')
    return
  log.info('[+] Found %d structures.'%( i ))
  return



def clean(opt):
  log.info('[+] Cleaning cache')
  context = Config.cleanCache(opt.dumpname)


def graph(opt):
  ''' show sorted structure instances groups to gefx '''
  #log.info('[+] Graphing')
  #context, sizeCache = signature.makeSizeCaches(opt.dumpname)  
  #for chains in signature.buildStructureGroup(context, sizeCache, opt.size ):
  #  signature.graphStructureGroups(context, chains, opt.address )
  # TODO change to generic fn, and output graph
  return

  
def argparser():
  rootparser = argparse.ArgumentParser(prog='haystack-reverser', 
    description='Several tools to reverse engineer structures on the heap.')

  rootparser.add_argument('--debug', action='store_true', help='Debug mode on.')
  rootparser.add_argument('dumpname', type=argparse_utils.readable, action='store', help='Source memory dump by haystack.')

  subparsers = rootparser.add_subparsers(help='sub-command help')

  instances = subparsers.add_parser('instances', 
    help='List all structures instances with virtual address, member types guess and info.')
  instances.set_defaults(func=reverseInstances)  

  typemap = subparsers.add_parser('typemap', 
        help='Try to reverse generic types from instances\' similarities.')
  typemap.set_defaults(func=writeReversedTypes)  

  groupparser = subparsers.add_parser('group', help='Show structure instances groups by size and signature.')
  groupparser.add_argument('--size', type=int, action='store', default=None, 
        help='Limit to a specific structure size')
  groupparser.add_argument('--address', type=argparse_utils.int16, action='store', default=None, 
        help='Limit to structure similar to the structure pointed at <address>')
  groupparser.set_defaults(func=groupStructures)  

  parent = subparsers.add_parser('parent', help='Print the parent structures pointing to the structure located at this address.')
  parent.add_argument('address', type=argparse_utils.int16, action='store', default=None, 
        help='Hex address of the child structure.')
  parent.set_defaults(func=printParents)  

  graphparser = subparsers.add_parser('graph', help='DISABLED - Show sorted structure instances groups by size and signature in a graph.')
  graphparser.add_argument('--size', type=int, action='store', default=None, 
        help='Limit to a specific structure size')
  graphparser.add_argument('--address', type=argparse_utils.int16, action='store', default=None, 
        help='Limit to structure similar to the structure pointed at <address>')
  graphparser.set_defaults(func=graph)  

  showparser = subparsers.add_parser('show', help='Show one structure instance.')
  showparser.add_argument('address', type=argparse_utils.int16, action='store', default=None, 
        help='Specify the address of the structure, or of a structure member.')
  showparser.set_defaults(func=show)  

  # XXX delete ?
  makesig = subparsers.add_parser('makesig', help='Create a simple signature file of the heap - NULL, POINTERS, OTHER VALUES.')
  makesig.set_defaults(func=saveSignatures)  

  cleanp = subparsers.add_parser('clean', help='Clean the memory dump from cached info.')
  cleanp.set_defaults(func=clean)  


  return rootparser

def main(argv):

  parser = argparser()
  opts = parser.parse_args(argv)

  level=logging.WARNING
  if opts.debug :
    level=logging.DEBUG
    flog = os.path.normpath('log')
    logging.basicConfig(level=level, filename=flog, filemode='w')
    logging.getLogger('haystack-reverse').setLevel(logging.DEBUG)
    logging.getLogger('signature').setLevel(logging.DEBUG)
    logging.getLogger('reversers').setLevel(logging.DEBUG)
    print ('[+] **** COMPLETE debug log to %s'%(flog))    
  else:
    logging.getLogger('haystack-reverse').setLevel(logging.INFO)
    logging.getLogger('signature').setLevel(logging.INFO)
    logging.getLogger('reversers').setLevel(logging.INFO)
  sh=logging.StreamHandler(sys.stdout) # 2.6, 2.7 compat
  logging.getLogger('signature').addHandler( sh )
  logging.getLogger('reversers').addHandler( sh )
  logging.getLogger('haystack-reverse').addHandler( sh )

  opts.func(opts)
  



if __name__ == "__main__":
  sys.path.append(os.getcwd())
  main(sys.argv[1:])


