A State Machine for Multi-Threaded Coroutines in Kotlin Multiplatform

5 April 2020

State machines work well for kotlin multiplatform where we can write business logic once in kotlin and re-implement rendering logic with platform-specific code for android and iOS.

A state machine lays out a set of finite states that the UI can be in at a given moment in response to a set of finite actions/events. A well-defined state machine allows us to have a readable narrative of what actions led to which states through the life of our program. This saves us from cycles lost adding breadcrumbs for tricky bugs, as well as cycles lost to incorrect inferences that can easily be made reviewing poorly crafted logs.

Writing a state machine class that plays well with multi-threaded coroutines (native-mt) took some trial and error, so I wanted to share code for a Mealy state machine that is just 36 lines of kotlin. This first implementation works on coroutines-native-mt 1.3.5 and below:

StateMachine.kt

class StateMachine<<R : Any, T>>(
  val scope: CoroutineScope,
  private val initialState: T,
  private val sideEffects: List<(Flow<<R>, () -> T) -> Flow<R>>,
  private val reducer: suspend (accumulator: T, value: R) -> T
) {
  private val _state: ConflatedBroadcastChannel(initialState)
  val state = _state.asFlow()
  private var isInitialized = atomic(false)
  private val inputActions: BroadcastChannel<R>> = BroadcastChannel(Channel.BUFFERED)

  init {
    scope.launch {
      val lastState = atomic(initialState)
      val flowList = sideEffects.map { sideEffect ->
        sideEffect(inputActions.asFlow(), { lastState.value })
      }.run {
        toMutableList().apply {
          add(inputActions.asFlow())
        }
      }
      flowList.onBindMerge { isInitialized.value = true }
        .onEach { kprint("result: $it") }
        .onCompletion { inputActions.cancel() }
        .scan(lastState.value, reducer)
        .distinctUntilChanged()
        .collect { outputState ->
          kprint("state emitted: $outputState")
          lastState.value = outputState
          _state.send(outputState)
        }
    }

    fun dispatchAction(action: R) = scope.launch {
      kprint("Received input action: $action")
      while (!isInitialized.value) {
        yield()
      }
      inputActions.send(action)
    }
  }

}

/**
 * Ensures that every down stream flow is bound to the upstream
 */
fun <T> List<Flow<<T>>.onBindMerge(onBind: () -> Unit): Flow<T>> {
  var boundFlows = atomic(0)
  return channelFlow {
    forEach { flow ->
      launch {
        flow.onStart {
          if (boundFlows.incrementAndGet() == size) onBind()
        }.collect {
          send(it)
        }
      }
    }
  }
}

We need lastState to be atomic since the state that is emitted from the flow will always be frozen. The atomic ref allows us to wrap the frozen object with an unfrozen reference that we can mutate without encountering an InvalidMutabilityException or having to resort to atomics.

We need onBindMerge because the onStart hook for the merge operator will get called before the merged flows are bound as seen in this sample:


fun main() = runBlocking<Unit> {
   val inputChannel = BroadcastChannel<Int>(Channel.BUFFERED)
   launch {
      println("waiting to send first elem")
      while (!isActive.get()) {
        yield()
      }
      inputChannel.send(0)
      println("sent first elem")
   }

   launch {
      delay(3000)
      val outputFlow = merge(
          inputChannel.asFlow().onStart { println("flow 1: started") }
              .onEach { println("flow 1: $it") },
          inputChannel.asFlow().onStart { println("flow 2: started") }
              .onEach { println("flow 2: $it") },
          inputChannel.asFlow().onStart { println("flow 3: started") }
              .onEach { println("flow 3: $it") }
      ).onStart { isActive.compareAndSet(false, true) }
      delay(3000)
      outputFlow.collect {
         println("elem: $it")
      }
   }
}

emits:
waiting to send first elem
sent first elem
flow 1: started
flow 2: started
flow 3: started
  

We can clean this up by using a rendezvous channel and the broadcast extension function. A rendezvous channel's send function suspends until a corresponding call to receive is made, so we know no actions can be dispatched until a downstream observer is listening.

StateMachine.kt
    
class StateMachine<R : Any, T>(
    val scope: CoroutineScope,
    private val initialState: T,
    private val sideEffects: List<(Flow<R>, () -> T) -> Flow<R>>,
    private val reducer: suspend (accumulator: T, value: R) -> T
) {
    val viewState: ViewStateFlow<T> = ViewStateFlow()
    private val inputActions: Channel<R> = Channel()

    init {
        scope.launch {
            val lastState = atomic(initialState)
            val multicaster = inputActions.multicast(scope)
            val flowList = sideEffects.map { sideEffect ->
                sideEffect(multicaster.asFlow(), { lastState.value })
            }.run {
                toMutableList().apply {
                    add(multicaster.asFlow())
                }
            }
            flowList.merge().onEach { kprint("result $it") }
                .onCompletion { inputActions.cancel() }
                .scan(lastState.value, reducer)
                .distinctUntilChanged()
                .collect { outputState ->
                    lastState.value = outputState
                    viewState.send(outputState)
                }
        }
    }

    fun dispatchAction(action: R) = scope.launch {
        kprint("Received input action: $action")
        inputActions.send(action)
    }
}

fun <T> Channel<T>.multicast(scope: CoroutineScope): BroadcastChannel<T> {
    val channel = this
    return scope.broadcast {
        for (x in channel) {
            send(x)
        }
    }.also {
        it.invokeOnClose { channel.cancel() }
    }
}
    
  

We can confirm that we no longer need onBindMerge with this sample.


fun main() = runBlocking<Unit> {
   val inputChannel = Channel<Int>()
   launch {
      println("waiting to send first elem")
      inputChannel.send(0)
      println("sent first elem")
   }

   launch {
      val multicaster = inputChannel.multicast(this)
      delay(3000)
      val outputFlow = merge(
          multicaster.asFlow().onStart { println("flow 1: started") }
              .onEach { println("flow 1: $it") },
          multicaster.asFlow().onStart { println("flow 2: started") }
              .onEach { println("flow 2: $it") },
          multicaster.asFlow().onStart { println("flow 3: started") }
              .onEach { println("flow 3: $it") }
      )
      delay(3000)
      outputFlow.collect {
      }
   }
}

emits:
waiting to send first elem
flow 1: started
flow 3: started
flow 2: started
sent first elem
flow 3: 0
flow 1: 0
flow 2: 0
  

The first implementation of dispatchAction did not properly guarantee that downstream flows were collecting. The pipeline happened by luck to not miss any events when run on the android emulator, but the first dispatched events were missed on iOS. One benefit of multiplatform development is race conditions are more likely to pop up early as you are forced to regularly test in dissimilar environments.

A sample usage for a page with a search input, filtering dialog, and list of results looks like:

SearchStateMachine.kt
/**
* Define a finite set of input actions and a finite set of output view states.
*
* This lives in a common module that both the android and iOS common modules depend on.
*/
class Search {

  /**
   * Defines default options for each filter section
   */
  data class FilterOptions(
    val filterCountryList: List = defaultCountryList,
    val filterExitDirectionList: List = exitDirectionList
  )

  sealed class Action {
    object InitializeFiltersAction : Search.Action()
    data class FilterOptionsLoadedAction(val filterOptions: FilterOptions) : Search.Action()

    object TapFilterCancelBtn : Search.Action()
    object TapOpenFilterIcon : Search.Action()
    object BackButtonTapAction : Search.Action()

    data class QueryChangeAction(
      val filterState: FilterState,
      val query: String
    ) : Search.Action()

    data class FilterUpdateAction(
      val filterState: FilterState,
      val query: String
    ) : Search.Action()

    data class SearchLoadedAction(
      val items: List,
      val diffResult: DiffResult?
    ) : Search.Action()
  }

  enum class StateType {
    InitialState,
    FilterOptionsLoaded,
    ShowResultsState,
    CloseFilterState,
    ClosePageState,
    OpenFilterState
  }

  /*
  * StateType is our workaround to not wanting to redefine copy-able properties for each class
  * definition within a sealed class.
  */
  data class State(
    val type: StateType,
    val items: List,
    val diffResult: DiffResult?,
    val isFilterWindowVisible: Boolean = false,
    val filterOptions: FilterOptions = FilterOptions()
  )
}

/**
* Sample usage for screen with search bar at the top, a filter icon that launches a modal/dialog with advanced filter
* options, and displays results as a list.
*/
class SearchStateMachine(
  scope: CoroutineScope,
  private val exitDb: KExitDatabase
) {
  private val stateMachine = StateMachine(
    scope = scope,
    initialState = Search.State(Search.StateType.InitialState, emptyList(), null),
    reducer = ::reducer,
    sideEffects = listOf(
      ::generateFilterOptions,
      ::queryDatabase,
      ::filterDatabase
    )
  )
  val viewState = stateMachine.viewState
  val dispatchAction = stateMachine::dispatchAction

  // Needed for android so we can generate diffResult.
  var lastList: AtomicRef<List<SearchResult>> = atomic(listOf())

  private fun generateFilterOptions(input: Flow<Search.Action>, state: () -> Search.State): Flow<Search.Action> =
    input.filterIsInstance<Search.Action.InitializeFiltersAction>()
      .map {
        val defaultOptions = state().filterOptions
        val filterOptions = withContext(Dispatchers.Default) {
          val countryList = exitDb.exitEntityQueries.listAllCountries().executeAsList()
          defaultOptions.copy(filterCountryList = countryList)
        }
        Search.Action.FilterOptionsLoadedAction(filterOptions = filterOptions)
      }

  private fun filterDatabase(input: Flow<Search.Action>, state: () -> Search.State): Flow<Search.Action> =
    input.filterIsInstance<Search.Action.FilterUpdateAction>()
      .map { action ->
        performFilterSearch(action.filterState, action.query)
      }

  private fun queryDatabase(input: Flow<Search.Action>, state: () -> Search.State): Flow<Search.Action> =
    input.filterIsInstance<Search.Action.QueryChangeAction>()
      .map { action ->
        performQuerySearch(action.filterState, action.query)
      }

  private suspend fun performFilterSearch(filterState: FilterState, query: String) = withContext(Dispatchers.Default) {
      val searchResultList = KSearchQueryExecutor(exitDb).queryDatabase(filterState, query)
      val prevList = lastList.value
      val diffResult = KDiffUtil.calculateDiff(
        SearchResultItemDiffHelper(
          newList = searchResultList,
          oldList = prevList
        )
      )
      Search.Action.SearchLoadedAction(searchResultList, diffResult)
    }

  private suspend fun performQuerySearch(filterState: FilterState, query: String) = withContext(Dispatchers.Default) {
      val newList = KSearchQueryExecutor(exitDb).queryByName(filterState, query)
      val prevList = lastList.value
      val diffResult = KDiffUtil.calculateDiff(
        SearchResultItemDiffHelper(
          newList = newList,
          oldList = prevList
        )
      )
      Search.Action.SearchLoadedAction(newList, diffResult)
    }

  private suspend fun reducer(state: Search.State, action: Search.Action): Search.State {
    kprint("reducer: curState=$state action=$action")
    return when (action) {
      is Search.Action.InitializeFiltersAction -> state
      is Search.Action.FilterOptionsLoadedAction -> {
        state.copy(type = Search.StateType.FilterOptionsLoaded, filterOptions = action.filterOptions)
      }
      is Search.Action.TapFilterCancelBtn -> state.copy(
        type = Search.StateType.CloseFilterState,
        isFilterWindowVisible = false
      )
      is Search.Action.TapOpenFilterIcon -> {
        state.copy(type = Search.StateType.OpenFilterState, isFilterWindowVisible = true)
      }
      is Search.Action.QueryChangeAction -> state
      is Search.Action.BackButtonTapAction -> {
        if (state.isFilterWindowVisible) {
          state.copy(type = Search.StateType.CloseFilterState, isFilterWindowVisible = false)
        } else {
          state.copy(type = Search.StateType.ClosePageState, isFilterWindowVisible = false)
        }
      }
      is Search.Action.FilterUpdateAction -> state.copy(
        type = Search.StateType.CloseFilterState,
        isFilterWindowVisible = false
      )
      is Search.Action.SearchLoadedAction -> {
        lastList.value = action.items
        state.copy(type = Search.StateType.ShowResultsState, items = action.items, diffResult = action.diffResult)
      }
    }
  }
}
  
  

Last year a lot of development teams significantly reduced how much they logged to reduce GDPR risk exposure. Another approach that does not sacrifice consistent logging is to use the Redacted compiler plugin to apply a @Redacted annotation to any field with personally identifiable information (PII) class so its value is redacted from a data class' toString implementation.

This MVI sample and RxRedux introduced me to State Machine patterns on android. Thank you Roman Elizarov for patiently bearing with my basic questions on github; I am happy I was able to make a reproducer for one coroutines bug in the process of getting this pattern worked out. Running test code with Kotlin 1.4-M1 is a significant performance improvement over 1.3.71; looking forward to moving faster in learning and development as the ecosystem upgrades next month.