Artificial Intelligence/머신러닝-딥러닝

파이토치 분산 학습하기 (Pytorch Distributed Training)

roytravel 2022. 11. 12. 16:07

 

1. 분산 학습 개요

최근 몇 년간 Large Language Model을 만드는 추세가 계속해서 이어지고 있다. 이런 거대 모델의 경우 파라미터를 역전파로 업데이트하기 위해 많은 양의 메모리와 컴퓨팅 파워가 필요하다. 따라서 여러 프로세서에 분산시켜 모델을 학습하는 분산 학습이 필요하다. 분산 학습을 통해 CPU 또는 GPU 상의 학습 속도 향상을 이룰 수 있다. 많은 사람들이 사용하는 딥러닝 라이브러리인 파이토치에서 이런 분산 학습을 돕는 아래 API들이 있다.

 

1. torch.multiprocessing

  • 여러 파이썬 프로세스를 생성하는 역할. 일반적으로 CPU나 GPU 코어 수 만큼 프로세스 생성 가능

2. torch.distributed

  • 분산 학습을 진행할 수 있도록 각 프로세스 간의 통신을 가능하게 하는 일종의 IPC 역할

3. torch.utils.data.distributed.DistributedSampler

  • 학습 데이터셋을 프로세스 수 만큼 분할해 분산 학습 세션의 모든 프로세스가 동일한 양의 데이터로 학습하도록 만드는 역할.
  • 프로세스 수만큼 나누기 위해 world_size라는 인자를 사용. 

4. torch.nn.parallel.DistributedDataParallel

해당 API는 내부적으로 5가지 동작이 이뤄짐

  • 분산 환경에서 각 프로세스마다 고유한 모델 사본이 생성됨.
  • 고유 모델 사본 별 자체 옵티마이저를 갖고, 전역 이터레이션과 동기화됨
  • 각 분산 학습 이터레이션에서 개별 loss를 통해 기울기가 계산되고, 각 프로세스의 기울기 평균을 구함
  • 평균 기울기는 매개변수를 조정하는 각 모델 복사본에 전역으로 역전파됨.
  • 전역 역전파 때문에 모든 모델의 매개변수는 이터레이션마다 동일하도록 자동으로 동기화됨. 

 

2. 분산 학습 루틴 정의

그렇다면 실제로 위 API를 이용해 어떻게 분산 학습 루틴을 정의할 수 있을까? 아래 예시 코드를 통해 알아보자.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import torch.multiprocessing as mp
import torch.distributed as dist

import os
import time
import argparse

class ConvNet(nn.Module):
	pass
    
def train(gpu_num, args):
    rank = args.machine_id * args.num_gpu_processes + gpu_num                        
    dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) 
    torch.manual_seed(0)
    model = ConvNet()
    torch.cuda.set_device(gpu_num)
    model.cuda(gpu_num)
    criterion = nn.NLLLoss().cuda(gpu_num) # nll is the negative likelihood loss
    
    train_dataset = datasets.MNIST('../data', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1302,), (0.3069,))]))  
                                       
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=rank
    )
    train_dataloader = torch.utils.data.DataLoader(
       dataset=train_dataset,
       batch_size=args.batch_size,
       shuffle=False,            
       num_workers=0,
       pin_memory=True,
       sampler=train_sampler)
       
    optimizer = optim.Adadelta(model.parameters(), lr=0.5)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu_num])
    model.train()
    for epoch in range(args.epochs):
        for b_i, (X, y) in enumerate(train_dataloader):
            X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
            ... 
            if b_i % 10 == 0 and gpu_num==0:
                print (...)
            
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-machines', default=1, type=int,)
    parser.add_argument('--num-gpu-processes', default=1, type=int)
    parser.add_argument('--machine-id', default=0, type=int)
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--batch-size', default=128, type=int)
    args = parser.parse_args()
    
    args.world_size = args.num_gpu_processes * args.num_machines                
    os.environ['MASTER_ADDR'] = '127.0.0.1'              
    os.environ['MASTER_PORT'] = '8892'      
    start = time.time()
    mp.spawn(train, nprocs=args.num_gpu_processes, args=(args,))
    print(f"Finished training in {time.time()-start} secs")
    
if __name__ == '__main__':
    main()

 

ConvNet 클래스는 코드가 길어지지 않도록 한 것으로 모종의 컨볼루션 레이어가 들어있다고 가정한다. 분산 학습의 루틴이 되는 로직의 핵심은 train 함수에 들어있다. 살펴보면 가장 먼저 첫번째로 rank가 할당된다. rank는 전체 분산시스템에서 프로세스 순서를 지칭하는 것이다. 예를 들어 4-CPU를 가진 시스템 2개가 있다면 8개의 프로세스를 생성할 수 있고 각 프로세스는 0~7번까지 고유 번호를 가질 것이다. rank를 구하기 위해 사용하는 식은 $rank = n*4 + k$이다. 여기서 n은 시스템 번호(0, 1)이 되고, k는 프로세스 번호(0, 1, 2, 3)이다. 참고로 rank가 0인 프로세스만 학습에 대한 로깅을 출력한다. 그 이유는 rank가 0인 프로세스가 다른 프로세스와의 통신의 주축이 되기 때문이다. 만약 그렇지 않다면 프로세스 개수만큼 로그가 출력될 것이다.

 

두 번째로 dist.init_process_group 메서드가 보인다. 이는 분산 학습을 진행하는 각 프로세스 간의 통신을 위해 사용한다.  정확히는 매개변수로 들어가는 backend 인자가 그 역할을 한다. pytorch에서 지원하는 backend 인자에는 Gloo, NCCL, MPI가 있다. 간단히 언급하면 주로 Gloo는 CPU 분산 학습에 NCCL은 GPU 분산 학습에, MPI는 고성능 분산 학습에 사용된다. init method는 각 프로세스가 서로 탐색하는 방법으로 URL이 입력되며 기본값으로 env://이 설정된다. world_size는 분산 학습에 사용할 전체 프로세스 수다. world_size 수 만큼 전체 학습 데이터셋 수가 분할된다.

 

세 번째로 torch.utils.data.distributed.DistributedSampler다. world_size 수 만큼 데이터셋을 분할하고 모든 프로세스가 동일한 양의 데이터셋을 갖도록 한다. 이후에 DataLoader가 나오는데 shuffle=False로 설정한 것은 프로세스 간 처리할 데이터셋의 중복을 피하기 위함이다. 

 

네 번째로 torch.nn.parallel.DistributedDataParallel다. 분산 환경에서 사용할 각각의 모델 복사본을 생성한다. 생성된 각 모델 복사본은 각자 옵티마이저를 갖고, loss function으로부터 기울기를 계산하고 rank 0을 가지는 프로세스와 통신하여 기울기의 평균을 구하고 rank 0을 갖는 프로세스로부터 평균 기울기를 받아 역전파를 수행한다. 이 DistributedDataparallel을 사용하면 각각 독립된 프로세스를 생성하므로 파이썬 속도 한계의 원인인 GIL 제약이 사라져 모델 학습 속도를 늘릴 수 있다는 장점이 있다. 

 

다섯 번째로 MASTER_ADDR은 rank가 0인 프로세스를 실행하는 시스템의 IP 주소를 의미하고 MASTER_PORT는 그 장치에서 사용 가능한 PORT를 의미한다. 이는 앞서 언급했듯 rank 0인 시스템이 모든 백엔드 통신을 설정하기 때문에 다른 프로세스들이 호스팅 시스템을 찾기 위해 사용한다. IP에 따라 local이나 remote를 설정해 사용할 수 있다. 

 

마지막으로 데이터로더에 pin_memory=True는 데이터셋이 로딩 된 장치(ex: CPU)에서 다양한 장치(GPU)로 데이터를 빠르게 전송할 수 있도록 한다. 예를 들어 데이터셋이 CPU가 사용하는 고정된 page-lock 메모리 영역에 할당되어 있다면 GPU는 이 CPU의 page-lock 메모리 영역을 참조하여 학습도중 필요한 데이터를 복사해 사용한다. pin_memory=True는 참고로 학습 루틴 상에서 non_blocking=True라는 매개변수와 함께 동작한다. 결과적으로 GPU 학습 속도가 향상되는 효과를 가져온다.