Model interpretability - Machine Learning Best Practices in Healthcare and Life Sciences

Model interpretability

Interpretability is the degree to which a human can understand the cause of a decision. The higher the interpretability of an ML model, the easier it is to comprehend the model’s predictions. Interpretability facilitates:

  • Understanding

  • Debugging and auditing ML model predictions

  • Bias detection to ensure fair decision making

  • Robustness checks to ensure that small changes in the input do not lead to large changes in the output

  • Methods that provide recourse for those who have been adversely affected by model predictions

In the context of GxP compliance, model interpretability provides a mechanism to ensure the safety and effectiveness of ML solutions by increasing the transparency around model predictions, as well as the behavior of the underlying algorithm. Promoting transparency is a key aspect of the patient-centered approach, and is especially important for AI/ML-based SaMD, which may learn and change over time.

There is a tradeoff between what the model has predicted (model performance) and why the model has made such a prediction (model interpretability).

For some solutions, a high model performance is sufficient; in others, the ability to interpret the decisions made by the model is key. The demand for interpretability increases when there is a large cost for incorrect predictions, especially in high-risk applications.

Diagram showing the trade-off between performance and model interpretability

Trade-off between performance and model interpretability

Based on the model complexity, methods for model interpretability can be classified into intrinsic analysis and post hoc analysis.

  • Intrinsic analysis can be applied to interpret models that have low complexity (simple relationships between the input variables and the predictions). These models are based on:

    • Algorithms, such as linear regression, where the prediction is the weighted sum of the inputs

    • Decision trees, where the prediction is based on a set of if-then rules

      The simple relationship between the inputs and output results in high model interpretability, but often leads to lower model performance, because the algorithms are unable to capture complex non-linear interactions.

  • Post hoc analysis can be applied to interpret simpler models, as described earlier, as well as more complex models, such as neural networks, which have the ability to capture non-linear interactions. These methods are often model-agnostic and provide mechanisms to interpret a trained model based on the inputs and output predictions. Post hoc analysis can be performed at a local level, or at a global level.

    • Local methods enable you to zoom in on a single data point and observe the behavior of the model in that neighborhood. They are an essential component for debugging and auditing ML model predictions. Examples of local methods include:

    • Local Interpretable Model-Agnostic Explanations (LIME), which provides a sparse, linear approximation of the model behavior around a data point

    • SHapley Additive exPlanations (SHAP), a game theoretic approach based on Shapley values which computes the marginal contribution of each input variable towards the output

    • Counterfactual explanations, which describe the smallest change in the input variables that causes a change in the model’s prediction

    • Integrated gradients, which provide mechanisms to attribute the model’s prediction to specific input variables

    • Saliency maps, which are a pixel attribution method to highlight relevant pixels in an image

  • Global methods enable you to zoom out and provide a holistic view that explains the overall behavior of the model. These methods are helpful for verifying that the model is robust and has the least possible bias to allow for fair decision making. Examples of global methods include:

    • Aggregating local explanations, as defined previously, across multiple data points

    • Permutation feature importance, which measures the importance of an input variable by computing the change in the model’s prediction due to permutations of the input variable

    • Partial dependence plots, which plot the relationship and the marginal effect of an input variable on the model’s prediction

    • Surrogate methods, which are simpler interpretable models that are trained to approximate the behavior of the original complex model

It is recommended to start the ML journey with a simple model that is both inherently interpretable and provides sufficient model performance. In later iterations, if you need to improve the model performance, AWS recommends increasing the model complexity and leveraging post hoc analysis methods to interpret the results.

Selecting both a local method and a global method gives you the ability to interpret the behavior of the model for a single data point, as well as across all data points in the dataset. It is also essential to validate the stability of model explanations, because methods in post-hoc analysis are susceptible to adversarial attacks, where small perturbations in the input could result in large changes in the output prediction and therefore in the model explanations as well.

Amazon SageMaker Clarify provides tools to detect bias in ML models and understand model predictions. SageMaker Clarify uses a model-agnostic feature attribution approach and provides a scalable and efficient implementation of SHAP.