TabTransformer 算法的输入和输出接口 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

TabTransformer 算法的输入和输出接口

TabTransformer 对表格数据进行操作,行代表观测值,一列代表目标变量或标签,其余列代表特征。

的 SageMaker AI 实现 TabTransformer 支持用于训练和推理的 CSV:

  • 对于训练 ContentType,有效的输入必须是文本/ csv。

  • 要进行推理 ContentType,有效的输入必须是文本 /csv。

注意

对于 CSV 训练,算法假定目标变量在第一列中,而 CSV 没有标头记录。

对于 CSV 推理,算法假定 CSV 输入没有标签列。

训练数据、验证数据和类别特征的输入格式

请注意如何格式化训练数据,以便输入到 TabTransformer 模型中。您必须提供包含训练和验证数据的 Amazon S3 存储桶的路径。您还可以包含类别特征列表。请使用 trainingvalidation 通道来提供您的输入数据。您也可以只使用 training 通道。

使用 trainingvalidation 通道

您可以通过两条 S3 路径来提供输入数据,一条用于 training 通道,一条用于 validation 通道。每个 S3 路径可以是指向一个或多个 CSV 文件的 S3 前缀,也可以是指向一个特定 CSV 文件的完整 S3 路径。目标变量应位于 CSV 文件的第一列。预测器变量(特征)应位于其余列。如果为trainingvalidation通道提供了多个 CSV 文件,则 TabTransformer 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。

如果您的预测器包含类别特征,则可以在与您的训练数据文件相同的位置,提供一个名为 categorical_index.json 的 JSON 文件。如果您为类别特征提供 JSON 文件,则您的 training 通道必须指向 S3 前缀而不是特定 CSV 文件。此文件应包含一个 Python 字典,其中的键是字符串 "cat_index_list",值是唯一整数列表。值列表中的每个整数都应指示训练数据 CSV 文件中对应分类特征的列索引。每个值都应为正整数(大于零,因为零表示目标值),小于 Int32.MaxValue (2147483647),并且小于列的总数。只应有一个类别索引 JSON 文件。

仅使用 training 通道

您也可以通过单个 S3 路径,为 training 通道提供输入数据。此 S3 路径指向的目录中应包含一个名为 training/ 的子目录,而该子目录中包含一个或多个 CSV 文件。您可以选择在相同位置添加另一个名为 validation/ 的子目录,该子目录同样包含一个或多个 CSV 文件。如果未提供验证数据,则会随机采样 20% 的训练数据作为验证数据。如果您的预测器包含类别特征,则可以在与您的数据子目录相同的位置,提供一个名为 categorical_index.json 的 JSON 文件。

注意

对于 CSV 训练输入模式,供算法使用的内存总量(实例计数乘以 InstanceType 中的可用内存)必须能够容纳训练数据集。