1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
import torch from torch import nn from torch.nn import functional as F from torch import optim import torchvision from torchvision import transforms import matplotlib.pyplot as plt from import_file.utils import plot_image, plot_curve, one_hot
batch_size = 512
tf_compose_1 = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=False)
x, y = next(iter(train_loader)) print(x.shape, y.shape, x.min(), x.max())
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 256) self.fc2 = nn.Linear(256, 64) self.fc3 = nn.Linear(64, 10)
def forward(self, x): h1 = F.relu(self.fc1(x)) h2 = F.relu(self.fc2(h1)) h3 = self.fc3(h2) return h3
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9 ) train_loss = [] for epoch in range(3): for batch_idx, (x, y) in enumerate(train_loader): x = x.view(x.size(0), 28*28) out = net(x) y_onehot = one_hot(y) loss = F.mse_loss(out, y_onehot) train_loss.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 10 == 0: print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
total_correct = 0 for x, y in test_loader: x = x.view(x.size(0), 28*28) out = net(x) pred = out.argmax(dim=1) correct = pred.eq(y).sum().float().item() total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num print("准确率:", acc)
x, y = next(iter(test_loader)) out = net(x.view(x.size(0), 28*28)) pred = out.argmax(dim=1) plot_image(x, pred, 'test')
准确率: 0.8895
|