이번에는 파이토치로 선형 회귀를 해보겠습니다.
선형 회귀는 주어진 데이터 집합에 대해, 종속 변수와 독립 변수들 사이의 선형 관계를 모델링합니다.
문제
아래 표와 같이 어떤 차량이 이동하는데 1시간 후 70km, 2시간 후 140km, 3시간 후 210km를 이동했다면,
4시간 후에는 몇 km를 이동할지에 대해 예측해봅시다.
이동 시간 (H) | 이동 거리 (Km) |
1 | 70 |
2 | 140 |
3 | 210 |
4 | ? |
선형 회귀
위의 식은 독립 변수가 하나 이므로
선형 회귀 모델은 $y = wx + b$ 형태를 갖습니다. ( w는 가중치 b는 편향입니다. )
본격적으로 위의 문제를 파이토치 코드로 나타내 간단한 선형회귀를 해보겠습니다.
먼저, 필요한 라이브러리를 불러오겠습니다.
import torch
import torch.nn as nn
import torch.optim as optim
데이터
위의 이동 시간과 이동 거리를 각각 훈련 데이터로 표현하면 아래와 같습니다.
# 데이터
X = torch.FloatTensor([[1], [2], [3]]) # 이동 시간
Y = torch.FloatTensor([[70], [140], [210]]) # 이동 거리
가중치, 편향
다음으로 모델의 가중치와 편향을 0으로 초기화 해줍니다.
가중치와 편향은 모델이 학습 과정 중에 스스로 갱신할 것입니다.
# 가중치, 편향 초기화
W = torch.zeros(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
손실함수, 최적화 기법
손실 함수는 평균 제곱 오차 (Mean Squared Error, MSE) 그리고 최적화 기법은 가장 기본적인 경사 하강법을 사용하겠습니다.
# 손실 함수, 최적화 기법
loss_fn = nn.MSELoss()
optimizer = optim.SGD([W,b], lr = 0.01)
(이번 포스팅의 목적은 기본적인 선형 회귀를 한번 해보는 것이기에, 손실 함수나 최적화 기법에는 주목하지 않겠습니다.)
모델 훈련
모델의 훈련은 예측 => 손실 계산 => 역전파를 통한 기울기 계산 => 파라미터 갱신 의 과정으로 이뤄집니다.
코드로 나타내면 아래와 같습니다.
# 모델 훈련
num_epochs = 3000
for epoch in range(num_epochs):
model_pred = W * X + b
loss = loss_fn(model_pred, Y)
optimizer.zero_grad() # Gradient 초기화
loss.backward() # 역전파를 통한 기울기 계산
optimizer.step() # Parameter Update
( 주의! optimizer.zero_grad()를 호출하여 모델 매개변수의 변화도를 재설정해야 합니다. 기본적으로 변화도는 더해지기(add up) 때문에 중복 계산을 막기 위해 반복할 때마다 명시적으로 0으로 설정합니다. )
결과
모델을 훈련시킨 후 파라미터를 확인해보겠습니다.
print(W, b)
# 출력 결과
tensor([69.9916], requires_grad=True) tensor([0.0190], requires_grad=True)
파라미터가 구해졌으니 모델을 통해 위의 문제의 결과를 예측해봅시다.
x_test = 4 # 4시간 이동
model_pred = W * x_test + b
print(model_pred)
# 출력 결과
tensor([279.9856], grad_fn=<AddBackward0>)
우리가 만든 모델은 위의 자동차가 4시간 후에는 279.9856km 를 이동할 것으로 예측했습니다.
이번 포스팅에서는 아주 간단히 파이토치를 활용해 선형회귀를 해봤습니다.
앞으로도 블로그에 파이토치에 관련해 최대한 이렇게 쉬운 내용부터 점차 심화된 내용까지 글을 써보려고 합니다.
읽어주셔서 감사합니다. ^^
'머신러닝, 딥러닝 ML, DL > Pytorch' 카테고리의 다른 글
[Pytorch] Autograd, 자동 미분 ( requires_grad, backward(), 예시 ) (2) | 2021.10.02 |
---|---|
[Pytorch] 파이토치, 텐서(Tensor)란 (텐서 속성, 텐서 초기화) (0) | 2021.09.24 |
[Pytorch] 파이토치로 GAN(Generative Adversarial Net) 구현하고 MNIST 손글씨 이미지 생성하기 (0) | 2021.06.30 |