Source code for libcity.data.dataset.staggcn_dataset

import torch
from tqdm import tqdm
import numpy as np
import pandas as pd
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
import pickle
import os

from libcity.data.dataset import TrafficStatePointDataset
# from libcity.data.dataset import TrafficStateGridDataset


"""
主要功能是定义了一种计算语义邻接矩阵的方法,并缓存到dataset_cache/,并通过get_data_feature返回
注意这里要求邻接矩阵构造成0-1矩阵,在STAGGCNDataset.json中进行了设置
STAGGCNDataset既可以继承TrafficStatePointDataset,也可以继承TrafficStateGridDataset以处理网格数据
修改成TrafficStateGridDataset时,只需要修改:
1.TrafficStatePointDataset-->TrafficStateGridDataset
2.self.use_row_column = False, 可以加到self.parameters_str中
3.计算 DTW 邻接矩阵前需要 reshape 为 (time_len, num_nodes, feature)
"""


[docs]class STAGGCNDataset(TrafficStatePointDataset): def __init__(self, config): super().__init__(config) self.points_per_hour = 3600 // self.time_intervals # 每小时的时间片数 self.period = 7 * 24 * self.points_per_hour # 一周的时间点数目,间隔为5min,用于求dtw_edge_index self.edge_index = self.get_edge_index() self.load_from_local = self.config.get('load_from_local', True) cache_path = './libcity/cache/dataset_cache/dtw_edge_index_' + self.dataset + '.npz' if self.load_from_local and os.path.exists(cache_path): # 提前算好了dtw_edge_index,并从本地导入 with open(cache_path, 'rb') as f: self.dtw_edge_index = pickle.load(f) else: # 临时求dtw_edge_index (临时求会耗时很久) self.dtw_edge_index = self.get_dtw_edge_index() with open(cache_path, 'wb') as f: pickle.dump(self.dtw_edge_index, f) # 返回语义邻接边集(该部分直接截取自源码,做为函数直接调用) # 根据.dyna文件求取语义邻接矩阵,通过调用edge_index_func将其转化为语义邻接边集
[docs] def get_dtw_edge_index(self): i = 0 for filename in self.data_files: if i == 0: df = self._load_dyna(filename) # (len_time, node_num, feature_dim) else: df = np.concatenate((df, self._load_dyna(filename)), axis=0) i += 1 df = df[:, :, 0] line = df.shape[0] order = np.arange(line).reshape(line, 1) df = np.concatenate((df, order), axis=1) df = pd.DataFrame(df) df['symbol'] = df[self.num_nodes] % self.period for i in tqdm(range(self.period)): df_i = df[df['symbol'] == i] values_i = df_i.values[:, :-1] mean_i = np.mean(values_i, axis=0)[np.newaxis, :] if i == 0: mean = mean_i else: mean = np.concatenate((mean, mean_i), axis=0) mean = mean.T dtw_matrix = np.zeros((self.num_nodes, self.num_nodes)) for index_x in tqdm(range(self.num_nodes)): for index_y in range(index_x, self.num_nodes): x = mean[index_x] y = mean[index_y] distance, _ = fastdtw(x, y, dist=euclidean) dtw_matrix[index_x][index_y] = distance for i in range(self.num_nodes): for j in range(0, i): dtw_matrix[i][j] = dtw_matrix[j][i] std = np.std(dtw_matrix) dtw_matrix = dtw_matrix / std dtw_matrix = np.exp(-1 * dtw_matrix) dtw_threshold = 0.83 count_min, count_max = self.num_nodes, 0 count_zero = 0 count_avg = 0 matrix = np.identity(self.num_nodes) for i in range(self.num_nodes): dtw_count_i = 0 for j in range(self.num_nodes): if dtw_matrix[i][j] > dtw_threshold: dtw_count_i += 1 matrix[i][j] = 1 count_avg += dtw_count_i if dtw_count_i == 1: count_zero += 1 if dtw_count_i > count_max: count_max = dtw_count_i if dtw_count_i < count_min: count_min = dtw_count_i return self.edge_index_func(matrix)
# 返回空间邻接边集 # 根据.geo文件求取空间邻接矩阵,通过调用edge_index_func将其转化为空间邻接边集
[docs] def get_edge_index(self): return self.edge_index_func(self.adj_mx)
# 用于将邻接矩阵转化为邻接边集
[docs] def edge_index_func(self, matrix): # print(matrix, matrix.max(), matrix.min()) a, b = [], [] for i in range(matrix.shape[0]): for j in range(matrix.shape[1]): if matrix[i][j] == 1: # matrix是0-1矩阵 a.append(i) b.append(j) edge = [a, b] edge_index = torch.tensor(edge, dtype=torch.long) return edge_index
[docs] def get_data_feature(self): """ 返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是点的个数, feature_dim是输入数据的维度,output_dim是模型输出的维度 Returns: dict: 包含数据集的相关特征的字典 """ return {"scaler": self.scaler, "adj_mx": self.adj_mx, "num_nodes": self.num_nodes, "feature_dim": self.feature_dim, "output_dim": self.output_dim, "ext_dim": self.ext_dim, "edge_index": self.edge_index, # 将空间邻接边集作为data_feature返回 "dtw_edge_index": self.dtw_edge_index, # 将语义邻接边集作为data_feature返回 "num_batches": self.num_batches}