Torch Modify Model
Torch Modify Model
Modify Model
torch model is a callable object, u can visit it as a list or tuple
change model to a nn.Sequential
- model_seq = nn.Sequential(model.children()[:])
get any layer`s output
create_feature_extractor
1
2
3
4
5
from torchvision.models.feature_extraction import create_feature_extractor
model_ex = create_feature_extractor(model, return_nodes=
dict{"layer name": "key"}
)
hook
1
2
3
4
5
6
def forward_hook(model, X_in, X_out):
pass
handle = layer.register_forward_hook(forward_hook)
# do something
handle.remove()
hook introduce
register_forward_hook
1
2
3
4
def forward_hook(module, X_input, X_output):
pass
handle = layer.register_forward_hook(forward_hook)
register_backward_hook
1
2
3
4
def backward_hook(module, grad_input, grad_output):
pass
handle = layer.register_backward_hook(backward_hook)
register_param_hook
1
2
3
4
def param_hook(param):
pass
# `param` 是模型中的一个参数
handle = param.register_hook(param_hook)
This post is licensed under CC BY 4.0 by the author.