Source code for libcity.utils.argument_list

"""
store the arguments can be modified by the user
"""
import argparse

general_arguments = {
    "gpu": {
        "type": "bool",
        "default": None,
        "help": "whether use gpu"
    },
    "gpu_id": {
        "type": "int",
        "default": None,
        "help": "the gpu id to use"
    },
    "train_rate": {
        "type": "float",
        "default": None,
        "help": "the train set rate"
    },
    "eval_rate": {
        "type": "float",
        "default": None,
        "help": "the validation set rate"
    },
    "batch_size": {
        "type": "int",
        "default": None,
        "help": "the batch size"
    },
    "learning_rate": {
        "type": "float",
        "default": None,
        "help": "learning rate"
    },
    "max_epoch": {
        "type": "int",
        "default": None,
        "help": "the maximum epoch"
    },
    "dataset_class": {
        "type": "str",
        "default": None,
        "help": "the dataset class name"
    },
    "executor": {
        "type": "str",
        "default": None,
        "help": "the executor class name"
    },
    "evaluator": {
        "type": "str",
        "default": None,
        "help": "the evaluator class name"
    },
}

hyper_arguments = {
    "gpu": {
        "type": "bool",
        "default": None,
        "help": "whether use gpu"
    },
    "gpu_id": {
        "type": "int",
        "default": None,
        "help": "the gpu id to use"
    },
    "train_rate": {
        "type": "float",
        "default": None,
        "help": "the train set rate"
    },
    "eval_rate": {
        "type": "float",
        "default": None,
        "help": "the validation set rate"
    },
    "batch_size": {
        "type": "int",
        "default": None,
        "help": "the batch size"
    }
}


[docs]def str2bool(s): if isinstance(s, bool): return s if s.lower() in ('yes', 'true'): return True elif s.lower() in ('no', 'false'): return False else: raise argparse.ArgumentTypeError('bool value expected.')
[docs]def str2float(s): if isinstance(s, float): return s try: x = float(s) except ValueError: raise argparse.ArgumentTypeError('float value expected.') return x
[docs]def add_general_args(parser): for arg in general_arguments: if general_arguments[arg]['type'] == 'int': parser.add_argument('--{}'.format(arg), type=int, default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) elif general_arguments[arg]['type'] == 'bool': parser.add_argument('--{}'.format(arg), type=str2bool, default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) elif general_arguments[arg]['type'] == 'float': parser.add_argument('--{}'.format(arg), type=str2float, default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) elif general_arguments[arg]['type'] == 'str': parser.add_argument('--{}'.format(arg), type=str, default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) elif general_arguments[arg]['type'] == 'list of int': parser.add_argument('--{}'.format(arg), nargs='+', type=int, default=general_arguments[arg]['default'], help=general_arguments[arg]['help'])
[docs]def add_hyper_args(parser): for arg in hyper_arguments: if hyper_arguments[arg]['type'] == 'int': parser.add_argument('--{}'.format(arg), type=int, default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) elif hyper_arguments[arg]['type'] == 'bool': parser.add_argument('--{}'.format(arg), type=str2bool, default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) elif hyper_arguments[arg]['type'] == 'float': parser.add_argument('--{}'.format(arg), type=str2float, default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) elif hyper_arguments[arg]['type'] == 'str': parser.add_argument('--{}'.format(arg), type=str, default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) elif hyper_arguments[arg]['type'] == 'list of int': parser.add_argument('--{}'.format(arg), nargs='+', type=int, default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help'])