# Rekall Memory Forensics
# Copyright (c) 2010, 2011, 2012 Michael Ligh <michael.ligh@mnin.org>
# Copyright 2013 Google Inc. All Rights Reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#

# pylint: disable=protected-access

from rekall import obj
from rekall.plugins.windows import common
from rekall.plugins.overlays.windows import pe_vtypes

callback_types = {
    '_NOTIFICATION_PACKET' : [0x10, {
        'ListEntry' : [0x0, ['_LIST_ENTRY']],
        'DriverObject' : [0x8, ['pointer', ['_DRIVER_OBJECT']]],
        'NotificationRoutine' : [0xC, ['unsigned int']],
    }],
    '_KBUGCHECK_CALLBACK_RECORD' : [0x20, {
        'Entry' : [0x0, ['_LIST_ENTRY']],
        'CallbackRoutine' : [0x8, ['unsigned int']],
        'Buffer' : [0xC, ['pointer', ['void']]],
        'Length' : [0x10, ['unsigned int']],
        'Component' : [0x14, ['pointer', ['String', dict(length=64)]]],
        'Checksum' : [0x18, ['pointer', ['unsigned int']]],
        'State' : [0x1C, ['unsigned char']],
    }],
    '_KBUGCHECK_REASON_CALLBACK_RECORD' : [0x1C, {
        'Entry' : [0x0, ['_LIST_ENTRY']],
        'CallbackRoutine' : [0x8, ['unsigned int']],
        'Component' : [0xC, ['pointer', ['String', dict(length=8)]]],
        'Checksum' : [0x10, ['pointer', ['unsigned int']]],
        'Reason' : [0x14, ['unsigned int']],
        'State' : [0x18, ['unsigned char']],
    }],
    '_SHUTDOWN_PACKET' : [0xC, {
        'Entry' : [0x0, ['_LIST_ENTRY']],
        'DeviceObject' : [0x8, ['pointer', ['_DEVICE_OBJECT']]],
    }],
    '_EX_CALLBACK_ROUTINE_BLOCK' : [0x8, {
        'RundownProtect' : [0x0, ['unsigned int']],
        'Function' : [0x4, ['unsigned int']],
        'Context' : [0x8, ['unsigned int']],
    }],
    '_GENERIC_CALLBACK' : [0xC, {
        'Callback' : [0x4, ['pointer', ['void']]],
        'Associated' : [0x8, ['pointer', ['void']]],
    }],
    '_REGISTRY_CALLBACK_LEGACY' : [0x38, {
        'CreateTime' : [0x0, ['WinFileTime', {}]],
    }],
    '_REGISTRY_CALLBACK' : [None, {
        'ListEntry' : [0x0, ['_LIST_ENTRY']],
        'Function' : [0x1C, ['pointer', ['void']]],
    }],
    '_DBGPRINT_CALLBACK' : [0x14, {
        'Function' : [0x8, ['pointer', ['void']]],
    }],
    '_NOTIFY_ENTRY_HEADER' : [None, {
        'ListEntry' : [0x0, ['_LIST_ENTRY']],
        'EventCategory' : [0x8, ['Enumeration', dict(
            target='long', choices={
                0: 'EventCategoryReserved',
                1: 'EventCategoryHardwareProfileChange',
                2: 'EventCategoryDeviceInterfaceChange',
                3: 'EventCategoryTargetDeviceChange'})]],
        'CallbackRoutine' : [0x14, ['unsigned int']],
        'DriverObject' : [0x1C, ['pointer', ['_DRIVER_OBJECT']]],
    }],
    }


callback_types_x64 = {
    '_GENERIC_CALLBACK' : [ 0x18, {
        'Callback' : [ 0x8, ['pointer', ['void']]],
        'Associated' : [ 0x10, ['pointer', ['void']]],
    } ],
    '_NOTIFICATION_PACKET' : [ 0x30, {
        'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
        'DriverObject' : [ 0x10, ['pointer', ['_DRIVER_OBJECT']]],
        'NotificationRoutine' : [ 0x18, ['address']],
    } ],
    '_SHUTDOWN_PACKET' : [ 0xC, {
        'Entry' : [ 0x0, ['_LIST_ENTRY']],
        'DeviceObject' : [ 0x10, ['pointer', ['_DEVICE_OBJECT']]],
    } ],
    '_DBGPRINT_CALLBACK' : [ 0x14, {
        'Function' : [ 0x10, ['pointer', ['void']]],
    } ],
    '_NOTIFY_ENTRY_HEADER' : [ None, {
        'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
        'EventCategory' : [ 0x10, ['Enumeration', dict(
            target = 'long', choices = {
            0: 'EventCategoryReserved',
            1: 'EventCategoryHardwareProfileChange',
            2: 'EventCategoryDeviceInterfaceChange',
            3: 'EventCategoryTargetDeviceChange'})]],
        'CallbackRoutine' : [ 0x20, ['address']],
        'DriverObject' : [ 0x30, ['pointer', ['_DRIVER_OBJECT']]],
    }],
    '_REGISTRY_CALLBACK' : [ 0x50, {
        'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
        'Function' : [ 0x20, ['pointer', ['void']]], # other could be 28
    }],

    # reactos/include/ddk/wdm.h :987
    '_KBUGCHECK_CALLBACK_RECORD' : [None, {
        'Entry' : [0x0, ['_LIST_ENTRY']],
        'CallbackRoutine' : [0x10, ['Pointer']],
        'Component' : [0x28, ['Pointer', dict(
            target='String',
            target_args=dict(
                length=8
            )
        )]],
    }],

    # reactos/include/ddk/wdm.h :962
    '_KBUGCHECK_REASON_CALLBACK_RECORD' : [None, {
        'Entry' : [0x0, ['_LIST_ENTRY']],
        'CallbackRoutine' : [0x10, ['Pointer']],
        'Component' : [0x18, ['Pointer', dict(
            target='String',
        )]],
    }],
}


class _SHUTDOWN_PACKET(obj.Struct):
    """Class for shutdown notification callbacks"""

    def sanity_check(self, vm):
        """
        Perform some checks to see if this object can exist in the provided
        address space.
        """

        if (not vm.is_valid_address(self.Entry.Flink) or
                not vm.is_valid_address(self.Entry.Blink) or
                not vm.is_valid_address(self.DeviceObject)):
            return False

        # Dereference the device object
        device = self.DeviceObject.dereference(vm=vm)

        # Carve out the device's object header and check its type
        object_header = self.obj_profile.Object(
            "_OBJECT_HEADER",
            offset=(device.obj_offset -
                    self.obj_profile.get_obj_offset("_OBJECT_HEADER", "Body")),
            vm=vm)

        return object_header.get_object_type(vm) == "Device"

class AbstractCallbackScanner(common.PoolScanner):
    """Return the offset of the callback, no object headers"""


class PoolScanFSCallback(AbstractCallbackScanner):
    """PoolScanner for File System Callbacks"""
    checks = [('PoolTagCheck', dict(tag="IoFs")),
              ('CheckPoolSize', dict(condition=lambda x: x == 0x18)),
              ('CheckPoolType', dict(non_paged=True, paged=True,
                                     free=True)),
             ]

    def scan(self, **kwargs):
        for pool_header in super(PoolScanFSCallback, self).scan(**kwargs):
            callback = self.profile.Object(
                '_NOTIFICATION_PACKET', offset=pool_header.end(),
                vm=self.address_space)

            yield ("IoRegisterFsRegistrationChange",
                   callback.NotificationRoutine, None)


class PoolScanShutdownCallback(AbstractCallbackScanner):
    """PoolScanner for Shutdown Callbacks"""
    checks = [('PoolTagCheck', dict(tag="IoSh")),
              ('CheckPoolSize', dict(condition=lambda x: x == 0x18)),
              ('CheckPoolType', dict(non_paged=True, paged=True,
                                     free=True)),
              ('CheckPoolIndex', dict(value=0)),
             ]

    def __init__(self, kernel_address_space=None, **kwargs):
        super(PoolScanShutdownCallback, self).__init__(**kwargs)
        self.kernel_address_space = kernel_address_space

    def scan(self, offset=0, **kwargs):
        for pool_header in super(PoolScanShutdownCallback, self).scan(
                offset=offset, **kwargs):

            # Instantiate the object in physical space but give it a native VM
            # of kernel space
            callback = self.profile._SHUTDOWN_PACKET(
                offset=pool_header.end(), vm=self.address_space)

            if not callback.sanity_check(self.kernel_address_space):
                continue

            # Get the callback's driver object. We've already
            # checked the sanity of the device object pointer.
            driver_obj = callback.DeviceObject.dereference(
                vm=self.kernel_address_space).DriverObject

            function_pointer = driver_obj.MajorFunction['IRP_MJ_SHUTDOWN']
            details = driver_obj.DriverName

            yield "IoRegisterShutdownNotification", function_pointer, details


class PoolScanGenericCallback(AbstractCallbackScanner):
    """PoolScanner for Generic Callbacks"""
    checks = [('PoolTagCheck', dict(tag="Cbrb")),
              ('CheckPoolSize', dict(condition=lambda x: x == 0x18)),
              ('CheckPoolType', dict(non_paged=True, paged=True, free=True)),
             ]

    def scan(self, **kwargs):
        """
        Enumerate generic callbacks of the following types:

        * PsSetCreateProcessNotifyRoutine
        * PsSetThreadCreateNotifyRoutine
        * PsSetLoadImageNotifyRoutine
        * CmRegisterCallback (on XP only)
        * DbgkLkmdRegisterCallback (on Windows 7 only)

        The only issue is that you can't distinguish between the types by just
        finding the generic callback structure
        """
        for pool_header in super(PoolScanGenericCallback, self).scan(**kwargs):
            callback = self.profile.Object(
                '_GENERIC_CALLBACK', offset=pool_header.end(),
                vm=self.address_space)

            yield "GenericKernelCallback", callback.Callback, None


class PoolScanDbgPrintCallback(AbstractCallbackScanner):
    """PoolScanner for DebugPrint Callbacks on Vista and 7"""
    checks = [('PoolTagCheck', dict(tag="DbCb")),
              ('CheckPoolSize', dict(condition=lambda x: x == 0x20)),
              ('CheckPoolType', dict(non_paged=True, paged=True, free=True)),
             ]

    def scan(self, offset=0, **kwargs):
        """Enumerate DebugPrint callbacks on Vista and 7"""
        for pool_header in super(PoolScanDbgPrintCallback, self).scan(
                offset=offset, **kwargs):

            callback = self.profile.Object(
                '_DBGPRINT_CALLBACK', offset=pool_header.end(),
                vm=self.address_space)

            yield "DbgSetDebugPrintCallback", callback.Function, None


class PoolScanRegistryCallback(AbstractCallbackScanner):
    """PoolScanner for DebugPrint Callbacks on Vista and 7"""
    checks = [('PoolTagCheck', dict(tag="CMcb")),
              # Seen as 0x38 on Vista SP2 and 0x30 on 7 SP0
              ('CheckPoolSize', dict(condition=lambda x: x >= 0x38)),
              ('CheckPoolType', dict(non_paged=True, paged=True, free=True)),
              ('CheckPoolIndex', dict(value=4)),
             ]

    def scan(self, offset=0, **kwargs):
        """
        Enumerate registry callbacks on Vista and 7.

        These callbacks are installed via CmRegisterCallback
        or CmRegisterCallbackEx.
        """
        for pool_header in super(PoolScanRegistryCallback, self).scan(
                offset=offset, **kwargs):

            callback = self.profile.Object(
                '_REGISTRY_CALLBACK', offset=pool_header.end(),
                vm=self.address_space)

            yield "CmRegisterCallback", callback.Function, None


class PoolScanPnp9(AbstractCallbackScanner):
    """PoolScanner for Pnp9 (EventCategoryHardwareProfileChange)"""
    checks = [('MultiPoolTagCheck', dict(tags=["Pnp9", "PnpD", "PnpC"])),
              # seen as 0x2C on W7, 0x28 on vistasp0 (4 less but needs 8 less)
              ('CheckPoolSize', dict(condition=lambda x: x >= 0x30)),
              ('CheckPoolType', dict(non_paged=True, paged=True, free=True)),
              ('CheckPoolIndex', dict(value=1)),
             ]

    def __init__(self, kernel_address_space=None, **kwargs):
        self.kernel_address_space = kernel_address_space
        super(PoolScanPnp9, self).__init__(**kwargs)

    def scan(self, offset=0, **kwargs):
        """Enumerate IoRegisterPlugPlayNotification"""
        for pool_header in super(PoolScanPnp9, self).scan(
                offset=offset, **kwargs):
            entry = self.profile.Object(
                "_NOTIFY_ENTRY_HEADER", offset=pool_header.end(),
                vm=self.address_space)

            # Dereference the driver object pointer
            driver = entry.DriverObject.dereference(
                vm=self.kernel_address_space)

            # Instantiate an object header for the driver name
            header = self.profile.Object(
                "_OBJECT_HEADER",
                offset=(driver.obj_offset -
                        driver.obj_profile.get_obj_offset(
                            "_OBJECT_HEADER", "Body")),
                vm=driver.obj_vm)

            # Grab the object name
            driver_name = header.NameInfo.Name.v()

            yield entry.EventCategory, entry.CallbackRoutine, driver_name


class CallbackScan(common.WindowsCommandPlugin):
    """Print system-wide notification routines by scanning for them.

    Note this plugin is quite inefficient - consider using the callbacks plugin
    instead.
    """

    __name = "callback_scan"

    def __init__(self, scan_in_kernel_address_space=False, **kwargs):
        super(CallbackScan, self).__init__(**kwargs)
        self.scan_in_kernel_address_space = scan_in_kernel_address_space

        if self.profile.metadata("arch") == "I386":
            # Add some plugin specific vtypes.
            self.profile.add_types(callback_types)
            self.profile.add_classes({
                '_SHUTDOWN_PACKET': _SHUTDOWN_PACKET,
            })

            self.profile = self.profile.copy()
            pe_vtypes.PEProfile.Initialize(self.profile)

        else:
            raise obj.ProfileError("This plugin only supports 32 bit profiles "
                                   "for now.")

    def get_kernel_callbacks(self):
        """
        Enumerate the Create Process, Create Thread, and Image Load callbacks.

        On some systems, the byte sequences will be inaccurate or the exported
        function will not be found. In these cases, the PoolScanGenericCallback
        scanner will pick up the pool associated with the callbacks.
        """

        routines = ["PspLoadImageNotifyRoutine",
                    "PspCreateThreadNotifyRoutine",
                    "PspCreateProcessNotifyRoutine"]

        for symbol in routines:
            # The list is an array of 8 _EX_FAST_REF objects
            callbacks = self.profile.get_constant_object(
                symbol,
                target="Array",
                target_args=dict(
                    count=8,
                    target='_EX_FAST_REF',
                    target_args=dict(
                        target="_GENERIC_CALLBACK",
                    )
                )
            )

            for callback in callbacks:
                if callback.Callback:
                    yield "GenericKernelCallback", callback.Callback, None

    def get_bugcheck_callbacks(self):
        """
        Enumerate generic Bugcheck callbacks.

        Note: These structures don't exist in tagged pools, but you can find
        them via KDDEBUGGER_DATA64 on all versions of Windows.
        """
        KeBugCheckCallbackListHead = self.profile.get_constant_object(
            "KeBugCheckCallbackListHead", "Pointer", target_args=dict(
                target='_LIST_ENTRY'))

        for l in KeBugCheckCallbackListHead.list_of_type(
                "_KBUGCHECK_CALLBACK_RECORD", "Entry"):
            yield ("KeBugCheckCallbackListHead", l.CallbackRoutine,
                   l.Component.dereference())

    def get_registry_callbacks_legacy(self):
        """
        Enumerate registry change callbacks.

        On XP these are registered using CmRegisterCallback.

        On Vista and Windows 7, these callbacks are registered using the
        CmRegisterCallbackEx function.
        """
        # The vector is an array of 100 _EX_FAST_REF objects
        addrs = self.profile.get_constant_object(
            "CmpCallBackVector",
            target="Array",
            target_args=dict(
                count=100,
                target="_EX_FAST_REF")
            )

        for addr in addrs:
            callback = addr.dereference_as("_EX_CALLBACK_ROUTINE_BLOCK")
            if callback:
                yield "Registry", callback.Function, None

    def get_bugcheck_reason_callbacks(self):
        """
        Enumerate Bugcheck Reason callbacks.
        """
        bugs = self.profile.get_constant_object(
            "KeBugCheckReasonCallbackListHead",
            target="_LIST_ENTRY")

        for l in bugs.list_of_type(
                "_KBUGCHECK_REASON_CALLBACK_RECORD", "Entry"):
            yield ("KeRegisterBugCheckReasonCallback", l.CallbackRoutine,
                   l.Component.dereference())

    def generate_hits(self):
        # Get the OS version we're analyzing
        version = self.profile.metadata('version')

        # Run through the hits but
        address_space = self.physical_address_space
        if self.scan_in_kernel_address_space:
            address_space = self.kernel_address_space

        # Get a scanner group - this will scan for all these in one pass.
        scanners = dict(
            PoolScanFSCallback=PoolScanFSCallback(
                address_space=address_space,
                profile=self.profile),

            PoolScanShutdownCallback=PoolScanShutdownCallback(
                profile=self.profile,
                address_space=address_space,
                kernel_address_space=self.kernel_address_space),

            PoolScanGenericCallback=PoolScanGenericCallback(
                address_space=address_space,
                profile=self.profile),
            )

        # Valid for Vista and later
        if version >= 6.0:
            scanners.update(
                PoolScanDbgPrintCallback=PoolScanDbgPrintCallback(
                    address_space=address_space,
                    profile=self.profile),

                PoolScanRegistryCallback=PoolScanRegistryCallback(
                    address_space=address_space,
                    profile=self.profile),

                PoolScanPnp9=PoolScanPnp9(
                    profile=self.profile,
                    address_space=address_space,
                    kernel_address_space=self.kernel_address_space),
                )

        for scanner in scanners.values():
            for info in scanner.scan():
                yield info

        # First few routines are valid on all OS versions
        for info in self.get_bugcheck_callbacks():
            yield info

        for info in self.get_bugcheck_reason_callbacks():
            yield info

        for info in self.get_kernel_callbacks():
            yield info

        # Valid for XP
        if version == 5.1:
            for info in self.get_registry_callbacks_legacy():
                yield info

    def render(self, renderer):
        renderer.table_header([("Type", "type", "36"),
                               ("Callback", "callback", "[addrpad]"),
                               ("Symbol", "symbol", "50"),
                               ("Details", "details", ""),
                              ])

        for (sym, cb, detail) in self.generate_hits():
            symbol_name = self.session.address_resolver.format_address(cb)
            renderer.table_row(sym, cb, symbol_name, detail)


class Callbacks(common.WindowsCommandPlugin):
    """Enumerate callback routines.

    This plugin just enumerates installed callback routines from various
    sources. It does not scan for them.

    This plugin is loosely based on the original Volatility plugin of the same
    name but much expanded using new information.

    Reference:
    <http://www.codemachine.com/notes.html>
    """

    name = "callbacks"

    def __init__(self, **kwargs):
        super(Callbacks, self).__init__(**kwargs)

        self.profile.add_types(callback_types_x64)

    def get_generic_callbacks(self, renderer):
        resolver = self.session.address_resolver
        for table, table_length in [
                ("nt!PspLoadImageNotifyRoutine",
                 "nt!PspLoadImageNotifyRoutineCount"),
                ("nt!PspCreateThreadNotifyRoutine",
                 "nt!PspCreateThreadNotifyRoutineCount"),
                ("nt!PspCreateProcessNotifyRoutine",
                 "nt!PspCreateProcessNotifyRoutineCount")]:
            array_length = resolver.get_constant_object(
                table_length, "unsigned long long")

            array = resolver.get_constant_object(
                table,
                target="Array",
                count=array_length,
                target_args=dict(
                    target="_EX_FAST_REF",
                    target_args=dict(
                        target="_GENERIC_CALLBACK"
                    )
                )
            )

            for callback in array:
                function = callback.Callback
                renderer.table_row(
                    table, callback, function,
                    resolver.format_address(function, max_distance=2**64)
                    )

    def get_bugcheck_callbacks(self, renderer):
        resolver = self.session.address_resolver

        for list_head_name, type in [
                ("nt!KeBugCheckCallbackListHead", "_KBUGCHECK_CALLBACK_RECORD"),
                ("nt!KeBugCheckReasonCallbackListHead",
                 "_KBUGCHECK_REASON_CALLBACK_RECORD")]:
            list_head = resolver.get_constant_object(
                list_head_name, "_LIST_ENTRY")

            for record in list_head.list_of_type(type, "Entry"):
                function = record.CallbackRoutine

                renderer.table_row(
                    list_head_name,
                    record,
                    function,
                    resolver.format_address(function, max_distance=2**64),
                    record.Component
                )

    def render(self, renderer):
        renderer.table_header([
            ("Type", "type", "36"),
            ("Offset", "offset", "[addrpad]"),
            ("Callback", "callback", "[addrpad]"),
            ("Symbol", "symbol", "50"),
            ("Details", "details", ""),
        ])

        self.get_generic_callbacks(renderer)
        self.get_bugcheck_callbacks(renderer)
