libcity.data.dataset.dataset_subclass.astgcn_dataset¶
-
class
libcity.data.dataset.dataset_subclass.astgcn_dataset.
ASTGCNDataset
(config)[source]¶ Bases:
libcity.data.dataset.traffic_state_point_dataset.TrafficStatePointDataset
-
_generate_input_data
(df)[source]¶ 根据全局参数len_closeness/len_period/len_trend切分输入,产生模型需要的输入
- Parameters
df (np.ndarray) – 输入数据, shape: (len_time, …, feature_dim)
- Returns
- tuple contains:
sources(np.ndarray): 模型输入数据, shape: (num_samples, Tw+Td+Th, …, feature_dim)
targets(np.ndarray): 模型输出数据, shape: (num_samples, Tp, …, feature_dim)
- Return type
tuple
-
_get_sample_indices
(data_sequence, label_start_idx)[source]¶ 根据全局参数len_closeness/len_period/len_trend找到数据预测目标数据 段: [label_start_idx: label_start_idx+output_window)
- Parameters
data_sequence (np.ndarray) – 输入数据,shape: (len_time, …, feature_dim)
label_start_idx (int) – the first index of predicting target, 预测开始的时间片的索引
- Returns
- tuple contains:
trend_sample: 输入数据1, (len_trend * self.output_window, …, feature_dim)
period_sample: 输入数据2, (len_period * self.output_window, …, feature_dim)
closeness_sample: 输入数据3, (len_closeness * self.output_window, …, feature_dim)
target: 输出数据, (self.output_window, …, feature_dim)
- Return type
tuple
-
_search_data
(sequence_length, label_start_idx, num_for_predict, num_of_depend, units)[source]¶ 根据全局参数len_closeness/len_period/len_trend找到数据索引的位置
- Parameters
sequence_length (int) – 历史数据的总长度
label_start_idx (int) – 预测开始的时间片的索引
num_for_predict (int) – 预测的时间片序列长度
num_of_depend (int) – len_trend/len_period/len_closeness
units (int) – trend/period/closeness的长度(以小时为单位)
- Returns
起点-终点区间段的数组,list[(start_idx, end_idx)]
- Return type
list
-