본문 바로가기
AI/Paper Review

[논문 리뷰] DoRA: Weight-Decomposed Low-Rank Adaptation

by 벵자민 2024. 2. 19.
728x90

논문 링크 : https://arxiv.org/abs/2402.09353

 

DoRA: Weight-Decomposed Low-Rank Adaptation

Among the widely used parameter-efficient finetuning (PEFT) methods, LoRA and its variants have gained considerable popularity because of avoiding additional inference costs. However, there still often exists an accuracy gap between these methods and full

arxiv.org

 

현재 주목받는 효과적인 파라미터 최적화 방법인 LoRA는 저순위 분해(low-rank subset)를 통해 LLM의 가중치 행렬을 업데이트하여 계산 효율성을 높이고 추가적인 연산 비용 없이 모델 성능 향상을 가져오지만, 완전한 Fine-tuning에 비해 정확도 격차가 존재하는 한계점이 있습니다.

 

 이러한 한계를 극복하고자 DoRA(Weight-Decomposed Low-Rank Adaptation)라는 새로운 방법의 논문이 발표되었습니다. DoRA는 기존 훈련된 가중치 행렬을 크기(magnitude)와 방향(direction) 벡터로 분해하여 파라미터 효율성을 유지하면서, 방향 업데이트(Directional updates)에만 LoRA를 이용하여 학습 능력을 향상시킵니다.

 

결과적으로 DoRA는 LoRA보다 더 높은 정확도를 달성하면서도 추가적인 연산 비용 없이 안정적으로 학습이 가능합니다. 다양한 언어 모델과 과제에서 실험을 통해 DoRA가 LoRA를 지속적으로 뛰어넘는 성능을 보였습니다.

 

Introduction

Intro에서는 기존 Full Fine-Tuning된 모델의 발전으로 NLP 분야에 상당한 발전이 있었으며 모델과 데이터셋의 규모가 확장됨에 따라 FT(Fine-tuning)의 비용이 엄청나게 커졌고, 이 문제를 해결하기 위해 PEFT(parameter-efficient fine-tuning, 2019) 방법, 그 중에서도 LoRA의 간편함과 효능에 대해서 언급하지만 여전히 FT와의 성능 격차가 존재하며 이를 해결하기 위해  DoRA(Weight-Decomposed Low-Rank Adaptation)방법을 소개하고있습니다. 

DoRA(Weight-Decomposed Low-Rank Adaptation) 구조

 

DoRA는 사전 학습된 가중치를 크기(magnitude)와 방향(direction) 성분으로 분해하여 파라미터의 효율성을 유지하면서도 학습 과정을 단순화합니다. 크기는 전체적으로 조정되지만 방향(Direction)은 LoRA를 사용하여 효율적으로 업데트됩니다. 이를 통해 FT(Full Fine-tuning)과 유사한 학습 효과를 달성하면서도 컴퓨터 자원 사용량을 줄여주는 효과가 있다고 합니다.

LoRA 대비 성능 항상은 LLaMA-7B/13B, LLaVA-7B, VL-BART/ViLBERT 모델로 상식추론, 시각적 지시 조정, 이미지/영상/텍스트 이해에 대해서 평가하였고 각각 다음과 같이 향상 했습니다.

  • 상식 추론
    • LLaMA-7B 모델에서 +3.4% 향상
    • LLaMA-13B 모델에서 +1.0% 향상
  • 시각적 지시 조정
    • LLaVA-7B 모델에서 +0.6% 향상
  • 이미지/영상/텍스트 이해
    • VL-BART 모델에서 +0.9% 향상
    • ViLBERT 모델에서 +1.9% 향상

 

Related Works, Pattern Analysis of LoRA and FT

위 두 섹션에서는 PEFT 및  LoRA 알고리즘의 작동 원리에 대한 설명으로, 이 글에서는 목차형식으로만 언급하겠습니다.

 

1. LoRA 개요 및 작동 방식

  • LoRA는 사전 학습된 모델(LLM 또는 vision transformer)을 특정(종종 더 작은) 데이터 세트에 더 적합하도록 수정하는 머신 러닝 기법입니다.
  • 이는 모델의 매개변수의 작고 저순위 하위 집합만 조정하여 수행됩니다.
  • 이 접근 방식은 대규모 모델을 작업별 데이터에 대해 효율적으로 파인튜닝할 수 있게 하여 파인튜닝에 필요한 계산 비용과 시간을 크게 줄여줍니다.

2. LoRA의 장점

  • GPU 메모리 사용량 감소
  • 모델 파라미터 수 감소
  • 빠른 파인튜닝 속도
  • 모델 성능 향상 가능성

3. LoRA의 한계점 

  • LoRA는 모든 방향으로 동일한 방식(방향과 크기 업데이트를 비례하여 학습하므로 패턴이 단순함)으로 업데이트를 적용하기 때문에 방향 업데이트에 대한 제어가 부족합니다.
  • FT(Full Fine-tuning)에 비해 낮은 성능

4. 기타 Weight Decomposition Analysis

  • Weight Normalization 언급하며 인사이트를 얻음.

 

Method

1. Weight-Decomposed Low-Rank Adaptation

Weight decomposition에서 얻은 인사이트를 바탕으로 DoRA를 소개하고있습니다. 계속 반복되는 말이지만, 두 가지 직관을 가지고 방법을 제안했다고합니다. 첫번째로 기존 LoRA가 크기와 방향 모두에 대해 학습해야 하는걸 DoRA에서는 방향 성분 학습에만 집중하도록 하면서 작업을 단순화 하는 것. 둘째는 Weight를 분해함으로써 안정적인 directional updates가 이루어지는 것.

위 Weight normalization 수식처럼 Weight 행렬을 크기와 방향성으로 분해합니다. 여기서 DoRA는 Weight Normalization과 다르게 두 부분을 처음부터 학습하지 않고, 이미 훈련된 Weight를 시작점으로 하여 초기화 문제를 피합니다. 

위 식에서 ∆V는 두 low-rank 행렬 B와 A를 곱하여 학습한 incremental directional 업데이트이며, 밑줄친 매개변수는 trainable 매개변수를 나타냅니다. 행렬 ℝd×r 와 A ∈ ℝr×k 는 LoRA의 전략에 따라 초기화되며, fine-tuning 전에 W'가 W0와 동일하도록 보장합니다. 또한, DoRA는 추론 전에 사전 훈련된 가중치와 병합될 수 있으므로 추가적인 지연을 도입하지 않습니다.

 

FT, LoRA, DoRA 각각 Weight matrix의 크기 및 방향 차이 시각화

위 전략으로 학습시켰을때, 그림처럼 FT와 DoRA는 LoRA와 달리 음의 기울기를 보입니다. 이는 FT가 이미 다양한 Downsteam task 작업에 적합한 상당한 지식을 갖고 있기 때문이라고 추론합니다. 

 

2. Gradient Analysis of DoRA

이 섹션에서는 DoRA의 Gradient를 도출하고 ∆V의 최적화에 어떻게 도움이 되는지와 DoRA의 학습 패턴을 설명하고 있습니다.

 

여기서부턴 수식이 너무 많이나와 추후에 다시 이어서 작성하겠습니다..

 

 

Implementation of LoRA and DoRA layers in PyTorch

추가적으로 논문에서는 아직 모델이나 코드가 공개되지 않았다고하는데 DoRA를 코드로 구현한 github을 발견했습니다.

https://github.com/catid/dora/tree/main

 

GitHub - catid/dora: Implementation of DoRA

Implementation of DoRA. Contribute to catid/dora development by creating an account on GitHub.

github.com

 

Pytorch로 DoRA 적용 방법을 구현한 코드입니다.

import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

# This layer is dropped into your pre-trained PyTorch model where nn.Linear is used
class DoRALayer(nn.Module):
    def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
        super().__init__()

        if weight is not None:
            self.weight = nn.Parameter(weight, requires_grad=False)
        else:
            self.weight = nn.Parameter(torch.Tensor(d_out, d_in), requires_grad=False)

        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False)
        else:
            self.bias = nn.Parameter(torch.Tensor(d_out), requires_grad=False)

        # m = Magnitude column-wise across output dimension
        self.m = nn.Parameter(self.weight.norm(p=2, dim=0, keepdim=True))
        
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.lora_A = nn.Parameter(torch.randn(d_out, rank)*std_dev)
        self.lora_B = nn.Parameter(torch.zeros(rank, d_in))

    def forward(self, x):
        lora = torch.matmul(self.lora_A, self.lora_B)
        adapted = self.weight + lora
        column_norm = adapted.norm(p=2, dim=0, keepdim=True)
        norm_adapted = adapted / column_norm
        calc_weights = self.m * norm_adapted
        return F.linear(x, calc_weights, self.bias)


class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.layer1(x)
        return x

# Generating synthetic data
def generate_data(num_samples=100, input_dim=10):
    X = torch.randn(num_samples, input_dim)
    y = torch.sum(X, dim=1, keepdim=True)  # Simple relationship for demonstration
    return X, y

# Training function
def train(model, criterion, optimizer, data_loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        for inputs, targets in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        #print(f"Epoch {epoch+1}, Loss: {loss.item()}")



def replace_linear_with_dora(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            # Get the input and output dimensions of the current nn.Linear layer
            d_in = module.in_features
            d_out = module.out_features

            # Create a new DoRALayer with the same dimensions
            setattr(model, name, DoRALayer(d_out=d_out, d_in=d_in, weight=module.weight.data.clone(), bias=module.bias.data.clone()))
        else:
            # Recursively apply this function to submodules
            replace_linear_with_dora(module)

def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")

# Main script
if __name__ == "__main__":
    input_dim, output_dim = 10, 1
    model = SimpleModel(input_dim, output_dim)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001)

    X, y = generate_data(num_samples=1000, input_dim=input_dim)
    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    print_model_parameters(model)

    train(model, criterion, optimizer, data_loader, epochs=100)

    # Evaluate the model
    model.eval()
    with torch.no_grad():
        inputs, targets = next(iter(data_loader))
        predictions = model(inputs)
        loss = criterion(predictions, targets)
        print(f"Final Evaluation Loss: {loss.item()}")

    replace_linear_with_dora(model)

    print_model_parameters(model)

    # Continue training with the Dora model
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
    print("Continuing training with DoRA layers...")
    train(model, criterion, optimizer, data_loader, epochs=5)  # Continue training

    # Evaluate the model
    model.eval()
    with torch.no_grad():
        inputs, targets = next(iter(data_loader))
        predictions = model(inputs)
        loss = criterion(predictions, targets)
        print(f"Final (DoRA) Evaluation Loss: {loss.item()}")

 

 

 

728x90