Post

Torch Train with Epoch

Torch Train with Epoch

使用epoch

train

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def train_epoch(model, trainer, loss_fn, dataloader, device):
    """
    train一个epoch
    return: 平均loss
    """
    model.train()
    loss_sum = 0.0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        trainer.zero_grad()
        y_hat = model(X)
        l = loss_fn(y_hat, y)
        loss_sum += l.item()
        l.backward()
        trainer.step()
    return loss_sum / len(dataloader)

val

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def val_epoch(model, loss_fn, dataloader, device):
    """
    eval一个epoch
    return: 评估指标
    """
    model.eval()
    correct = 0.0
    # loss_sum = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            # acc
            preds = y_hat.argmax(1)
	        accuracy = torch.mean((preds == y).float()).item()
            """
            # loss
            l = loss_fn(y_hat, y)
            loss_sum += l.item()
            """
    return correct / len(dataloader.dataset)
    # return loss_sum / len(dataloader)

val_epoch使用forward_batch(见train with steps)简洁代码

1
2
3
4
5
6
7
8
9
def val_epoch(model, trainer, loss_fn, dataloader, device):
    model.eval()
	running_loss, running_accuracy = 0.0, 0.0
    for batch in dataloader:
		with torch.no_grad():
			l, accuracy = forward_batch(batch, model, loss_fn, device)
			running_loss += l.item()
			running_accuracy += accuracy.item()
    return running_accuracy / len(dataloader)

total

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def train_and_eval(epoches, model, trainer, loss_fn, 
                    train_dataloader, val_dataloader, device):
    """
    训练并评估, 每个epoch记录一次metric
    """
    loss_list, acc_list = [], []
    for epoch in tqdm(range(epoches)):
        loss_list.append(
            train_epoch(model, trainer, loss_fn, train_dataloader, device))
        acc_list.append(
            val_epoch(model, loss_fn, val_dataloader, device))
    print(f'loss:{loss_list[-1]:.4f} acc:{acc_list[-1]:.4f}')
    plt.plot(list(range(epoches)), loss_list, label='loss')
    plt.plot(list(range(epoches)), acc_list, label='acc')
    plt.legend(fontsize='small', loc='best')
    plt.grid(True)
    plt.show()
This post is licensed under CC BY 4.0 by the author.