#!/home/other/Compil/python-haystack/venv/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Loic Jaquemet loic.jaquemet+python@gmail.com
#

import argparse
import logging
import sys

import os

from haystack import argparse_utils
from haystack.reverse import config
from haystack.reverse import context
from haystack.reverse import api
from haystack import dump_loader

log = logging.getLogger('haystack-reverse')


def reverse_instances(opt):
    """
    Initialise the reverse cache.

    :param opt:
    :return:
    """
    # go through the 4 step process. Double linked list, basic type, Pointers and graph.
    memory_handler = dump_loader.load(opt.dumpname)
    ctx = api.reverse_instances(memory_handler)
    return


def show_reversed_types(opt):
    """
    Show the list of reversed records types after reverse_instances.

    :param opt:
    :return:
    """
    memory_handler = dump_loader.load(opt.dumpname)
    process_context = memory_handler.get_reverse_context()
    for record_type in process_context.list_reversed_types():
        # FIXME use a file outputer for stdout ?
        print record_type.to_string()
    return


def list_strings(opt):
    """
    List all record fields that have been reversed as strings.
    :param opt:
    :return:
    """
    memory_handler = dump_loader.load(opt.dumpname)
    process_context = memory_handler.get_reverse_context()
    # look at each record in each structure for strings
    print('Heap_address,Size,txt')
    for heap_context in process_context.list_contextes():
        for record in heap_context.listStructures():
            for field in record.get_fields():
                addr = record.address + field.offset
                if field.is_string():
                    maxlen = len(field)
                    value = record.get_value_for_field(field, maxlen+10)
                    try:
                        print('0x%x,0x%x bytes,%s' % (addr, maxlen, value))
                    except IOError as e:
                        # ignore the pipe errors
                        return
    return


def show_record(opt):
    """
    Show the record for a specific address.
    :param opt:
    :return:
    """
    memory_handler = dump_loader.load(opt.dumpname)
    ctx = context.get_context_for_address(memory_handler, opt.address)
    try:
        st = ctx.get_record_at_address(opt.address)
        log.info('[+] Found record at: @%x', opt.address)
        print st.to_string()
    except ValueError as e:
        log.info('[+] Found no record at: @%x', opt.address)
        return
    return


def show_hex(opt):
    """
    Show the Hex values for the record at that address.
    :param opt:
    :return:
    """
    memory_handler = dump_loader.load(opt.dumpname)
    ctx = context.get_context_for_address(memory_handler, opt.address)
    try:
        st = ctx.get_record_at_address(opt.address)
        log.info('[+] Found record at: @%x', opt.address)
        print repr(st.bytes)
    except ValueError as e:
        log.info('[+] Found no record at: @%x', opt.address)
        return
    return


def show_predecessor(opt):
    """
    Show the predecessors that point to a record at a particular address.
    :param opt:
    :return:
    """
    memory_handler = dump_loader.load(opt.dumpname)
    ctx = context.get_context_for_address(memory_handler, opt.address)
    try:
        child_record = ctx.get_record_at_address(opt.address)
    except ValueError as e:
        log.info('[+] Found no record at: @%x', opt.address)
        return

    log.info('[+] Showing predecessors of record at: 0x%x', opt.address)
    records = api.get_record_predecessors(memory_handler, child_record)
    if len(records) == 0:
        print '# [+] No parents records found.'
    else:
        for p_record in records:
            print '#0x%x\n%s\n' % (p_record.address, p_record.to_string())

    # DEBUG pointer found by pontersearcher
    # print '---'
    # process_context = memory_handler.get_reverse_context()
    # for heap_context in process_context.list_contextes():
    #    p_records = heap_context.listStructuresForPointerValue(child_record.address)
    #    if len(p_records) == 0:
    #        continue
    #    else:
    #        for p_record in p_records:
    #            print '#0x%x\n%s\n' % (p_record.address, p_record.to_string())

    # for heap_context in process_context.list_contextes():
    # heap = memory_handler.get_mapping_for_address(0xc30000)
    # finder = memory_handler.get_heap_finder()
    # heap_walker = finder.get_heap_walker(heap)
    # heap_context = process_context.get_context_for_heap(heap)
    # offsets, values = heap_context.get_heap_pointers_from_allocated(heap_walker)
    # for addr in map(long, offsets):
    #     print hex(addr),
    # print ''
    return


def clean(opt):
    log.info("Removing cache folder %s", config.get_cache_folder_name(opt.dumpname))
    config.remove_cache_folder(opt.dumpname)


def argparser():
    rootparser = argparse.ArgumentParser(prog='haystack-reverse',
                                         description='Several tools to reverse engineer records from heap.')

    rootparser.add_argument('--debug', action='store_true', help='Debug mode on.')
    rootparser.add_argument('dumpname', type=argparse_utils.readable, action='store', help='Memory dump filename')

    subparsers = rootparser.add_subparsers(help='sub-command help')

    instances = subparsers.add_parser('analyze', help='Step 1. Run heuristics to reverse allocations record types')
    instances.set_defaults(func=reverse_instances)

    typemap = subparsers.add_parser('types', help='Show reversed record types')
    typemap.set_defaults(func=show_reversed_types)

    strings = subparsers.add_parser('strings', help='Show fields that are assumed to be strings types')
    strings.set_defaults(func=list_strings)

    show_parser = subparsers.add_parser('show', help='Show one record instance')
    show_parser.add_argument('address', type=argparse_utils.int16, action='store', default=None,
                            help='Specify the address of the record, or encompassed by the record')
    show_parser.set_defaults(func=show_record)

    hex_parser = subparsers.add_parser('hex', help='Show the bytes for one record instance')
    hex_parser.add_argument('address', type=argparse_utils.int16, action='store', default=None,
                            help='Specify the address of the record, or encompassed by the record')
    hex_parser.set_defaults(func=show_hex)

    parent = subparsers.add_parser('parents', help='List the predecessors pointing to the record at this address')
    parent.add_argument('address', type=argparse_utils.int16, action='store', default=None,
                        help='Hex address of the child structure')
    parent.set_defaults(func=show_predecessor)

    clean_parser = subparsers.add_parser('clean', help='Clean the memory dump from cached info.')
    clean_parser.set_defaults(func=clean)

    return rootparser


def main(argv):

    parser = argparser()
    opts = parser.parse_args(argv)

    if opts.debug:
        level = logging.DEBUG
        flog = os.path.normpath('log')
        logging.basicConfig(level=level, filename=flog, filemode='w')
        print ('[+] **** COMPLETE debug log to %s' % flog)
    else:
        level = logging.INFO
        logging.basicConfig(level=level)

    opts.func(opts)
    return

if __name__ == "__main__":
    sys.path.append(os.getcwd())
    main(sys.argv[1:])
