본문 바로가기
모두의 딥러닝/강의자료 정리

Lab-10-2: mnis-cnn

by cvlab_김준수 2025. 2. 13.

실제 CNN 구조를 만들어보자!

 

import torch
import torch.nn as nn

# [입력 만들기]
inputs = torch.Tensor(1, 1, 28, 28) # 배치 크기 1 / 채널 1 / 높이 28 / 너비 28

# [Convolution Layer(1) 만들기]
conv1 = nn.Conv2d(1, 32, 3, padding = 1) # 입력 채널 1 / 출력채널 32 / 커널 크기 3x3 / 패딩 1 / stride는 기본값인 1로

# [pooling 레이어 만들기]
pool = nn.MaxPool2d(2) # 커널 크기 2x2

# [Convolution Layer(2) 만들기]
conv2 = nn.Conv2d(32, 64, 3, padding = 1) # 입력 채널 32 / 출력채널 64 / 커널 크기 3x3 / 패딩 1 / stride는 기본값인 1로




# [CNN에 입력 넣기] : 합성공층 1 -> 풀링 -> 합성곱층 2 -> 풀
out = conv1(inputs) 
out.shape # torch.Size([1, 32, 28, 28]) -> 1개의 배치(이미지). 32의 채널, 28의 높이, 28의 넓

out = pool(out)
out.shape # torch.Size([1, 32, 14, 14])

out = conv2(out)
out.shape # torch.Size([1, 64, 14, 14])

out = pool(out)
out.shape # torch.Size([1, 64, 7, 7])



# [FC 연결을 위한 1차원 평탄화 과정] 
out.size(0) # 1
out = out.view(out.size(0), -1) 
out.shape # torch.Size([1, 3136]) -> 64*7*7 = 3136


# [FC 연결]
fc = nn.Linear(3136, 10)
out = fc(out)
out.shape # torch.Size([1, 10])

 

 

 

 

MNIST dataset을 통해서 CNN 모델을 학습 시키고, 테스트 해보자.

 

# [라이브러리 가져오기]
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init


# [GPU 사용 설정] --------------------------------------------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu' # 쿠다가 있으면 쿠다 없으면 cpu

torch.manual_seed(777)


#[parameter 결정] --------------------------------------------------------------
learning_rate = 0.001
training_epochs = 15
batch_size = 100


#[MNIST dataset 만들기] --------------------------------------------------------
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)


#[dataset 가져오기] ------------------------------------------------------------
data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)


#[학습모델 만들기: CNN] --------------------------------------------------------
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

    self.layer2 = nn.Sequential(
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

    self.fc = nn.Linear(7 * 7 * 64, 10, bias=True) # 입력값, 출력값, 바이어스 사용여부
    torch.nn.init.xavier_uniform_(self.fc.weight) # 가중치 초기화

  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    
    out = out.view(out.size(0), -1) # 배치사이즈만큼 펼치기
    out = self.fc(out)
    return out



#[학습모델 만들기: CNN] --------------------------------------------------------
model = CNN().to(device)
criterion = nn.CrossEntropyLoss().to(device) # 손실함수: 크로스 엔트로피
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 최적화: Adam


#[학습시키기] ------------------------------------------------------------------
total_batch = len(data_loader)

for epoch in range(training_epochs):
  avg_cost = 0

  for X, Y in data_loader: # X는 이미지, Y는 라벨
    X = X.to(device)
    Y = Y.to(device)

    optimizer.zero_grad() # 기울기 누적 방지
    hypothesis = model(X) # 모델이 입력 데이터 X에 대한 예측값(가설, hypothesis)을 생성

    cost = criterion(hypothesis, Y) # 오차 계산
    cost.backward() # 역전파
    optimizer.step() # 가중치 업데이트

    avg_cost += cost / total_batch # 전체 배치에 대한 평균 손실

  print('[Epoch: {}] cost = {}'.format(epoch + 1, avg_cost))
print("학습 완료")

 

 

테스트

# [모델 테스트하기]
with torch.no_grad():
  X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
  Y_test = mnist_test.test_labels.to(device)

  prediction = model(X_test)
  correct_prediction = torch.argmax(prediction, 1) == Y_test
  accuracy = correct_prediction.float().mean()
  print('Accuracy:', accuracy.item())