libcity.data.dataset.dataset_subclass.stg2seq_dataset¶
-
class
libcity.data.dataset.dataset_subclass.stg2seq_dataset.
STG2SeqDataset
(config)[source]¶ Bases:
libcity.data.dataset.traffic_state_point_dataset.TrafficStatePointDataset
-
_generate_data
()[source]¶ 加载数据文件(.dyna/.grid/.od/.gridod)和外部数据(.ext),且将二者融合,以X,y的形式返回
- Returns
- tuple contains:
x(np.ndarray): 模型输入数据,(num_samples, input_length, …, feature_dim)
y(np.ndarray): 模型输出数据,(num_samples, output_length, …, feature_dim)
- Return type
tuple
-
_generate_train_val_test
()[source]¶ 加载数据集,并划分训练集、测试集、验证集,并缓存数据集
- Returns
- tuple contains:
x_train: (num_samples, input_length, …, feature_dim)
y_train: (num_samples, input_length, …, feature_dim)
x_val: (num_samples, input_length, …, feature_dim)
y_val: (num_samples, input_length, …, feature_dim)
x_test: (num_samples, input_length, …, feature_dim)
y_test: (num_samples, input_length, …, feature_dim)
- Return type
tuple
-
_load_cache_train_val_test
()[source]¶ 加载之前缓存好的训练集、测试集、验证集
- Returns
- tuple contains:
x_train: (num_samples, input_length, …, feature_dim)
y_train: (num_samples, input_length, …, feature_dim)
x_val: (num_samples, input_length, …, feature_dim)
y_val: (num_samples, input_length, …, feature_dim)
x_test: (num_samples, input_length, …, feature_dim)
y_test: (num_samples, input_length, …, feature_dim)
- Return type
tuple
-
_load_rel
()[source]¶ 根据网格结构构建邻接矩阵,一个格子跟他周围的8个格子邻接
- Returns
self.adj_mx, N*N的邻接矩阵
- Return type
np.ndarray
-
_split_train_val_test
(x, y, df=None)[source]¶ 划分训练集、测试集、验证集,并缓存数据集
- Parameters
x (np.ndarray) – 输入数据 (num_samples, input_length, …, feature_dim)
y (np.ndarray) – 输出数据 (num_samples, input_length, …, feature_dim)
- Returns
- tuple contains:
x_train: (num_samples, input_length, …, feature_dim)
y_train: (num_samples, input_length, …, feature_dim)
x_val: (num_samples, input_length, …, feature_dim)
y_val: (num_samples, input_length, …, feature_dim)
x_test: (num_samples, input_length, …, feature_dim)
y_test: (num_samples, input_length, …, feature_dim)
- Return type
tuple
-