diff --git a/src/drivers/linker.py b/src/drivers/linker.py index ec13cb6..2bace2d 100644 --- a/src/drivers/linker.py +++ b/src/drivers/linker.py @@ -15,6 +15,7 @@ MAC_SIZE = int(sancus.config.SECURITY / 8) KEY_SIZE = MAC_SIZE +CONNECTION_STRUCT_SIZE = 6 + KEY_SIZE class SmEntry: @@ -61,7 +62,7 @@ def add_sym(file, sym_map): args += [file, file] call_prog('msp430-elf-objcopy', args) - return file + return file def parse_size(val): @@ -90,18 +91,23 @@ def get_symbol(elf_file, name): def get_io_sym_map(sm_name): sym_map = { - '__sm_handle_input': '__sm_{}_handle_input'.format(sm_name), - '__sm_num_inputs': '__sm_{}_num_inputs'.format(sm_name), - '__sm_num_connections': '__sm_{}_num_connections'.format(sm_name), - '__sm_io_keys': '__sm_{}_io_keys'.format(sm_name), - '__sm_input_callbacks': '__sm_{}_input_callbacks'.format(sm_name), - '__sm_output_nonce': '__sm_{}_output_nonce'.format(sm_name), - '__sm_send_output': '__sm_{}_send_output'.format(sm_name), - '__sm_set_key': '__sm_{}_set_key'.format(sm_name), - '__sm_X_exit': '__sm_{}_exit'.format(sm_name), - '__sm_X_stub_malloc': '__sm_{}_stub_malloc'.format(sm_name), + '__sm_handle_input': '__sm_{}_handle_input'.format(sm_name), + '__sm_num_inputs': '__sm_{}_num_inputs'.format(sm_name), + '__sm_num_connections': '__sm_{}_num_connections'.format(sm_name), + '__sm_max_connections': '__sm_{}_max_connections'.format(sm_name), + '__sm_io_connections': '__sm_{}_io_connections'.format(sm_name), + '__sm_input_callbacks': '__sm_{}_input_callbacks'.format(sm_name), + '__sm_send_output': '__sm_{}_send_output'.format(sm_name), + '__sm_set_key': '__sm_{}_set_key'.format(sm_name), + '__sm_attest': '__sm_{}_attest'.format(sm_name), + '__sm_X_exit': '__sm_{}_exit'.format(sm_name), + '__sm_X_stub_malloc': '__sm_{}_stub_malloc'.format(sm_name), '__sm_X_stub_reactive_handle_output': - '__sm_{}_stub_reactive_handle_output'.format(sm_name) + '__sm_{}_stub_reactive_handle_output'.format(sm_name), + '__sm_X_public_start': '__sm_{}_public_start'.format(sm_name), + '__sm_X_public_end': '__sm_{}_public_end'.format(sm_name), + '__sm_X_secret_start': '__sm_{}_secret_start'.format(sm_name), + '__sm_X_secret_end': '__sm_{}_secret_end'.format(sm_name) } return sym_map @@ -113,7 +119,7 @@ def get_io_sect_map(sm_name): '.rela.sm.X.text': '.rela.sm.{}.text'.format(sm_name), } - for entry in ('__sm{}_set_key', '__sm{}_handle_input'): + for entry in ('__sm{}_set_key', '__sm{}_attest', '__sm{}_handle_input'): map['.sm.X.{}.table'.format(entry.format(''))] = \ '.sm.{}.{}.table'.format(sm_name, entry.format('_' + sm_name)) map['.rela.sm.X.{}.table'.format(entry.format(''))] = \ @@ -134,15 +140,19 @@ def create_io_stub(sm, stub): def sort_entries(entries): - # If the set_key entry exists, it should have index 0 and if the - # handle_input entry exists, it should have index 1. This is accomplished by - # mapping those entries to __ and ___ respectively since those come + # If the set_key entry exists, it should have index 0, if the + # attest entry exists, it should have index 1 and if + # handle_input entry exists, it should have index 2. This is accomplished by + # mapping those entries to __, ___ and ___ respectively since those come # alphabetically before any valid entry name. def sort_key(entry): if re.match(r'__sm_\w+_set_key', entry.name): return '__' - if re.match(r'__sm_\w+_handle_input', entry.name): + if re.match(r'__sm_\w+_attest', entry.name): return '___' + if re.match(r'__sm_\w+_handle_input', entry.name): + return '____' + return entry.name entries.sort(key=sort_key) @@ -199,7 +209,7 @@ def sort_key(entry): for a in archive_files: debug("Unpacking archive for Sancus SM inspection: " + a) file_name = a - if ':' in a: + if ':' in a: # support calls such as -lib:/full/path file_name = file_name.split(':')[1] @@ -256,6 +266,7 @@ def sort_key(entry): elf_relocations = defaultdict(list) added_set_key_stub = False +added_attest_stub = False added_input_stub = False added_output_stub = False @@ -396,6 +407,14 @@ def sort_key(entry): input_files_to_scan.append(generated_file) added_set_key_stub = True + if not added_attest_stub: + # Generate the attest stub file + generated_file = create_io_stub(sm, 'sm_attest.o') + generated_object_files.append(generated_file) + # And register it to also be scanned by this loop later + input_files_to_scan.append(generated_file) + added_attest_stub = True + if which == 'input': dest = sms_inputs @@ -740,32 +759,45 @@ def sort_key(entry): call_prog('msp430-gcc', ['-c', '-o', o_file, c_file]) - input_callbacks += ' {}(.sm.{}.callbacks)\n'.format(o_file, sm) + input_callbacks += ' KEEP({}(.sm.{}.callbacks))\n'.format(o_file, sm) input_callbacks += ' . = ALIGN(2);' - # Table of connection keys - io_keys = '' + """ + Table of connections: in a reactive application, a connection links the + output of a SM (defined using the macro `SM_OUTPUT`) to the input of another + (defined using the macro `SM_INPUT`). + These connections are stored in a `Connection` array on each SM (see + `reactive_stubs_support.h`). The array is allocated here with a fixed size, + according to the `num_connections` parameter in the SM config (default 0). + """ - if len(ios) > 0: - io_keys += '__sm_{}_io_keys = .;\n'.format(sm) - io_keys += ' . += {};\n'.format(len(ios) * KEY_SIZE) - io_keys += ' . = ALIGN(2);' + num_connections = '' + io_connections = '' + + if hasattr(sm_config[sm], "num_connections"): + sm_num_connections = sm_config[sm].num_connections + else: + sm_num_connections = 0 - # Nonce used by outputs - outputs_nonce = '' + if len(ios) > 0: + # make sure we allocate space even if num_connections is zero + io_connections_size = max(sm_num_connections * CONNECTION_STRUCT_SIZE, 2) - if len(outputs) > 0: - outputs_nonce += '__sm_{}_output_nonce = .;\n'.format(sm) - outputs_nonce += ' . += 2;\n' - outputs_nonce += ' . = ALIGN(2);' + num_connections += '__sm_{}_num_connections = .;\n'.format(sm) + num_connections += ' . += 2;\n' + num_connections += ' . = ALIGN(2);' + io_connections += '__sm_{}_io_connections = .;\n'.format(sm) + io_connections += ' . += {};\n'.format(io_connections_size) + io_connections += ' . = ALIGN(2);' text_sections.append(text_section.format(sm, entry_file, isr_file, exit_file, '\n '.join(tables), input_callbacks, '\n '.join(extra_labels))) + data_sections.append(data_section.format(sm, '\n '.join(id_syms), - args.sm_stack_size, io_keys, - outputs_nonce)) + args.sm_stack_size, num_connections, + io_connections)) if sm in sms_entries: num_entries = len(sms_entries[sm]) @@ -782,7 +814,7 @@ def sort_key(entry): symbols.append('__sm_{}_io_{}_idx = {};'.format(sm, io, index)) # Add symbols for the number of connections/inputs - symbols.append('__sm_{}_num_connections = {};'.format(sm, len(ios))) + symbols.append('__sm_{}_max_connections = {};'.format(sm, sm_num_connections)) symbols.append('__sm_{}_num_inputs = {};'.format(sm, len(inputs))) if args.prepare_for_sm_text_section_wrapping: diff --git a/src/sancus_support/reactive.h b/src/sancus_support/reactive.h index db60b61..d3bebc5 100644 --- a/src/sancus_support/reactive.h +++ b/src/sancus_support/reactive.h @@ -6,19 +6,20 @@ #include typedef uint16_t io_index; +typedef uint16_t conn_index; typedef uint8_t io_data __attribute__((aligned(2))); // The ASM symbols are used for the linker to be able to detect inputs/outputs -#define SM_OUTPUT_AUX(sm, name) \ - asm("__sm_" #sm "_output_tag_" #name " = 0\n"); \ - SM_FUNC(sm) void name(const io_data* data, size_t len) \ - { \ - extern char __sm_##sm##_io_##name##_idx; \ - SM_FUNC(sm) void __sm_##sm##_send_output(unsigned int, \ - const void*, size_t); \ - __sm_##sm##_send_output((io_index)&__sm_##sm##_io_##name##_idx, \ - data, len); \ +#define SM_OUTPUT_AUX(sm, name) \ + asm("__sm_" #sm "_output_tag_" #name " = 0\n"); \ + SM_FUNC(sm) uint16_t name(const io_data* data, size_t len) \ + { \ + extern char __sm_##sm##_io_##name##_idx; \ + SM_FUNC(sm) uint16_t __sm_##sm##_send_output(unsigned int, \ + const void*, size_t); \ + return __sm_##sm##_send_output((io_index)&__sm_##sm##_io_##name##_idx, \ + data, len); \ } #define SM_OUTPUT(sm, name) SM_OUTPUT_AUX(sm, name) diff --git a/src/sancus_support/sm_support.h b/src/sancus_support/sm_support.h index 0be83e0..0537f12 100644 --- a/src/sancus_support/sm_support.h +++ b/src/sancus_support/sm_support.h @@ -4,6 +4,7 @@ #include "config.h" #include +#include #if __GNUC__ >= 5 || __clang_major__ >= 5 @@ -199,15 +200,12 @@ extern char __unprotected_sp; #endif -#define __OUTSIDE_SM( p, sm ) \ - ( ((void*) p < (void*) &__PS(sm)) || ((void*) p >= (void*) &__PE(sm)) ) && \ - ( ((void*) p < (void*) &__SS(sm)) || ((void*) p >= (void*) &__SE(sm)) ) - /* * Returns true iff whole buffer [p,p+len-1] is outside of the sm SancusModule */ -#define sancus_is_outside_sm( sm, p, len) \ - ( __OUTSIDE_SM(p, sm) && __OUTSIDE_SM((p+len-1), sm) ) +#define sancus_is_outside_sm(sm, p, len) \ + ( is_buffer_outside_region(&__PS(sm), &__PE(sm), p, len) && \ + is_buffer_outside_region(&__SS(sm), &__SE(sm), p, len) ) /** * Interrupt vector for the Sancus violation ISR. @@ -311,6 +309,37 @@ sm_id sancus_enable_wrapped(struct SancusModule* sm, unsigned nonce, void* tag); #undef always_inline #define always_inline static inline __attribute__((always_inline)) +/* + * Returns true if buf is outside the memory region [start, end) + * if start >= end, immediately return false + */ +always_inline int is_buffer_outside_region(void *start_p, void *end_p, + void *buf_p, size_t len) { + uintptr_t start = (uintptr_t) start_p; + uintptr_t end = (uintptr_t) end_p; + uintptr_t buf = (uintptr_t) buf_p; + uintptr_t buf_end; + + // make sure start < end, otherwise return false + if (start >= end) { + return 0; + } + + if(len > 0) { + buf_end = buf + len - 1; + } + else { + buf_end = buf; + } + + /* check for int overflow and finally validate `buf` falls outside */ + if( (buf <= buf_end) && ((end <= buf) || (start > buf_end))) { + return 1; + } + + return 0; +} + /** * Disable the protection of the calling module. */ diff --git a/src/stubs/CMakeLists.txt b/src/stubs/CMakeLists.txt index 6fe1dc9..c0fda87 100644 --- a/src/stubs/CMakeLists.txt +++ b/src/stubs/CMakeLists.txt @@ -22,6 +22,7 @@ set(EXTRA_FLAGS -I${CMAKE_SOURCE_DIR}/src/sancus_support) add_object(sm_output.o sm_output.c ${EXTRA_FLAGS}) add_object(sm_input.o sm_input.c ${EXTRA_FLAGS}) add_object(sm_set_key.o sm_set_key.c ${EXTRA_FLAGS}) +add_object(sm_attest.o sm_attest.c ${EXTRA_FLAGS}) set(STUBS ${CMAKE_CURRENT_BINARY_DIR}/sm_entry.o @@ -32,6 +33,7 @@ set(STUBS ${CMAKE_CURRENT_BINARY_DIR}/sm_output.o ${CMAKE_CURRENT_BINARY_DIR}/sm_input.o ${CMAKE_CURRENT_BINARY_DIR}/sm_set_key.o + ${CMAKE_CURRENT_BINARY_DIR}/sm_attest.o ${CMAKE_CURRENT_BINARY_DIR}/sm_mmio_entry.o ${CMAKE_CURRENT_BINARY_DIR}/sm_mmio_exclusive.o diff --git a/src/stubs/reactive_stubs_support.h b/src/stubs/reactive_stubs_support.h index c5aca41..3c7a251 100644 --- a/src/stubs/reactive_stubs_support.h +++ b/src/stubs/reactive_stubs_support.h @@ -3,29 +3,54 @@ #include "reactive.h" -void reactive_handle_output(io_index output_id, void* data, size_t len); +void reactive_handle_output(conn_index conn_id, void* data, size_t len); typedef uint8_t IoKey[SANCUS_KEY_SIZE]; typedef void (*InputCallback)(const void*, size_t); typedef enum { - Ok = 0x0, - IllegalConnection = 0x1, - MalformedPayload = 0x2 + Ok = 0x0, + IllegalConnection = 0x1, + MalformedPayload = 0x2, + IllegalParameters = 0x3, + BufferInsideSM = 0x4, + CryptoError = 0x5, + InternalError = 0x6 } ResultCode; +typedef struct Connection { + io_index io_id; + conn_index conn_id; + uint16_t nonce; + IoKey key; +} Connection; + +// The size of the Connection struct is also hardcoded in linker.py. Hence, +// we need to make sure that it does not change at compile time (e.g. due to +// optimizations). +// Besides, if the struct changes, we need to adjust this value here and in +// linker.py (check the CONNECTION_STRUCT_SIZE global variable) as well. +_Static_assert (sizeof(Connection) == 6 + SANCUS_KEY_SIZE, + "Size of Connection struct differs from the expected value"); + // These will be allocated by the linker -extern IoKey __sm_io_keys[]; +extern Connection __sm_io_connections[]; extern InputCallback __sm_input_callbacks[]; -extern uint16_t __sm_output_nonce; -extern char __sm_num_connections; -#define SM_NUM_CONNECTIONS (size_t)&__sm_num_connections +extern char __sm_max_connections; +#define SM_MAX_CONNECTIONS (size_t)&__sm_max_connections + +extern uint16_t __sm_num_connections; extern char __sm_num_inputs; #define SM_NUM_INPUTS (size_t)&__sm_num_inputs #define SM_NAME X +// declare symbols for the public/secret regions +#define __SECTION(sect, name) sect(name) +extern char __SECTION(__PS, SM_NAME), __SECTION(__PE, SM_NAME), + __SECTION(__SS, SM_NAME), __SECTION(__SE, SM_NAME); + #endif diff --git a/src/stubs/sm_attest.c b/src/stubs/sm_attest.c new file mode 100644 index 0000000..6bfec25 --- /dev/null +++ b/src/stubs/sm_attest.c @@ -0,0 +1,16 @@ +#include "reactive_stubs_support.h" + +uint16_t SM_ENTRY(SM_NAME) __sm_attest(const uint8_t* challenge, size_t len, + uint8_t *result) +{ + if( !sancus_is_outside_sm(SM_NAME, (void *) challenge, len) || + !sancus_is_outside_sm(SM_NAME, (void *) result, SANCUS_TAG_SIZE) ) { + return BufferInsideSM; + } + + if( !sancus_tag(challenge, len, result) ) { + return CryptoError; + } + + return Ok; +} diff --git a/src/stubs/sm_input.c b/src/stubs/sm_input.c index 992d2e6..05f3f55 100644 --- a/src/stubs/sm_input.c +++ b/src/stubs/sm_input.c @@ -2,24 +2,42 @@ #include -#define AD_SIZE 2 - -void SM_ENTRY(SM_NAME) __sm_handle_input(uint16_t conn_id, +uint16_t SM_ENTRY(SM_NAME) __sm_handle_input(uint16_t conn_idx, const void* payload, size_t len) { - if (conn_id >= SM_NUM_INPUTS) - return; + // sanitize input buffer + if(!sancus_is_outside_sm(SM_NAME, (void *) payload, len)) { + return BufferInsideSM; + } - const size_t data_len = len - AD_SIZE - SANCUS_TAG_SIZE; - const uint8_t* cipher = (uint8_t*)payload + AD_SIZE; - const uint8_t* tag = cipher + data_len; + // check correctness of other parameters + if(len < SANCUS_TAG_SIZE || conn_idx >= __sm_num_connections) { + return IllegalParameters; + } + + Connection *conn = &__sm_io_connections[conn_idx]; + // check if io_id is a valid input ID + if (conn->io_id >= SM_NUM_INPUTS) { + return IllegalConnection; + } + + // associated data only contains the nonce, therefore we can use this + // this trick to build the array fastly (i.e. by swapping the bytes) + const uint16_t nonce_rev = conn->nonce << 8 | conn->nonce >> 8; + const size_t data_len = len - SANCUS_TAG_SIZE; + const uint8_t* cipher = payload; + const uint8_t* tag = cipher + data_len; // TODO check for stack overflow! uint8_t* input_buffer = alloca(data_len); - if (sancus_unwrap_with_key(__sm_io_keys[conn_id], payload, AD_SIZE, - cipher, data_len, tag, input_buffer)) - { - __sm_input_callbacks[conn_id](input_buffer, data_len); + if (sancus_unwrap_with_key(conn->key, &nonce_rev, sizeof(nonce_rev), + cipher, data_len, tag, input_buffer)) { + conn->nonce++; + __sm_input_callbacks[conn->io_id](input_buffer, data_len); + return Ok; } + + // here only if decryption fails + return CryptoError; } diff --git a/src/stubs/sm_output.c b/src/stubs/sm_output.c index 6febf29..cc71424 100644 --- a/src/stubs/sm_output.c +++ b/src/stubs/sm_output.c @@ -2,14 +2,38 @@ #include -SM_FUNC(SM_NAME) void __sm_send_output(io_index index, +SM_FUNC(SM_NAME) uint16_t __sm_send_output(io_index index, const void* data, size_t len) { - const size_t nonce_len = sizeof(__sm_output_nonce); - const size_t payload_len = nonce_len + len + SANCUS_TAG_SIZE; - uint8_t* payload = malloc(payload_len); - *(uint16_t*)payload = __sm_output_nonce++; - sancus_wrap_with_key(__sm_io_keys[index], payload, nonce_len, data, len, - payload + nonce_len, payload + nonce_len + len); - reactive_handle_output(index, payload, payload_len); + const size_t payload_len = len + SANCUS_TAG_SIZE; + + // search for all the connections associated to the index. + // Unfortunately, this operation is O(n) with n = number of connections + int i; + for (i=0; i<__sm_num_connections; i++) { + Connection *conn = &__sm_io_connections[i]; + if (conn->io_id != index) { + continue; + } + + uint8_t* payload = malloc(payload_len); + + if (payload == NULL) { + return InternalError; + } + + if ( !sancus_is_outside_sm(SM_NAME, (void *) payload, payload_len) ) { + return BufferInsideSM; + } + + // associated data only contains the nonce, therefore we can use this + // this trick to build the array fastly (i.e. by swapping the bytes) + uint16_t nonce_rev = conn->nonce << 8 | conn->nonce >> 8; + sancus_wrap_with_key(conn->key, &nonce_rev, sizeof(nonce_rev), data, + len, payload, payload + len); + conn->nonce++; + reactive_handle_output(conn->conn_id, payload, payload_len); + } + + return Ok; } diff --git a/src/stubs/sm_set_key.c b/src/stubs/sm_set_key.c index e982d4c..7694958 100644 --- a/src/stubs/sm_set_key.c +++ b/src/stubs/sm_set_key.c @@ -1,21 +1,43 @@ #include "reactive_stubs_support.h" -void SM_ENTRY(SM_NAME) __sm_set_key(const uint8_t* ad, const uint8_t* cipher, - const uint8_t* tag, uint8_t* result) +#define AD_SIZE 6 + +uint16_t SM_ENTRY(SM_NAME) __sm_set_key(const uint8_t* ad, const uint8_t* cipher, + const uint8_t* tag, uint16_t *conn_idx) { - uint16_t conn_id = (ad[2] << 8) | ad[3]; - ResultCode code = Ok; - - if (conn_id >= SM_NUM_CONNECTIONS) - code = IllegalConnection; - else if (!sancus_unwrap(ad, 4, cipher, SANCUS_KEY_SIZE, tag, - __sm_io_keys[conn_id])) - { - code = MalformedPayload; + if( !sancus_is_outside_sm(SM_NAME, (void *) ad, AD_SIZE) || + !sancus_is_outside_sm(SM_NAME, (void *) cipher, SANCUS_KEY_SIZE) || + !sancus_is_outside_sm(SM_NAME, (void *) tag, SANCUS_TAG_SIZE) || + !sancus_is_outside_sm(SM_NAME, (void *) conn_idx, sizeof(uint16_t)) ) { + return BufferInsideSM; + } + + // Note: make sure we only use AD_SIZE bytes of the buffer `ad` + conn_index conn_id = (ad[0] << 8) | ad[1]; + io_index io_id = (ad[2] << 8) | ad[3]; + uint16_t nonce = (ad[4] << 8) | ad[5]; + + // check if there is still space left in the array + if (__sm_num_connections == SM_MAX_CONNECTIONS) { + return InternalError; + } + + // check nonce + if(nonce != __sm_num_connections) { + return MalformedPayload; } - result[0] = 0; - result[1] = code; - uint8_t result_ad[] = {ad[0], ad[1], result[0], result[1]}; - sancus_tag(result_ad, sizeof(result_ad), result + 2); + Connection *conn = &__sm_io_connections[__sm_num_connections]; + *conn_idx = __sm_num_connections; + + if (!sancus_unwrap(ad, AD_SIZE, cipher, SANCUS_KEY_SIZE, tag, conn->key)) { + return CryptoError; + } + + __sm_num_connections++; + conn->io_id = io_id; + conn->conn_id = conn_id; + conn->nonce = 0; + + return Ok; }