数据集适用的任务¶
在本节中,我们向您介绍我们提供的标准数据集和任务、模型之间的对应关系。 值得注意的是,某些数据集只能支持某个任务中的某些模型。
如果你觉得这个表格看起来不够方便,请点击这里在Github上查看此表格。
任务 | 数据集 | 支持的模型 | 备注 |
---|---|---|---|
交通流量预测 | TAXIBJ, PORTO, NYCTAXI_GRID, NYCBIKE, AUSTINRIDE, BIKEDC, BIKECHI, NYCBike20140409, NYCBike20160708, NYCBike20160809, NYCTaxi20140112, NYCTaxi20150103, NYCTaxi20160102, T_DRIVE20150206, T_DRIVE_SMALL | ACFM, STResNet, DSAN, ACFMCommon, STResNetCommon | 基于网格的数据 |
PEMSD3, PeMSD4, PeMSD7, PeMSD8, BEIJING_SUBWAY, M_DENSE, SHMETRO, HZMETRO, NYCTAXI_DYNA | RNN, Seq2Seq, FNN, AutoEncoder, AGCRN, ASTGCNCommon, MSTGCNCommon, STSGCN, CONVGCNCommon, ToGCN, MultiSTGCnetCommon, STNN, ASTGCN, MSTGCN, CONVGCN, DGCN, ResLSTM, MultiSTGCnet | 基于传感器点的数据 | |
TAXIBJ, PORTO, NYCTAXI_GRID, NYCBIKE, AUSTINRIDE, BIKEDC, BIKECHI, NYCBike20140409, NYCBike20160708, NYCBike20160809, NYCTaxi20140112, NYCTaxi20150103, NYCTaxi20160102, T_DRIVE20150206, T_DRIVE_SMALL | RNN, Seq2Seq, FNN, AutoEncoder, AGCRN, ASTGCNCommon, MSTGCNCommon, STSGCN, CONVGCNCommon, ToGCN, MultiSTGCnetCommon, STNN, ASTGCN, MSTGCN, CONVGCN, DGCN, ResLSTM, MultiSTGCnet | 需要简单的修改以用于网格数据, 见 Note 3. | |
M_DENSE | CRANN | 需要.ext 文件 |
|
NYCBike20160708, NYCTaxi20150103 | STDN | 需要.gridod 文件 |
|
交通速度预测 | METR_LA, LOS_LOOP, PeMSD4, PeMSD8, PEMSD7(M), PEMS_BAY, LOS_LOOP_SMALL, SZ_TAXI, LOOP_SEATTLE, Q_TRAFFIC, ROTTERDAM | RNN, Seq2Seq, FNN, AutoEncoder, DCRNN, STGCN, GWNET, MTGNN, STMGAT, TGCN, ATDM, HGCN, DKFN, STTN, GTS, GMAN, STAGGCN, TGCLSTM | 基于传感器点的数据 |
LOOP_SEATTLE | TGCLSTM | 见Note 2. | |
交通需求预测 | TAXIBJ, PORTO, NYCTAXI_GRID, NYCBIKE, AUSTINRIDE, BIKEDC, BIKECHI, NYCBike20140409, NYCBike20160708, NYCBike20160809, NYCTaxi20140112, NYCTaxi20150103, NYCTaxi20160102, T_DRIVE20150206, T_DRIVE_SMALL | DMVSTNet | 基于网格的数据 |
PEMSD3, PeMSD4, PeMSD7, PeMSD8, BEIJING_SUBWAY, M_DENSE, SHMETRO, HZMETRO, NYCTAXI_DYNA | RNN, Seq2Seq, FNN, AutoEncoder, CCRNN, STG2Seq | 基于传感器点的数据 | |
TAXIBJ, PORTO, NYCTAXI_GRID, NYCBIKE, AUSTINRIDE, BIKEDC, BIKECHI, NYCBike20140409, NYCBike20160708, NYCBike20160809, NYCTaxi20140112, NYCTaxi20150103, NYCTaxi20160102, T_DRIVE20150206, T_DRIVE_SMALL | CCRNN, STG2Seq | 需要简单的修改以用于网格数据, 见 Note 3. | |
OD矩阵预测 | NYCTAXI_OD | GEML | 基于点的OD的数据 |
NYC_TOD | GEML | 需要简单的修改以用于网格OD数据, 见 Note 4. | |
NYC_TOD | CSTN | 基于网格的OD的数据和.ext 文件 |
|
交通事故预测 | NYC_RISK, CHICAGO_RISK | GSNet | 基于网格的交通事故数据 |
轨迹下一跳预测 | Gowalla, BrightKite | FPMC, RNN, ST-RNN, ATST-LSTM, DeepMove, HST-LSTM, LSTPM, STAN | 轨迹数据 |
Fousquare, Instagram | FPMC, RNN, ST-RNN, ATST-LSTM, DeepMove, HST-LSTM, LSTPM, GeoSAN, STAN, SERM, CARA | 轨迹数据 | |
到达时间估计 | Chengdu_Taxi_Sample1 | DeepTTE | 轨迹数据 |
Beijing_Taxi_Sample | DeepTTE, TTPNet | 轨迹数据 | |
路网匹配 | Seattle, global | STMatching, IVMM, HMMM | 轨迹数据 |
路网表征学习 | bj_roadmap_edge | ChebConv, LINE | 路网数据 |
Note 1
加粗的数据集是我们推荐的数据集。
Note 2
对于TGCLSTM
,需要将dataset_class
设置为TrafficStatePointDataset
。 否则,默认的dataset_class=TGCLSTMDataset
只适用于数据集LOOP_SEATTLE
。
Note 3
以下是如何将用于基于点的数据的模型推广到基于网格的数据。
(1)如果模型使用的数据集类是TrafficStatePointDataset
,如AGCRN
、ASTGCNCommon
、CCRNN
等,可以直接在task_file.json
中将dataset_class
设为TrafficStateGridDataset
或通过自定义配置文件(--config_file
)设置。 然后将 TrafficStateGridDataset
的参数 use_row_column
设置为 False
。
(2)如果模型使用的数据集类是TrafficStatePointDataset
的子类,如ASTGCNDataset
、CONVGCNDataset
、STG2SeqDataset
等,可以修改数据集类的文件,使其继承TrafficStateGridDataset
取代当前的TrafficStatePointDataset
。 然后将函数__init__()
中的参数use_row_column
设置为False
。
样例(1):
修改前:
# task_config.json
"RNN": {
"dataset_class": "TrafficStatePointDataset",
},
# TrafficStateGridDataset.json
{
"use_row_column": true
}
修改后:
# task_config.json
"RNN": {
"dataset_class": "TrafficStateGridDataset",
},
# TrafficStateGridDataset.json
{
"use_row_column": false
}
样例(2):
修改前:
# task_config.json
"STG2Seq": {
"dataset_class": "STG2SeqDataset",
},
# STG2SeqDataset.json
{
"use_row_column": false
}
# stg2seq_dataset.py
from libcity.data.dataset import TrafficStatePointDataset
class STG2SeqDataset(TrafficStatePointDataset):
def __init__(self, config):
super().__init__(config)
pass
修改后:
# task_config.json
"STG2Seq": {
"dataset_class": "STG2SeqDataset",
},
# STG2SeqDataset.json
{
"use_row_column": false
}
# stg2seq_dataset.py
from libcity.data.dataset import TrafficStateGridDataset
class STG2SeqDataset(TrafficStateGridDataset):
def __init__(self, config):
super().__init__(config)
self.use_row_column = False
pass
Note 4
以下是如何将用于基于点的OD的数据的模型推广到基于网格的OD的数据。(跟Note 3是类似的)
(1)如果模型使用的数据集类是TrafficStateOdDataset
,如GEML
等,可以直接在task_file.json
中将dataset_class
设为TrafficStateGridOdDataset
或通过自定义配置文件(--config_file
)设置。 然后将 TrafficStateGridOdDataset
的参数 use_row_column
设置为 False
。
(2)如果模型使用的数据集类是TrafficStateOdDataset
的子类,可以修改数据集类的文件,使其继承TrafficStateGridOdDataset
取代当前的TrafficStateOdDataset
。 然后将函数__init__()
中的参数use_row_column
设置为False
。