Post

Torch Train with Steps

Torch Train with Steps

使用steps

forward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def forward_batch(batch, model, loss_fn, device):
	"""
    forward一个batch
    return: metric
    """
	X, y = batch
	X, y = X.to(device), y.to(device)
	y_hat = model(X)
	l = loss_fn(y_hat, y)
	
	preds = y_hat.argmax(1)
	accuracy = torch.mean((preds == y).float())
	
	return l, acc

total

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def train_and_eval(total_steps, valid_steps, 
					model, optimizer, loss_fn, scheduler, 
					train_dataloader, val_dataloader, device):
    """
	total_steps 约等于 len(dataloader) * epochs
	"""
	train_iter = iter(train_dataloader)
    for step in range(total_steps):
		# 读一个batch的数据
		try:
			batch = next(train_iter)
		except StopIteration:
			train_iter = iter(train_dataloader)
			batch = next(train_iter)

		l, acc = forward_batch(batch, model, loss_fn, device)
		batch_loss = l.item()
		batch_accuracy = acc.item()

		l.backward()
		optimizer.step()
		# scheduler.step() # optional
		optimizer.zero_grad()

        if (step + 1) % valid_steps == 0:
			valid_accuracy = val_epoch(val_dataloader, model, loss_fn, device)
			# 可添加画图和输出可视化metric

val_epoch使用forward_batch简洁代码

1
2
3
4
5
6
7
8
9
def val_epoch(model, trainer, loss_fn, dataloader, device):
    model.eval()
	running_loss, running_accuracy = 0.0, 0.0
    for batch in dataloader:
		with torch.no_grad():
			l, accuracy = forward_batch(batch, model, loss_fn, device)
			running_loss += l.item()
			running_accuracy += accuracy.item()
    return running_accuracy / len(dataloader)
This post is licensed under CC BY 4.0 by the author.