修改 TensorFlow 训练脚本 - Amazon SageMaker

修改 TensorFlow 训练脚本

在此部分中,您将学习如何修改 TensorFlow 训练脚本以配置 SageMaker 模型并行性库,用于自动分区和手动分区。这些示例还包括与 Horovod 集成的示例,用于混合模型和数据并行性。

注意

要了解库支持哪些 TensorFlow 版本,请参阅支持的框架和 AWS 区域

TensorFlow 的自动拆分中列出了为使用库而必须对训练脚本进行的修改。

要了解如何修改训练脚本以在 Horovod 中使用混合模型和数据并行性,请参阅 TensorFlow 和 Horovod 的自动拆分,用于混合模型和数据并行化

如果您要使用手动分区,另请参阅 TensorFlow 的手动拆分

提示

有关演示如何将 TensorFlow 训练脚本与 SageMaker 模型并行性库结合使用的端到端笔记本示例,请参阅 TensorFlow 示例

以下主题显示了训练脚本的示例,您可以使用这些脚本来配置 SageMaker 的模型并行性库,以便对 TensorFlow 模型进行自动分区和手动分区。

注意

默认情况下启用自动分区。除非另行指定,否则示例脚本使用自动分区。

TensorFlow 的自动拆分

要使用 SageMaker 模型并行性库运行 TensorFlow 模型,训练脚本需要进行以下更改:

  1. 使用 smp.init() 导入和初始化库。

  2. 通过从 smp.DistributedModel 继承来定义 Keras 模型,而不是从 Keras 模型类继承。从 smp.DistributedModel 对象的调用方法返回模型输出。请注意,从调用方法返回的任何张量都将在模型并行设备之间广播,这会产生通信开销,因此在调用方法之外不需要的任何张量(例如中间激活)都不应返回。

  3. tf.Dataset.batch() 方法中设置 drop_remainder=True。这是为了确保批次大小始终可以被微批次数量整除。

  4. 使用 smp.dp_rank() 在数据管道中植入随机操作(例如 shuffle(ds, seed=smp.dp_rank())),以确保存有不同模型分区的 GPU 之间数据样本的一致性。

  5. 将向前和向后逻辑放在步进函数中,然后用 smp.step 进行修饰。

  6. 使用 StepOutput 方法(例如 reduce_mean)对微批次的输出进行后处理。smp.step 函数必须具有一个取决于 smp.DistributedModel 的输出的返回值。

  7. 如果有评估步骤,则同样将向前逻辑放在 smp.step 修饰的函数中,然后使用 StepOutput API 对输出进行后处理。

要了解有关 SageMaker 的模型并行性库 API 的更多信息,请参阅 API 文档

以下 Python 脚本是进行更改后的训练脚本的示例。

import tensorflow as tf # smdistributed: Import TF2.x API import smdistributed.modelparallel.tensorflow as smp # smdistributed: Initialize smp.init() # Download and load MNIST dataset. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data( "MNIST-data-%d" % smp.rank() ) x_train, x_test = x_train / 255.0, x_test / 255.0 # Add a channels dimension x_train = x_train[..., tf.newaxis] x_test = x_test[..., tf.newaxis] # smdistributed: If needed, seed the shuffle with smp.dp_rank(), and drop_remainder # in batching to make sure batch size is always divisible by number of microbatches train_ds = ( tf.data.Dataset.from_tensor_slices((x_train, y_train)) .shuffle(10000, seed=smp.dp_rank()) .batch(256, drop_remainder=True) ) # smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API class MyModel(smp.DistributedModel): def __init__(self): super(MyModel, self).__init__() # define layers def call(self, x, training=None): # define forward pass and return the model output model = MyModel() loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) optimizer = tf.keras.optimizers.Adam() train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy") # smdistributed: Define smp.step. Return any tensors needed outside @smp.step def get_grads(images, labels): predictions = model(images, training=True) loss = loss_object(labels, predictions) grads = optimizer.get_gradients(loss, model.trainable_variables) return grads, loss, predictions @tf.function def train_step(images, labels): gradients, loss, predictions = get_grads(images, labels) # smdistributed: Accumulate the gradients across microbatches gradients = [g.accumulate() for g in gradients] optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # smdistributed: Merge predictions and average losses across microbatches train_accuracy(labels, predictions.merge()) return loss.reduce_mean() for epoch in range(5): # Reset the metrics at the start of the next epoch train_accuracy.reset_states() for images, labels in train_ds: loss = train_step(images, labels) accuracy = train_accuracy.result()

如果您已准备好训练脚本,请继续到步骤 2:使用 SageMaker Python SDK 启动训练作业。如果要运行混合模型和数据并行训练作业,请继续到下一个部分。

TensorFlow 和 Horovod 的自动拆分,用于混合模型和数据并行化

您可以将 SageMaker 模型并行性库与 Horovod 结合使用,用于混合模型和数据并行性。要详细了解库如何拆分模型以用于混合并行性,请参阅管道并行性(可用于 PyTorch 和 TensorFlow)

在本步骤中,我们将重点介绍如何修改训练脚本以适应 SageMaker 模型并行性库。

要正确设置训练脚本,以便选取要在 步骤 2:使用 SageMaker Python SDK 启动训练作业 中设置的混合并行度配置,请使用库的帮助程序函数 smp.dp_rank()smp.mp_rank(),它们分别自动检测数据并行秩和模型并行秩。

要查找该库支持的所有 MPI 基元,请参阅《SageMaker Python SDK 文档》中的 MPI 基础知识

脚本中需要进行以下更改:

  • 添加 hvd.allreduce

  • 按照 Horovod 的要求,在第一个批次之后广播变量

  • 使用 smp.dp_rank() 在数据管道中植入随机排序和/或分片操作。

注意

使用 Horovod 时,您不可在训练脚本中直接调用 hvd.init。相反,您必须在 SageMaker Python SDK modelparallel 参数的 步骤 2:使用 SageMaker Python SDK 启动训练作业 中,将 "horovod" 设置为 True。这使得库可以根据模型分区的设备分配,在内部初始化 Horovod。直接在训练脚本中调用 hvd.init() 可能会导致问题。

注意

在训练脚本中直接使用 hvd.DistributedOptimizer API 可能会导致训练性能和速度不佳,因为 API 会隐式地将 AllReduce 操作放入 smp.step 中。我们建议您在 smp.step 返回的梯度上调用 accumulate()reduce_mean() 后,直接调用 hvd.allreduce,以将模型并行性库与 Horovod 一起使用,如下例所示。

要了解有关 SageMaker 的模型并行性库 API 的更多信息,请参阅 API 文档

import tensorflow as tf import horovod.tensorflow as hvd # smdistributed: Import TF2.x API import smdistributed.modelparallel.tensorflow as smp # smdistributed: Initialize smp.init() # Download and load MNIST dataset. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data( "MNIST-data-%d" % smp.rank() ) x_train, x_test = x_train / 255.0, x_test / 255.0 # Add a channels dimension x_train = x_train[..., tf.newaxis] x_test = x_test[..., tf.newaxis] # smdistributed: Seed the shuffle with smp.dp_rank(), and drop_remainder # in batching to make sure batch size is always divisible by number of microbatches train_ds = ( tf.data.Dataset.from_tensor_slices((x_train, y_train)) .shuffle(10000, seed=smp.dp_rank()) .batch(256, drop_remainder=True) ) # smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API class MyModel(smp.DistributedModel): def __init__(self): super(MyModel, self).__init__() # define layers def call(self, x, training=None): # define forward pass and return model outputs model = MyModel() loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) optimizer = tf.keras.optimizers.Adam() train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy") # smdistributed: Define smp.step. Return any tensors needed outside @smp.step def get_grads(images, labels): predictions = model(images, training=True) loss = loss_object(labels, predictions) grads = optimizer.get_gradients(loss, model.trainable_variables) return grads, loss, predictions @tf.function def train_step(images, labels, first_batch): gradients, loss, predictions = get_grads(images, labels) # smdistributed: Accumulate the gradients across microbatches # Horovod: AllReduce the accumulated gradients gradients = [hvd.allreduce(g.accumulate()) for g in gradients] optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Horovod: Broadcast the variables after first batch if first_batch: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) # smdistributed: Merge predictions across microbatches train_accuracy(labels, predictions.merge()) return loss.reduce_mean() for epoch in range(5): # Reset the metrics at the start of the next epoch train_accuracy.reset_states() for batch, (images, labels) in enumerate(train_ds): loss = train_step(images, labels, tf.constant(batch == 0))

TensorFlow 的手动拆分

使用 smp.partition 上下文管理器将操作放在特定的分区中。未放在任何 smp.partition 上下文中的任何操作都放在 default_partition 中。要了解有关 SageMaker 的模型并行性库 API 的更多信息,请参阅 API 文档

import tensorflow as tf # smdistributed: Import TF2.x API. import smdistributed.modelparallel.tensorflow as smp # smdistributed: Initialize smp.init() # Download and load MNIST dataset. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data( "MNIST-data-%d" % smp.rank() ) x_train, x_test = x_train / 255.0, x_test / 255.0 # Add a channels dimension x_train = x_train[..., tf.newaxis] x_test = x_test[..., tf.newaxis] # smdistributed: If needed, seed the shuffle with smp.dp_rank(), and drop_remainder # in batching to make sure batch size is always divisible by number of microbatches. train_ds = ( tf.data.Dataset.from_tensor_slices((x_train, y_train)) .shuffle(10000, seed=smp.dp_rank()) .batch(256, drop_remainder=True) ) # smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API. class MyModel(smp.DistributedModel): def __init__(self): # define layers def call(self, x): with smp.partition(0): x = self.layer0(x) with smp.partition(1): return self.layer1(x) model = MyModel() loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) optimizer = tf.keras.optimizers.Adam() train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy") # smdistributed: Define smp.step. Return any tensors needed outside @smp.step def get_grads(images, labels): predictions = model(images, training=True) loss = loss_object(labels, predictions) grads = optimizer.get_gradients(loss, model.trainable_variables) return grads, loss, predictions @tf.function def train_step(images, labels): gradients, loss, predictions = get_grads(images, labels) # smdistributed: Accumulate the gradients across microbatches gradients = [g.accumulate() for g in gradients] optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # smdistributed: Merge predictions and average losses across microbatches train_accuracy(labels, predictions.merge()) return loss.reduce_mean() for epoch in range(5): # Reset the metrics at the start of the next epoch train_accuracy.reset_states() for images, labels in train_ds: loss = train_step(images, labels) accuracy = train_accuracy.result()

不支持的框架功能

库不支持以下 TensorFlow 功能:

  • 当前不支持 tf.GradientTape()。您可以改用 Optimizer.get_gradients()Optimizer.compute_gradients() 来计算梯度。

  • 目前不支持 tf.train.Checkpoint.restore() API。对于检查点操作,请改用 smp.CheckpointManager,它提供相同的 API 和功能。请注意,smp.CheckpointManager 的检查点还原应在第一步之后进行。