已实现的Dataset类¶
已实现的Dataset类的功能介绍如下
AbstractDataset所有数据集的基类。注意这是抽象类,不能直接使用
TrajectoryDataset所有轨迹位置预测任务的基类。会根据设定的
window_size与cut_method来对获取到的轨迹记录进行切割,将原始数据集中的数月长的轨迹切割为符合单次出行时间/距离长度的子轨迹。在完成切割后,会根据设定的traj_encoder参数调用相应的轨迹时空特征编码器来对轨迹进行特征提取并生成模型输入。TrafficStateDataset用于交通状态预测任务的基类之一。注意这是抽象类,不能直接使用。默认情况下,
input_window的数据被用来预测output_window对应的数据。此类生成的Batch 对象包含两个键,分别是X和y。此处的input_window和output_window是data的参数,点击此处查看细节。TrafficStateCPTDataset用于交通状态预测任务的另一个基类。注意这是抽象类,不能直接使用。一些交通预测模型通过对接近度/周期/趋势的建模来实现预测。默认情况下,
len_closeness/len_period/len_trend的数据被用来预测当前时刻的数据(单步预测)。此类生成的Batch对象包含4个键,分别是X,y,X_ext和y_ext。此处的len_closeness/len_period/len_trend是data的参数, 点击此处查看细节。TrafficStatePointDataset一个继承了
TrafficStateDataset的类,用于交通状态预测,该类适用于空间维度是一维的数据集(即基于点/基于段/基于区域的数据集)。该类生成的Batch对象中的张量形状是3维的,即space_dim,time_dim,feature_dim(空间维度、时间维度、特征维度)。TrafficStateGridDataset一个继承了
TrafficStateDataset的类,用于交通状态预测,该类适用于基于网格的数据集。这个类生成的Batch对象中的张量形状是3维还是4维取决于参数use_row_column。如果设置use_row_column=True,那么4个维度是grid_row_dim,grid_column_dim,time_dim,feature_dim(网格行数、网格列数、时间维度、特征维度)。否则,3个维度是space_dim、time_dim、feature_dim(空间维度、时间维度、特征维度),在这种情况下,网格被重新编号为一维的。TrafficStateOdDataset一个继承了
TrafficStateDataset的类,用于交通状态预测,该类适用于基于OD的数据集,即起点和终点。这个类生成的Batch对象中的张量形状是4维的,即origin_dim,destination_dim,time_dim,feature_dim(起点维度、终点维度,时间维度、特征维度)。TrafficStateGridOdDataset一个继承了
TrafficStateDataset的类,用于交通状态预测,该类适用于基于网格的OD数据集。这个类生成的Batch对象中的张量形状是4维或6维,取决于参数use_row_column。如果设置use_row_column=True,那么张量拥有6个维度,分别是origin_grid_row_dim,origin_grid_column_dim,destination_grid_row_dim,destination_grid_column_dim,time_dim,feature_dim(起点网格行数、起点网格列数、终点网格行数、终点网格列数、时间维度、特征维度)。否则,张量拥有4个维度,分别是origin_dim,destination_dim,time_dim,feature_dim(起点空间维度、终点空间维度、时间维度、特征维度),在这种情况下,二维网格被重新编号为一维的。MapMatchingDataset所有地图匹配任务的基类。该类生成一个包含3个键的字典:
rd_nwk,trajectory和route,分别代表路网、GPS样本的轨迹和真实轨迹。如果delta_time=True,trajectory将包括一个time列,表示轨迹时间读秒。delta_time是数据集的参数,详见此处。标准数据输入的介绍见此处。ETADataset所有ETA任务的基类。
_load_dyna函数会读取轨迹信息。_encode_traj函数会根据设定的eta_encoder参数调用相应的轨迹时空特征编码器来对轨迹进行特征提取。提取到的特征会被分为训练集、验证集、测试集,来生成模型输入。