Softmax란?
다중 클래스를 분류할 때 사용하는 활성화 함수로, 각 클래스에 대한 확률을 도출해준다.
위와 같은 수식을 사용하며, 지수 함수를 사용해서 확률을 도출한다.
지수함수가 사용되는 이유는 미분이 가능하도록 하게 함이며, 입력값 중 큰 값은 더 크게 작은 값은 더 작게 만들어 입력벡터가 더 잘 구분되게 하기 위함이다.
위의 사진과 같이 각 클래스에 따른 확률을 도출하며, 모든 확률을 더하면 1이 나오게된다. 또한, 오차는 크로스 엔트로피 손실함수를 사용해서 모델을 학습시킨다. 이를 코드로 작성하면 아래와 같다.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# For reproducibility
torch.manual_seed(1)
# (1) 데이터 로드
xy = np.loadtxt('data-04-zoo.csv', delimiter=',', dtype=np.float32)
x_train = torch.FloatTensor(xy[:, 0:-1]) # 특성 데이터 (입력값)
y_train = torch.LongTensor(xy[:, -1]).squeeze() # 정답 라벨 (마지막 열)
# [첫 다섯 개 샘플]
# tensor([[1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 4., 0., 0., 1.],
# [1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 4., 1., 0., 1.],
# [0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0.],
# [1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 4., 0., 0., 1.],
# [1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 4., 1., 0., 1.]])
# [첫 다섯 개 정답]
# tensor([0, 0, 3, 0, 0])
# 클래스의 수
nb_classes = 7
# (2) 모델 정의
class SoftmaxClassifierModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(16, nb_classes) # 입력 특성 16개, 출력 클래스 7개로 구성된 선형 계층
def forward(self, x):
return self.linear(x) # 순전파에서 선형 계층을 통과시켜 출력값 반환
# 모델 인스턴스 생성
model = SoftmaxClassifierModel()
# (3) 옵티마이저 설정
optimizer = optim.SGD(model.parameters(), lr=0.1) # 확률적 경사 하강법(SGD) 옵티마이저, 학습률 0.1
# (4) 학습 설정
nb_epochs = 1000 # 총 학습 에포크 수
for epoch in range(nb_epochs + 1):
# 순전파: 모델을 통해 예측값을 계산
prediction = model(x_train)
# 손실 계산: 크로스 엔트로피 손실 함수 적용
cost = F.cross_entropy(prediction, y_train)
# 역전파 및 가중치 업데이트
optimizer.zero_grad() # 이전 단계의 기울기 초기화
cost.backward() # 손실 함수의 기울기 계산
optimizer.step() # 옵티마이저를 통해 가중치 업데이트
# 주기적으로 손실 출력 (100 에포크마다)
if epoch % 100 == 0:
print(f'Epoch {epoch:4d}/1000 Cost: {cost.item():.6f}')
'모두의 딥러닝 > 강의자료 정리' 카테고리의 다른 글
Lab 07-2: MNIST Introduction (1) | 2025.01.26 |
---|---|
Lab 07-1: Tips(Overfitting, Preprocessing Data) (2) | 2025.01.21 |
Lab_06(준비): Cross Entropy 손실함수 (3) | 2025.01.20 |
Lab_05: Logistic Regression (1) | 2025.01.16 |
Lab_04-2: Loading Data(mini batch) (0) | 2025.01.15 |