import os
import json
import torch
[docs]class ConfigParser(object):
"""
use to parse the user defined parameters and use these to modify the
pipeline's parameter setting.
值得注意的是,目前各阶段的参数是放置于同一个 dict 中的,因此需要编程时保证命名空间不冲突。
config 优先级:命令行 > config file > default config
"""
def __init__(self, task, model, dataset, config_file=None,
saved_model=True, train=True, other_args=None, hyper_config_dict=None):
"""
Args:
task, model, dataset (str): 用户在命令行必须指明的三个参数
config_file (str): 配置文件的文件名,将在项目根目录下进行搜索
other_args (dict): 通过命令行传入的其他参数
"""
self.config = {}
self._parse_external_config(task, model, dataset, saved_model, train, other_args, hyper_config_dict)
self._parse_config_file(config_file)
self._load_default_config()
self._init_device()
def _parse_external_config(self, task, model, dataset,
saved_model=True, train=True, other_args=None, hyper_config_dict=None):
if task is None:
raise ValueError('the parameter task should not be None!')
if model is None:
raise ValueError('the parameter model should not be None!')
if dataset is None:
raise ValueError('the parameter dataset should not be None!')
# 目前暂定这三个参数必须由用户指定
self.config['task'] = task
self.config['model'] = model
self.config['dataset'] = dataset
self.config['saved_model'] = saved_model
self.config['train'] = False if task == 'map_matching' else train
if other_args is not None:
# TODO: 这里可以设计加入参数检查,哪些参数是允许用户通过命令行修改的
for key in other_args:
self.config[key] = other_args[key]
if hyper_config_dict is not None:
# 超参数调整时传入的待调整的参数,优先级低于命令行参数
for key in hyper_config_dict:
self.config[key] = hyper_config_dict[key]
def _parse_config_file(self, config_file):
if config_file is not None:
# TODO: 对 config file 的格式进行检查
if os.path.exists('./{}.json'.format(config_file)):
with open('./{}.json'.format(config_file), 'r') as f:
x = json.load(f)
for key in x:
if key not in self.config:
self.config[key] = x[key]
else:
raise FileNotFoundError(
'Config file {}.json is not found. Please ensure \
the config file is in the root dir and is a JSON \
file.'.format(config_file))
def _load_default_config(self):
# 首先加载 task config
with open('./libcity/config/task_config.json', 'r') as f:
task_config = json.load(f)
if self.config['task'] not in task_config:
raise ValueError(
'task {} is not supported.'.format(self.config['task']))
task_config = task_config[self.config['task']]
# check model and dataset
if self.config['model'] not in task_config['allowed_model']:
raise ValueError('task {} do not support model {}'.format(
self.config['task'], self.config['model']))
model = self.config['model']
# 加载 dataset、executor、evaluator 的模块
if 'dataset_class' not in self.config:
self.config['dataset_class'] = task_config[model]['dataset_class']
if self.config['task'] == 'traj_loc_pred' and 'traj_encoder' not in self.config:
self.config['traj_encoder'] = task_config[model]['traj_encoder']
if self.config['task'] == 'eta' and 'eta_encoder' not in self.config:
self.config['eta_encoder'] = task_config[model]['eta_encoder']
if 'executor' not in self.config:
self.config['executor'] = task_config[model]['executor']
if 'evaluator' not in self.config:
self.config['evaluator'] = task_config[model]['evaluator']
# 对于 LSTM RNN GRU 使用的都是同一个类,只是 RNN 模块不一样而已,这里做一下修改
if self.config['model'].upper() in ['LSTM', 'GRU', 'RNN']:
self.config['rnn_type'] = self.config['model']
self.config['model'] = 'RNN'
# if self.config['dataset'] not in task_config['allowed_dataset']:
# raise ValueError('task {} do not support dataset {}'.format(
# self.config['task'], self.config['dataset']))
# 接着加载每个阶段的 default config
default_file_list = []
# model
default_file_list.append('model/{}/{}.json'.format(self.config['task'], self.config['model']))
# dataset
default_file_list.append('data/{}.json'.format(self.config['dataset_class']))
# executor
default_file_list.append('executor/{}.json'.format(self.config['executor']))
# evaluator
default_file_list.append('evaluator/{}.json'.format(self.config['evaluator']))
# 加载所有默认配置
for file_name in default_file_list:
with open('./libcity/config/{}'.format(file_name), 'r') as f:
x = json.load(f)
for key in x:
if key not in self.config:
self.config[key] = x[key]
# 加载数据集config.json
with open('./raw_data/{}/config.json'.format(self.config['dataset']), 'r') as f:
x = json.load(f)
for key in x:
if key == 'info':
for ik in x[key]:
if ik not in self.config:
self.config[ik] = x[key][ik]
else:
if key not in self.config:
self.config[key] = x[key]
def _init_device(self):
use_gpu = self.config.get('gpu', True)
gpu_id = self.config.get('gpu_id', 0)
if use_gpu:
torch.cuda.set_device(gpu_id)
self.config['device'] = torch.device(
"cuda:%d" % gpu_id if torch.cuda.is_available() and use_gpu else "cpu")
[docs] def get(self, key, default=None):
return self.config.get(key, default)
def __getitem__(self, key):
if key in self.config:
return self.config[key]
else:
raise KeyError('{} is not in the config'.format(key))
def __setitem__(self, key, value):
self.config[key] = value
def __contains__(self, key):
return key in self.config
# 支持迭代操作
def __iter__(self):
return self.config.__iter__()