Building a concurrency-proof token refresh flow in Combine

Published on: November 9, 2020

Refreshing access tokens is a common task for many apps that use OAuth or other authentication mechanisms. No matter what your authentication mechanism is, your tokens will expire (eventually) and you'll need to refresh them using a refresh token. Frameworks like RxSwift and Combine provide convenient ways to build pipelines that perform transformation after transformation on a succesful network response, allowing you to grab Data, manipulate and transform it to an instance of a model object or anything else.

Programming the not-so-happy path where you need to refresh a token is not as simple. Especially because in an ideal world you only fire a single token refresh request even if multiple requests fail due a token error at the same time. You'll want to retry every request as soon as possible without firing more than one token refresh call.

The trick to building something like this is partly a Combine problem (what type of publisher can/should we use) but mostly a concurrency problem (how do we ensure that we only perform a single network call).

In this week's post I'll take a closer look at this problem and show a solution that should be able to hold up even if you're hammering it with token refresh requests.

Setting up a simple networking layer

In this post I will set up a simple mock networking layer that will allow you to experiment with the solution provided in this post even if you don't have a back-end or a token that requires refreshing. I'll start by showing you my models:

struct Token: Decodable {
  let isValid: Bool
}

struct Response: Decodable {
  let message: String
}

enum ServiceErrorMessage: String, Decodable, Error {
  case invalidToken = "invalid_token"
}

struct ServiceError: Decodable, Error {
  let errors: [ServiceErrorMessage]
}

These are just a few simple models. The Token is the key actor here since it's used to authenticate network calls. The Response object models a succesful request and ServiceError and ServiceErrorMessage represent the response that we'll get in case a user isn't authenticated due to a bad or expired token. You back-end will probably return something entirely different and your Token will probably have an expiresAt or expiresIn field that you would use to determine if the current device clock is past the token's expected expiration date. Since different servers might use different mechanisms to let their client know about a token's moment of expiration I won't detail that here. Just make sure that your version of isValid is based on the token's expiration time.

The networking itself is modelled by this protocol:

protocol NetworkSession: AnyObject {
  func publisher(for url: URL, token: Token?) -> AnyPublisher<Data, Error>
}

Using a protocol for this will help me swap out URLSession for a mock object that allows me to easily experiment with different responses. Note that I'm using Data as an output for my publisher here. This means that callers for publisher(for:token:) wouldn't have access to the response object that's returned by a data task publisher. That's not a problem for me in this case but if it is for you, make sure that you adjust the output (and adapt the code from this post) accordingly.

Here's what my mock networking object looks like:

class MockNetworkSession: NetworkSession {
  func publisher(for url: URL, token: Token? = nil) -> AnyPublisher<Data, Error> {
    let statusCode: Int
    let data: Data

    if url.absoluteString == "https://donnys-app.com/token/refresh" {
      print("fake token refresh")
      data = """
      {
        "isValid": true
      }
      """.data(using: .utf8)!
      statusCode = 200
    } else {
      if let token = token, token.isValid {
        print("success response")
        data = """
        {
          "message": "success!"
        }
        """.data(using: .utf8)!
        statusCode = 200
      } else {
        print("not authenticated response")
        data = """
        {
          "errors": ["invalid_token"]
        }
        """.data(using: .utf8)!
        statusCode = 401
      }
    }

    let response = HTTPURLResponse(url: url, statusCode: statusCode, httpVersion: nil, headerFields: nil)!

    // Use Deferred future to fake a network call
    return Deferred {
      Future { promise in
        DispatchQueue.global().asyncAfter(deadline: .now() + 1, execute: {
          promise(.success((data: data, response: response)))
        })
      }
    }
    .setFailureType(to: URLError.self)
    .tryMap({ result in
      guard let httpResponse = result.response as? HTTPURLResponse,
            httpResponse.statusCode == 200 else {

        let error = try JSONDecoder().decode(ServiceError.self, from: result.data)
        throw error
      }

      return result.data
    })
    .eraseToAnyPublisher()
  }
}

While there's a bunch of code here, it's really not that complex. The first couple of lines only check which endpoint we're calling and whether we received a valid token. Depending on these variables I either return a refreshed token, a success response or an error response. In really you would of course make a call to your server rather write some code like I did here. I use Combine's Future to publish my prepared response with a delay. I also do some processing on this response like I would in a real implementation to check which http status code I ended up with. If I get a non-200 status code I decode the body data into a ServiceError and fail the publisher by throwing an error that we can catch later when we call publisher(for:token:). If I got a 200 status code I return a Data object.

While this code might look a bit silly in the context of my mock, let's take a look at how you can extend URLSession to make it conform to NetworkSession:

extension URLSession: NetworkSession {
  func publisher(for url: URL, token: Token?) -> AnyPublisher<Data, Error> {
    var request = URLRequest(url: url)
    if let token = token {
      request.setValue("Bearer <access token>", forHTTPHeaderField: "Authentication")
    }

    return dataTaskPublisher(for: request)
      .tryMap({ result in
        guard let httpResponse = result.response as? HTTPURLResponse,
              httpResponse.statusCode == 200 else {

          let error = try JSONDecoder().decode(ServiceError.self, from: result.data)
          throw error
        }

        return result.data
      })
      .eraseToAnyPublisher()
  }
}

This extension looks a lot more reasonable. It's not quite as useful as I'd like because you'll probably want to have a version of publisher(for:) that takes a URLRequest that you configure in your networking layer. But again, my point isn't to teach you how to abstract your networking layer perfectly. It's to show you how you can implement a token refresh flow in Combine that can deal with concurrent requests. The abstraction I've written here is perfect to provide some scaffolding for this.

The final piece in this puzzle (for now) is a networking object that makes an authenticated request:

struct NetworkManager {
  private let session: NetworkSession

  init(session: NetworkSession = URLSession.shared) {
    self.session = session
  }

  func performAuthenticatedRequest() -> AnyPublisher<Response, Error> {
    let url = URL(string: "https://donnys-app.com/authenticated/resource")!

    return session.publisher(for: url: token: nil)
  }
}

This code doesn't quite meet the mark, and that's fine. We'll fix it in the next section.

When we call performAuthenticatedRequest we want to obtain a token from somewhere and then pass this token to publisher(for:token:) if it turns out that this token is invalid we want to try and refresh it exactly once. If we obtain a token and still aren't authenticated it's not very likely that refreshing the token again will fix this. It's probably a better idea to ask the user to login again or head down some other recovery path that's appropriate for your use case. The key component here is that we build an object that can provide tokens to objects that require them, obtain new tokens as needed, and most importantly does this gracefully without duplicate refresh requests. Let's see how.

Building an authenticator

Authenticator, Token provider, authentication manager, name it what you will. I will call it an authenticator since it handles the user's authentication status. You can call it anything you want, it doesn't matter much. Just name it something good.

The idea of an authenticator is that when asked for a valid token it can go down three routes:

  1. A valid token exists and should be returned
  2. We don't have a token so the user should log in
  3. A token refresh is in progress so the result should be shared
  4. No token refresh is in progress so we should start one

Each of the four scenarios above should produce a publisher, and this should all happen in a single method that returns a publisher that emits a token.

Before I show you my implementation for this method, I want to show you the skeleton for my authenticator:

class Authenticator {
  private let session: NetworkSession
  private var currentToken: Token? = Token(isValid: false)
  private let queue = DispatchQueue(label: "Autenticator.\(UUID().uuidString)")

  // this publisher is shared amongst all calls that request a token refresh
  private var refreshPublisher: AnyPublisher<Token, Error>?

  init(session: NetworkSession = URLSession.shared) {
    self.session = session
  }

  func validToken(forceRefresh: Bool = false) -> AnyPublisher<Token, Error> {
    // magic...
  }
}

Since we'll need to make a network call if the token requires refreshing the authenticator depdends on a NetworkSession. It will also keep track of the current token. In this case I use an invalid token as the default. In an app you'll probably want to grab a current token from the user's keychain and use nil as a default token so you can show a log in screen if no token exists.

The authenticator will need to deal with concurrency gracefully and the refreshPublisher property will be used to determine if a refresh is in progress. Since multiple queues could access refreshPublisher at the same time we want to make sure that only one thread can read refreshPublisher at the same time. This is what the queue property is used for. When I kick off a request I assign a value to refreshPublisher and when the request completes I will set this property to nil again.

Learn more about concurrency and synchronizing access in my post on DispatchQueue.sync and DispatchQueue.async.

Note that my validToken method take a forceRefresh argument. This argument is used to tell the authenticator that it should refresh a token even if it might look like a token should be valid. We'll pass true for this argument in case we get a token error from the server back in the NetworkManager. You'll see why in a moment.

Let's look at the implementation of validToken(forceRefresh:):

func validToken(forceRefresh: Bool = false) -> AnyPublisher<Token, Error> {
  return queue.sync { [weak self] in
    // scenario 1: we're already loading a new token
    if let publisher = self?.refreshPublisher {
      return publisher
    }

    // scenario 2: we don't have a token at all, the user should probably log in
    guard let token = self?.currentToken else {
      return Fail(error: AuthenticationError.loginRequired)
        .eraseToAnyPublisher()
    }

    // scenario 3: we already have a valid token and don't want to force a refresh
    if token.isValid, !forceRefresh {
      return Just(token)
        .setFailureType(to: Error.self)
        .eraseToAnyPublisher()
    }

    // scenario 4: we need a new token
    let endpoint = URL(string: "https://donnys-app.com/token/refresh")!
    let publisher = session.publisher(for: endpoint, token: nil)
      .share()
      .decode(type: Token.self, decoder: JSONDecoder())
      .handleEvents(receiveOutput: { token in
        self?.currentToken = token
      }, receiveCompletion: { _ in
        self?.queue.sync {
          self?.refreshPublisher = nil
        }
      })
      .eraseToAnyPublisher()

    self?.refreshPublisher = publisher
    return publisher
  }
}

The entire body for validToken(forceRefresh:) is executed sync on my queue to ensure that I don't have any data races for refreshPublisher. The initial scenario is simple. If we have a refreshPublisher, a request is already in progress and we should return the publisher that we stored earlier. The second scenario occurs if a user hasn't logged in at all or their token went missing. In this case I fail my publisher with an error I defined to tell subscribers that the user should log in. For posterity, here's what that error looks like:

enum AuthenticationError: Error {
  case loginRequired
}

If we have a token that's valid and we're not forcing a refresh, then I use a Just publisher to return a publisher that will emit the existing token immediately. No refresh needed.

Lastly, if we don't have a token that's valid I kick off a refresh request. Note that I use the share() operator how to make sure that any subscribers for my refresh request share the output from my initial request. Normally if you subscribe to the same data task publisher more than once it will kick off a network call for each subscriber. The share() operator makes sure that all subscribers receive the same output without triggering a new request.

I don't subscribe to the output of my refresh request, that's up to the caller of validToken(forceRefresh:). Instead, I use handleEvents to hook into the receiveOutput and receiveCompletion events. When my request produces a token, I cache it for future use. In your app you'll probably want to store the obtained token in the user's keychain. When the refresh request completes (either succesfully or with an error) I set the refreshPublisher to nil. Note that I wrap this in self?.queue.sync again to avoid data races.

Now that you have an authenticator, let's see how it can be used in the NetworkManager from the previous section.

Using the authenticator in your networking code

Since the authenticator should act as a dependncy of the network manager we'll need to make some changes to its init code:

private let session: NetworkSession
private let authenticator: Authenticator

init(session: NetworkSession = URLSession.shared) {
  self.session = session
  self.authenticator = Authenticator(session: session)
}

The same Authenticator can now be used in every network call you make through a single instance of NetworkManager. All that's left to do is use this authenticator in every network call that NetworkManager can perform.

In this case that's only a single method but for you it could be many, many more methods. Make sure that they all use the same instance of Authenticator.

Let's see what my finished example of performAuthenticatedRequest looks like:

func performAuthenticatedRequest() -> AnyPublisher<Response, Error> {
  let url = URL(string: "https://donnys-app.com/authenticated/resource")!

  return authenticator.validToken()
    .flatMap({ token in
      // we can now use this token to authenticate the request
      session.publisher(for: url, token: token)
    })
    .tryCatch({ error -> AnyPublisher<Data, Error> in
      guard let serviceError = error as? ServiceError,
            serviceError.errors.contains(ServiceErrorMessage.invalidToken) else {
        throw error
      }

      return authenticator.validToken(forceRefresh: true)
        .flatMap({ token in
          // we can now use this new token to authenticate the second attempt at making this request
          session.publisher(for: url, token: token)
        })
        .eraseToAnyPublisher()
    })
    .decode(type: Response.self, decoder: JSONDecoder())
    .eraseToAnyPublisher()
}

Before making my network call I call authenticator.validToken(). This will produce a publisher that emits a valid token. If we already have a valid token then the valid token will be published immediately. If we have a token that appears to be expired, validToken() will fire off a refresh immediately and we'll receive a refreshed token eventually. This means that the token that's passed to the flatMap which comes after validToken() should always be valid unless something strange happened and the validity of our token isn't what it looked like initially.

By using flatMap on validToken() you can grab the token and use it to create a new publisher. In this case that should be your network call.

After my flatMap I use tryCatch. Since the publisher(for:token:) implementation is expected to throw an error and fail the publisher if we receive a non-200 http status code we'll want to handle this in the tryCatch.

I check whether the error I received in my tryCatch is indeed a ServiceError and that its errors array contains ServiceErrorMessage.invalidToken. If I receive something else this could mean that the authenticator noticed that we don't have a token and it failed with a loginRequired error. It could also mean that something else went wrong. We want all these errors to be forwarded to the caller of performAuthenticatedRequest(). But if we received an error due to an expired token, we'll want to attempt one refresh to be sure we can't recover.

Note that I call validToken and pass forceRefresh: true at this point. The reason for this is that I already called validToken before and didn't force a refresh. The token that the authenticator holds appears to be valid but for some reason it's not. We'll want to tell the autenticator to refresh the token even if the token looks valid.

On the next line I flatMap over the output of validToken(forceRefresh:) just like I did before to return a network call.

Either the flapMap or the tryCatch will produce a publisher that emits Data. I can call decode on this publisher to obtain an instance of Response.

The whole chain is erased to AnyPublisher so my return type for performAuthenticatedRequest() is AnyPublisher<Response, Error>.

It takes some setup, and it's definitely not something you'll wrap your head around easily but this approach makes a ton of sense ones you've let it sink in. Especially because we begin our initial request with an access token that should be valid and has already been refreshed if the token appears to be expired locally. A single refresh will be attempted if the initial token turns out to be invalid in case the device clock is off, or a token was marked as expired on the server for security reasons or other reasons.

If the token was refreshed succesfully but we still can't perform our request it's likely that something else is off and it's highly unlikely that refreshing again will alleviate the issue.

In Summary

In this week's post you saw an approach that uses Combine and DispatchQueue.sync to build a token refresh flow that can handle multiple incoming requests at the same time without firing off new requests when a token refresh is already in progress. The implementation I've shown you will pro-actively refresh the user's token if it's already known that the token is expired. The implementation also features a forced refresh mechanism so you can trigger a token refresh at will, even if the locally cached token appears to be valid.

Flows like these are often built on top of arbitrary requirements and not every service will work will with the same approach. For that reason I tried to focus on the authenticator itself and the mechanisms that I use to synchronize access and share a publisher rather than showing you how you can design a perfect networking layer that integrates nicely with my Authenticator. I did show you a basic setup that can be thought of as a nice starting point.

If you have any questions or feedback about this article please let me know on Twitter.

Categories

Combine

Subscribe to my newsletter