Rui's Blog

[2019 SOSP] Parity Models: Erasure-Coded Resilience for Prediction Serving Systems


This work uses erasure codes for reducing tail latency in ML inference.

Background & Motivation

ML inference, typically done in large-scale clusters, is latency-sensitive. However, slowdowns (network/compute contention) and failures in clusters might cause inference queries to miss their SLOs. This work aims to alleviate the effects of slowdowns and failures to reduce tail latency.
Erasure codes is a technique widely deployed in systems (e.g., storage systems, communication systems) for resource-efficient data corruption prevention. The difference between erasure codes for ML serving and for traditional settings is the need to handle computation over inputs. In other words, the encoding and decoding must hold over computation F.
The problem boils down to: How do we design the erasure codes for ML inference?

Design & Implementation

Current approaches hand-craft erasure codes, which is relatively straightforward for a linear computation F, but is far more challenging for non-linear computations like ML serving. The authors overcome this challenge by taking a learning-based approach. Slap a NN, problem solved!
But wait! Using NNs for encoders/decoders is computationally expensive. Instead, the authors use simple, fast encoders/decoders and operate over parities using a new computation model, namely the parity model. In this diagram, the parity model takes as input parity queries P = X1 + X2 and outputs Fp(P) = F(X1) + F(X2), which can later be used to reconstruct F(X2).
What a brilliant idea. We can also tweak the settings of this process, e.g. using a larger degree of query multiplexing (erasure codes parameter), or using different encoders/decoders instead of the simple summation encoder.
For example, for image tasks, we can downsample multiple queries and concatenate them into a single query


Note that although there is an accuracy loss, the inaccuracy only comes into play when predictions are otherwise slowed down or straigh up failed, which violate the latency requirements. This still sounds like a pretty good tradeoff, although I am curious about the accuracy loss on larger datasets and models with more complex architectures.
Evaluation of the accuracy loss
Tail latency reduction in the presence of resource contention