Source code for libcity.data.dataset.map_matching_dataset

import pickle

from libcity.data.dataset import AbstractDataset
import os
import pandas as pd
import numpy as np
from logging import getLogger
from libcity.utils.dataset import parse_time
from libcity.utils.GPS_utils import dist, angle2radian
from libcity.utils.utils import ensure_dir
import networkx as nx


[docs]class UnionSet: def __init__(self, n): self.n = n self.lst = list(range(n))
[docs] def find(self, index): if index != self.lst[index]: self.lst[index] = self.find(self.lst[index]) return self.lst[index]
[docs] def union(self, index1, index2): self.lst[self.find(index1)] = self.find(index2)
[docs] def print(self): print(self.lst)
[docs]class MapMatchingDataset(AbstractDataset): """ 路网匹配数据集的基类。 """ def __init__(self, config): # config and dataset name self.config = config self.dataset = self.config.get('dataset', '') # logger self._logger = getLogger() # features self.with_time = config.get('with_time', True) # 输入轨迹数据是否包含时间 self.delta_time = config.get('delta_time', True) # True则轨迹输入时间差(s),False则轨迹输入时间datetime.datetime self.with_rd_speed = ('speed' in config['rel']['geo'].keys()) # cache self.cache_dataset = self.config.get('cache_dataset', True) self.parameters_str = \ str(self.dataset) + '_' + str(self.delta_time) self.cache_file_name = os.path.join('./libcity/cache/dataset_cache/', 'map_matching_{}.pkl'.format(self.parameters_str)) self.cache_file_folder = './libcity/cache/dataset_cache/' ensure_dir(self.cache_file_folder) # ensure dataset self.data_path = './raw_data/' + self.dataset + '/' if not os.path.exists(self.data_path): raise ValueError("Dataset {} not exist! Please ensure the path " "'./raw_data/{}/' exist!".format(self.dataset, self.dataset)) # related file names self.geo_file = self.config.get('geo_file', self.dataset) self.rel_file = self.config.get('rel_file', self.dataset) self.dyna_file = self.config.get('dyna_file', self.dataset) self.usr_file = self.config.get('usr_file', self.dataset) self.truth_file = self.config.get('truth_file', self.dataset + '_truth') # result self.trajectory = None self.rd_nwk = None self.route = None # load 5 files if not self.cache_dataset or not os.path.exists(self.cache_file_name): if os.path.exists(self.data_path + self.rel_file + '.rel'): if os.path.exists(self.data_path + self.geo_file + '.geo'): self._load_geo_and_rel() else: raise ValueError('Not found .geo file!') else: raise ValueError('Not found .rel file!') if os.path.exists(self.data_path + self.usr_file + '.usr'): self._load_usr() else: raise ValueError('Not found .rel file!') if os.path.exists(self.data_path + self.dyna_file + '.dyna'): self._load_dyna() else: raise ValueError('Not found .dyna file!') if os.path.exists(self.data_path + self.truth_file + '.dyna'): self._load_truth_dyna()
[docs] def _load_geo_and_rel(self): """ 加载.geo文件,格式[geo_id, type, coordinates, properties(若干列)] 加载.rel文件,格式[rel_id, type, origin_id, destination_id, properties(若干列)], .rel文件用来表示路网数据 Returns: self.rd_nwk: networkx.MultiDiGraph """ # init road network, which is the result of this function self.rd_nwk = nx.DiGraph(name="road network") # load geo and rel file geofile = pd.read_csv(self.data_path + self.geo_file + '.geo') relfile = pd.read_csv(self.data_path + self.rel_file + '.rel') geo_num = geofile.shape[0] # check type geo in rel file and LineString in geo file if not ['geo'] == self.config['rel']['including_types']: raise ValueError('.rel file should include geo type in Map Matching task!') if not ['LineString'] == self.config['geo']['including_types']: raise ValueError('.geo file should include LineString type in Map Matching task!') # get properties columns = relfile.columns.tolist()[4:] # use UnionSet to get nodes node_set = UnionSet(2 * geo_num) for index, row in relfile.iterrows(): # origin and destination from_id = int(row[2]) to_id = int(row[3]) node_set.union(from_id, to_id + geo_num) # generate MultiDigraph for index, row in geofile.iterrows(): geo_id = int(row['geo_id']) coordinate = eval(row['coordinates']) origin_node = node_set.find(geo_id + geo_num) dest_node = node_set.find(geo_id) if origin_node not in self.rd_nwk.nodes: self.rd_nwk.add_node(origin_node, lon=coordinate[0][0], lat=coordinate[0][1]) if dest_node not in self.rd_nwk.nodes: self.rd_nwk.add_node(dest_node, lon=coordinate[1][0], lat=coordinate[1][1]) # add edge self.rd_nwk.add_edge(origin_node, dest_node) feature_dct = dict() for i, column in enumerate(columns): feature_dct[column] = row[i + 4] if 'distance' not in feature_dct.keys(): feature_dct['distance'] = dist( angle2radian(self.rd_nwk.nodes[origin_node]['lat']), angle2radian(self.rd_nwk.nodes[origin_node]['lon']), angle2radian(self.rd_nwk.nodes[dest_node]['lat']), angle2radian(self.rd_nwk.nodes[dest_node]['lon']) ) feature_dct['geo_id'] = geo_id self.rd_nwk.edges[origin_node, dest_node].update(feature_dct) # logger self._logger.info("Loaded file " + self.geo_file + '.geo' + ', num_nodes=' + str(geo_num)) self._logger.info("Loaded file " + self.rel_file + '.rel, num_roads=' + str(len(self.rd_nwk)))
[docs] def _load_usr(self): """ 加载.usr文件, 格式 [usr_id] Returns: np.ndarray: self.usr_lst 用户id的集合 """ usrfile = pd.read_csv(self.data_path + self.usr_file + '.usr') self.usr_lst = [] for index, row in usrfile.iterrows(): self.usr_lst.append(row[0]) self._logger.info("Loaded file " + self.rel_file + '.usr, num_users=' + str(len(self.usr_lst)))
[docs] def _load_dyna(self): """ 加载.dyna文件,格式 [dyna_id,type,time,entity_id,location] self.with_time 用于表示轨迹是否包含时间信息 Returns: np.ndarray: 数据数组 """ dynafile = pd.read_csv(self.data_path + self.dyna_file + '.dyna') if not ['trajectory'] == self.config['dyna']['including_types']: raise ValueError('.dyna file should include trajectory type in Map Matching task!') if not self.config['dyna']['trajectory']["entity_id"] == "usr_id": raise ValueError('entity_id should be usr_id in Map Matching task!') self.trajectory = {} self.multi_traj = 'traj_id' in dynafile.keys() for index, row in dynafile.iterrows(): dyna_id = row['dyna_id'] usr_id = row['entity_id'] traj_id = row['traj_id'] if self.multi_traj else 0 time = row['time'] coordinate = eval(row['coordinates']) if usr_id not in self.usr_lst: raise ValueError('entity_id %d should be in usr_ids in Map Matching task!' % usr_id) # if row['location'] not in self.geo_data.keys(): # raise ValueError('location %d should be in geo_ids in Map Matching task!' % row['location']) if self.with_time: if usr_id in self.trajectory.keys(): if traj_id in self.trajectory[usr_id].keys(): self.trajectory[usr_id][traj_id].append([dyna_id] + coordinate + [parse_time(time)]) else: self.trajectory[usr_id][traj_id] = [[dyna_id] + coordinate + [parse_time(time)]] else: self.trajectory[usr_id] = {traj_id: [[dyna_id] + coordinate + [parse_time(row['time'])]]} else: if usr_id in self.trajectory.keys(): if traj_id in self.trajectory[usr_id].keys(): self.trajectory[usr_id][traj_id].append([dyna_id] + coordinate) else: self.trajectory[usr_id][traj_id] = [[dyna_id] + coordinate] else: self.trajectory[usr_id] = {traj_id: [[dyna_id] + coordinate]} if self.delta_time and self.with_time: for usr_id, usr_value in self.trajectory.items(): for traj_id, trajectory in usr_value.items(): t0 = trajectory[0][3] trajectory[0][3] = 0 for i in range(1, len(trajectory)): trajectory[i][3] = (trajectory[i][3] - t0).seconds for key, value in self.trajectory.items(): for key_i, value_i in value.items(): self.trajectory[key][key_i] = np.array(value_i) self._logger.info("Loaded file " + self.dyna_file + '.dyna, num of GPS samples=' + str(dynafile.shape[0]))
[docs] def _load_truth_dyna(self): """ 加载.dyna文件,格式: 每行一个 rel_id 或一组 rel_id Returns: """ # open file truth_dyna = pd.read_csv(self.data_path + self.truth_file + '.dyna') # result of the function self.route = {} # multi_traj multi_traj = 'traj_id' in truth_dyna.keys() if multi_traj != self.multi_traj: raise ValueError('cannot match traj_id in route file and dyna file') # set route for index, row in truth_dyna.iterrows(): dyna_id = row['dyna_id'] usr_id = row['entity_id'] traj_id = row['traj_id'] if multi_traj else 0 location = row['location'] # check usr if usr_id not in self.usr_lst: raise ValueError('usr_id %d should be in usr_ids in Map Matching task!' % usr_id) if usr_id in self.route.keys(): if traj_id in self.route[usr_id].keys(): self.route[usr_id][traj_id].append([dyna_id, location]) else: self.route[usr_id][traj_id] = [[dyna_id, location]] else: self.route[usr_id] = {traj_id: [[dyna_id, location]]} for key, value in self.route.items(): for key_i, value_i in value.items(): self.route[key][key_i] = np.array(value_i) self._logger.info("Loaded file " + self.truth_file + '.dyna, route length=' + str(truth_dyna.shape[0]))
[docs] def get_data(self): """ 返回训练数据、验证数据、测试数据 对于MapMatching,训练数据和验证数据为None。 Returns: dictionary: { 'trajectory': np.array (time, lon, lat) if with_time else (lon, lat) 'rd_nwk': networkx.MultiDiGraph 'route': ground truth, numpy array } """ if self.cache_dataset and os.path.exists(self.cache_file_name): self._logger.info('Loading ' + self.cache_file_name) with open(self.cache_file_name, 'rb') as f: res = pickle.load(f) self.multi_traj = res['multi_traj'] return None, None, res res = dict() res['trajectory'] = self.trajectory res['rd_nwk'] = self.rd_nwk res['route'] = self.route res['multi_traj'] = self.multi_traj with open(self.cache_file_name, 'wb') as f: pickle.dump(res, f) self._logger.info('Saved at ' + self.cache_file_name) return None, None, res
[docs] def get_data_feature(self): """ 返回一个 dict,包含数据集的相关特征 Returns: dict: 包含数据集的相关特征的字典 """ res = dict() res['with_time'] = self.with_time res['with_rd_speed'] = self.with_rd_speed res['delta_time'] = self.delta_time res['multi_traj'] = self.multi_traj return res