1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
epoch = 32 batch_size = 64 lr = 0.01 ...
trans = transforms.Compose([transforms.ToTensor(), transforms.Resize(256), transform.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))], ...)
dataset=Dataset(...) dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=0)
net = Net().to(device) loss = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch in range(epoch): for images,labels in dataloader: y_hat = net(images) l = loss(y_hat,labels) l.backward() optimizer.step() ... torch.save(net.state_dict(), f"./checkpoints/{epoch}_weights.pth")
|