在人工智能的训练中,可能存在很多参数,其中一些是默认参数(如数据的路径,保存checkpoint的训练轮数,batchsize等等),也有一些是针对某个训练使用的参数(如读取数据的分辨率,读取的类别条件)。在过去的代码中,一个普遍的解决方法是:使用BaseOption, TrainOption, 和TestOption。其中TrainOption和TestOption继承BaseOption。而Openai的代码提供了一个新的思路。很难直接说谁的更好,记录一下。
1. 前置条件
首先要知道在python中,字典可以在被**后作为函数的参数传入,函数会自行在被传入的众多参数里面找到自己要的。例如说:
Big_parameter_sets = dict{a=1, b=2, c=3}
def Small_function(a, b):
return a+b
output = Small_function(**Big_parameter_sets) 可以发现Small_function这个函数并没有完全用到abc三个参数,总之,字典的键要和参数名同名,且是该函数的母集。
2. 具体过程
在主函数里,代码这样得到参数的:
args = create_argparser().parse_args()可以看到create_argparser这个函数返回了一个parser,然后这个parser的类再调用parse_args函数,然后得到传递的参数。于是我们找到create_argparser这个函数。
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1,
ema_rate="0.9999",
log_interval=10,
save_interval=10000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(sr_model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser其中的defaults这个字典,有一个神经网络最基础的一些参数:数据路径,学习率,权重缩减……
接下来,使用字典类的一个方法:update。这个方法有点像合并两个字典,字典来自于一个叫做sr_model_and_diffusion_defaults的函数,于是再找到这个函数。
def sr_model_and_diffusion_defaults():
res = model_and_diffusion_defaults()
res["large_size"] = 256
res["small_size"] = 64
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
for k in res.copy().keys():
if k not in arg_names:
del res[k]
return res这里一个函数model_and_diffusion_defaults的输出被赋给了res这个变量,因为嵌套了太多函数,这个函数就不再写在这里,这个函数的本质也是直接返回了一个字典。之后又把256赋给了”large_size”这个键,把64赋给了”small_size“这个键。接下来一句话比较少见:
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
这句话的目的就是 获得sr_create_model_and_diffusion这个函数的所有参数(使用[0]的原因)。这个函数配合接下来的for循环的目的是删除res 中所有不是sr_create_model_and_diffusion的参数的键还有它的值。
总的来说,这里的目的是:因为我们最终需要跑的是sr_create_model_and_diffusion这个函数,而model_and_diffusion_defaults这个函数提供的参数和我们跑的函数间是交集的关系,因此删除这些键值对可以减少内存的占用,但是我个人认为如果字典只有几十或者十几个键值对,这种操作其实没有很大的必要性。
于是我们返回res,并把它和defaults合并在一起。
parser = argparse.ArgumentParser()
parser = add_dict_to_argparser(parser, defaults)
这句话创造了一个parser,第二句话对这个parser进行了一些改动,因为字典终究是字典,还要想办法把字典的值一一赋值给parser才行,进入到这个函数:
def add_dict_to_argparser(parser, default_dict):
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)其中应该只有一个问题比较难看懂,就是str2bool,这个主要是为了防止当参数是bool类型的时候,只要你输入了,输入会变成字符串,然后不管是输入了什么,都会表示为True。因此需要用这个将输入的True or False转换成参数的。最后,只需要在使用的时候,再将args转换成字典就好了:
model, diffusion = sr_create_model_and_diffusion(
**args_to_dict(args, sr_model_and_diffusion_defaults().keys())
)
def args_to_dict(args, keys):
return {k: getattr(args, k) for k in keys}最后这个args_to_dict函数其实是一个字典的推导式。它的意思是,对于传入的keys(也就是sr_model_and_diffusion_defaults的所有参数,是个列表)k,从parser的args中找到对应的参数值,建立从k到这个参数值的新映射,其实就是把parser的形态再转换回来。