import importlib
import logging
import datetime
import os
import sys
import numpy as np
import random
import torch
[docs]def get_executor(config, model, data_feature):
"""
according the config['executor'] to create the executor
Args:
config(ConfigParser): config
model(AbstractModel): model
Returns:
AbstractExecutor: the loaded executor
"""
try:
return getattr(importlib.import_module('libcity.executor'),
config['executor'])(config, model, data_feature)
except AttributeError:
raise AttributeError('executor is not found')
[docs]def get_model(config, data_feature):
"""
according the config['model'] to create the model
Args:
config(ConfigParser): config
data_feature(dict): feature of the data
Returns:
AbstractModel: the loaded model
"""
if config['task'] == 'traj_loc_pred':
try:
return getattr(importlib.import_module('libcity.model.trajectory_loc_prediction'),
config['model'])(config, data_feature)
except AttributeError:
raise AttributeError('model is not found')
elif config['task'] == 'traffic_state_pred':
try:
return getattr(importlib.import_module('libcity.model.traffic_flow_prediction'),
config['model'])(config, data_feature)
except AttributeError:
try:
return getattr(importlib.import_module('libcity.model.traffic_speed_prediction'),
config['model'])(config, data_feature)
except AttributeError:
try:
return getattr(importlib.import_module('libcity.model.traffic_demand_prediction'),
config['model'])(config, data_feature)
except AttributeError:
try:
return getattr(importlib.import_module('libcity.model.traffic_od_prediction'),
config['model'])(config, data_feature)
except AttributeError:
try:
return getattr(importlib.import_module('libcity.model.traffic_accident_prediction'),
config['model'])(config, data_feature)
except AttributeError:
raise AttributeError('model is not found')
elif config['task'] == 'map_matching':
try:
return getattr(importlib.import_module('libcity.model.map_matching'),
config['model'])(config, data_feature)
except AttributeError:
raise AttributeError('model is not found')
elif config['task'] == 'road_representation':
try:
return getattr(importlib.import_module('libcity.model.road_representation'),
config['model'])(config, data_feature)
except AttributeError:
raise AttributeError('model is not found')
elif config['task'] == 'eta':
try:
return getattr(importlib.import_module('libcity.model.eta'),
config['model'])(config, data_feature)
except AttributeError:
raise AttributeError('model is not found')
else:
raise AttributeError('task is not found')
[docs]def get_evaluator(config):
"""
according the config['evaluator'] to create the evaluator
Args:
config(ConfigParser): config
Returns:
AbstractEvaluator: the loaded evaluator
"""
try:
return getattr(importlib.import_module('libcity.evaluator'),
config['evaluator'])(config)
except AttributeError:
raise AttributeError('evaluator is not found')
[docs]def get_logger(config, name=None):
"""
获取Logger对象
Args:
config(ConfigParser): config
name: specified name
Returns:
Logger: logger
"""
log_dir = './libcity/log'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_filename = '{}-{}-{}-{}.log'.format(config['exp_id'],
config['model'], config['dataset'], get_local_time())
logfilepath = os.path.join(log_dir, log_filename)
logger = logging.getLogger(name)
log_level = config.get('log_level', 'INFO')
if log_level.lower() == 'info':
level = logging.INFO
elif log_level.lower() == 'debug':
level = logging.DEBUG
elif log_level.lower() == 'error':
level = logging.ERROR
elif log_level.lower() == 'warning':
level = logging.WARNING
elif log_level.lower() == 'critical':
level = logging.CRITICAL
else:
level = logging.INFO
logger.setLevel(level)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(logfilepath)
file_handler.setFormatter(formatter)
console_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.info('Log directory: %s', log_dir)
return logger
[docs]def get_local_time():
"""
获取时间
Return:
datetime: 时间
"""
cur = datetime.datetime.now()
cur = cur.strftime('%b-%d-%Y_%H-%M-%S')
return cur
[docs]def ensure_dir(dir_path):
"""Make sure the directory exists, if it does not exist, create it.
Args:
dir_path (str): directory path
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
[docs]def trans_naming_rule(origin, origin_rule, target_rule):
"""
名字转换规则
Args:
origin (str): 源命名格式下的变量名
origin_rule (str): 源命名格式,枚举类
target_rule (str): 目标命名格式,枚举类
Return:
target (str): 转换之后的结果
"""
# TODO: 请确保输入是符合 origin_rule,这里目前不做检查
target = ''
if origin_rule == 'upper_camel_case' and target_rule == 'under_score_rule':
for i, c in enumerate(origin):
if i == 0:
target = c.lower()
else:
target += '_' + c.lower() if c.isupper() else c
return target
else:
raise NotImplementedError(
'trans naming rule only support from upper_camel_case to \
under_score_rule')
[docs]def preprocess_data(data, config):
"""
split by input_window and output_window
Args:
data: shape (T, ...)
Returns:
np.ndarray: (train_size/test_size, input_window, ...)
(train_size/test_size, output_window, ...)
"""
train_rate = config.get('train_rate', 0.7)
eval_rate = config.get('eval_rate', 0.1)
input_window = config.get('input_window', 12)
output_window = config.get('output_window', 3)
x, y = [], []
for i in range(len(data) - input_window - output_window):
a = data[i: i + input_window + output_window] # (in+out, ...)
x.append(a[0: input_window]) # (in, ...)
y.append(a[input_window: input_window + output_window]) # (out, ...)
x = np.array(x) # (num_samples, in, ...)
y = np.array(y) # (num_samples, out, ...)
train_size = int(x.shape[0] * (train_rate + eval_rate))
trainx = x[:train_size] # (train_size, in, ...)
trainy = y[:train_size] # (train_size, out, ...)
testx = x[train_size:x.shape[0]] # (test_size, in, ...)
testy = y[train_size:x.shape[0]] # (test_size, out, ...)
return trainx, trainy, testx, testy
[docs]def set_random_seed(seed):
"""
重置随机数种子
Args:
seed(int): 种子数
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True