from axion_api import axion
from PythonQt.Qt import *

import reven

def pick_memory_access(client, point):
    "Return first non null logical address accessed by instruction at given execution point or None"
    null_logical = reven.logical_address(0, 0)
    result = None

    accesses = client.memory_get_history_instruction(point)
    for access in accesses:
        if access.logical != null_logical:
            if not (result and result.write and not access.write):
                result = access
    return result

def get_matching_instruction(client, point):
    "Return execution point of matching instruction for given execution point or None if not found."

    # create a range of size 1 to query context before and after selected instruction
    point_range = reven.execution_range(point.run_name, point.sequence_identifier, 1, point.instruction_index)
    # give an empty vector of logical address range to ignore memory for we don't nedd any memory here
    context = client.run_get_running_context_between(point_range, reven.vector_of_logical_address_range())

    # ss values before and after selected instruction
    ss_before = context.before.numeric_registers['ss'].value
    ss_after = context.after.numeric_registers['ss'].value

    # ss basic heuristic to handle sysenter/sysexit
    if ss_before != ss_after:
        # ss is modified by instruction, search next or previous write for ss and return corresponding execution point
        return client.run_search_next_register_use(point, forward=(ss_before > ss_after), read=False, write=True, register_name="ss")

    # memory heuristic
    access = pick_memory_access(client, point)
    if access != None:
        if access.write:
            return client.run_search_next_memory_use(point, forward=True, read=True, write=False, address=access.logical)

        if access.read:
            return client.run_search_next_memory_use(point, forward=False, read=False, write=True, address=access.logical)

    # esp values before and after selected instruction
    esp_before = context.before.numeric_registers['esp'].value
    esp_after = context.after.numeric_registers['esp'].value

    # esp heuristic
    if esp_before != esp_after:
        if esp_before > esp_after:
            stack = reven.logical_address(ss_after, esp_after)
            return client.run_search_next_memory_use(point, forward=True, read=True, write=False, address=stack)

        if esp_before < esp_after:
            stack = reven.logical_address(ss_before, esp_before)
            return client.run_search_next_memory_use(point, forward=False, read=False, write=True, address=stack)



def axion_callback():
    # get current selected instruction from axion
    run, seq, instr = axion.selected_sequence()

    # connect to Reven server project instance for latter service calls
    host, port = axion.connection_info()
    client = reven.reven_connection(host.encode(), port)

    # get reven execution point corresponding to selected instruction (for reven service calls take execution point)
    point = reven.execution_point(run.encode(), seq, instr)

    # retrieve data from Reven and use custom heuristic to find an instruction matching the current one
    result = get_matching_instruction(client, point)

    if result == None or not result.valid():
        axion.status_message("No matching instruction recorded for [%s@%d:%d]" % (run, seq, instr))
    else:
         # select matching instruction in axion
        axion.select_sequence(result.run_name, result.sequence_identifier, result.instruction_index)
        axion.status_message("Jumped to matching instruction at [%s@%d:%d] from [%s@%d:%d]" % (result.run_name, result.sequence_identifier, result.instruction_index, run, seq, instr))
