Published on

PyTorch 2.x PTQ - TorchAO INT4로 2배 가속

Authors

서빙 비용을 줄이려면 결국 같은 GPU에서 더 많은 토큰을 뽑아내야 합니다. 가장 빠르게 체감되는 레버 중 하나가 PTQ(Post-Training Quantization)이며, PyTorch 2.x에서는 TorchAO를 통해 INT4까지 비교적 간단히 내려갈 수 있습니다. 특히 LLM 계열에서 weight-only INT4는 정확도 손실을 제한하면서도 메모리 대역폭 병목을 줄여 대략 1.3배~2배 수준의 처리량 개선을 기대할 수 있습니다(모델 구조, 시퀀스 길이, 커널/하드웨어에 따라 편차 큼).

이 글은 다음을 목표로 합니다.

  • PyTorch 2.x에서 TorchAO로 INT4 PTQ를 적용하는 최소 경로
  • 실제로 “2배 가속”에 가까워지기 위한 조건(커널, 컴파일, 배치/시퀀스)
  • 정확도/품질 저하를 제어하는 체크리스트

운영 환경에서 지연/타임아웃이 함께 문제라면, 인프라 레벨의 타임아웃/리트라이 폭주도 같이 잡아야 전체 체감이 좋아집니다. 관련해서는 gRPC MSA에서 데드라인·재시도 폭주 막는 법, GPU 서빙 레이어 이슈라면 KServe + Istio에서 GPU 모델 503·타임아웃 해결도 함께 참고하면 좋습니다.

TorchAO INT4 PTQ가 “잘 먹히는” 이유

대부분의 추론 워크로드(특히 LLM)는 다음 병목 중 하나에 걸립니다.

  • 연산량(FLOPs) 병목: 작은 배치, 짧은 시퀀스, 커널이 잘 최적화된 경우
  • 메모리 대역폭 병목: 큰 모델, 긴 시퀀스, KV 캐시/가중치 로딩이 지배적인 경우

INT4 weight-only는 보통 가중치 메모리 풋프린트를 크게 줄여 메모리 대역폭 병목을 완화합니다. 활성값(activation)은 FP16/BF16로 유지하는 경우가 많아 품질 저하를 제한할 수 있고, 구현 난이도도 INT8 activation까지 가는 것보다 낮습니다.

다만 “2배”가 항상 나오지는 않습니다. 다음 조건이 맞을수록 유리합니다.

  • 모델이 크고(가중치 로딩이 크고), 배치/시퀀스가 커서 메모리 대역폭이 지배적일 때
  • INT4 GEMM 커널이 잘 붙을 때(하드웨어/라이브러리/빌드 조합)
  • PyTorch 2.x의 컴파일 경로(torch.compile)가 양자화된 모듈을 잘 최적화할 때

준비물: 버전과 환경 체크

TorchAO는 PyTorch 2.x와 함께 빠르게 진화하고 있어, “코드가 맞는데도 커널이 안 붙는” 상황이 생각보다 자주 나옵니다. 먼저 아래를 확인하세요.

  • CUDA와 GPU 아키텍처(예: Ampere, Ada, Hopper)
  • PyTorch 버전(가능하면 최신 2.x)
  • TorchAO 설치 및 버전

설치 예시는 아래처럼 시작합니다(환경에 따라 다를 수 있음).

pip install -U torch
pip install -U torchao

런타임에서 버전 확인도 해두면 디버깅이 빨라집니다.

import torch

print(torch.__version__)
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")

핵심 개념: INT4 PTQ에서 선택해야 하는 것들

TorchAO로 INT4를 적용할 때, 실무에서 가장 자주 부딪히는 선택지는 아래 3가지입니다.

1) Weight-only vs Activation까지 양자화

  • weight-only INT4: 가중치만 INT4, activation은 FP16/BF16 유지
    • 장점: 품질/안정성 유리, 적용 쉬움
    • 단점: activation 메모리/연산은 그대로
  • activation까지 내리는 방식: 더 큰 이득 가능하지만 튜닝 난이도와 품질 리스크가 커짐

이 글은 “가장 ROI 좋은” weight-only INT4에 집중합니다.

2) Per-channel(또는 group-wise) 스케일링

INT4는 표현 범위가 좁아 스케일링이 중요합니다.

  • per-tensor: 단순하지만 정확도 손실이 커질 수 있음
  • per-channel 또는 group-wise: 더 정교, 정확도 유리(대신 메타데이터/연산 비용 증가)

LLM 계열에서는 group-wise(예: group size 128) 같은 구성으로 균형을 잡는 경우가 흔합니다.

3) 커널/컴파일 경로

같은 양자화라도 커널이 다르면 성능이 크게 갈립니다.

  • eager 실행에서는 커널 선택이 제한될 수 있음
  • torch.compile로 그래프 최적화가 붙으면 더 좋은 경로로 떨어질 수 있음

따라서 “양자화만 했다”로 끝내지 말고, 반드시 벤치마크로 확인해야 합니다.

실전: TorchAO로 INT4 PTQ 적용하기

아래 코드는 “모델의 Linear 계열(예: attention/MLP projection)을 INT4 weight-only로 바꾸는” 전형적인 흐름을 보여줍니다. 실제 TorchAO API는 버전에 따라 함수/모듈명이 바뀔 수 있으니, 핵심은 구조(준비 → 변환 → 검증 → 컴파일/벤치)로 이해하는 게 안전합니다.

예시 모델 준비

데모를 위해 간단한 MLP를 쓰지만, 실제로는 Transformer 블록의 nn.Linear들을 대상으로 합니다.

import torch
import torch.nn as nn

class TinyMLP(nn.Module):
    def __init__(self, d=4096):
        super().__init__()
        self.fc1 = nn.Linear(d, 4*d, bias=False)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(4*d, d, bias=False)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

model = TinyMLP(d=2048).cuda().eval().to(torch.float16)

TorchAO INT4 변환(개념 코드)

아래는 “TorchAO를 이용해 Linear를 INT4 weight-only로 치환”하는 전형적인 패턴입니다. 실제 환경에서는 TorchAO의 권장 함수(예: quantize_ 또는 replace_linear 류)를 사용하게 됩니다.

import torch

# TorchAO API는 버전별로 달라질 수 있습니다.
# 핵심은: (1) weight-only INT4 설정 (2) Linear 모듈 치환 입니다.

def quantize_linear_modules_to_int4(model: torch.nn.Module):
    import torchao

    # 예: group-wise INT4 설정을 만든다고 가정
    # (정확한 옵션명은 설치된 torchao 문서를 확인하세요)
    config = {
        "dtype": "int4_weight_only",
        "group_size": 128,
        "per_channel": True,
    }

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # 개념적으로는 여기서 module을 INT4 Linear로 교체합니다.
            # torchao의 실제 API에 맞게 치환하세요.
            pass

    return model

qmodel = quantize_linear_modules_to_int4(model)

위 코드에서 pass로 둔 이유는, TorchAO가 빠르게 바뀌기 때문입니다. 하지만 실무에서 중요한 건 다음을 반드시 점검하는 것입니다.

  • 변환 후 state_dict 크기가 줄었는지
  • 변환 후 forward가 정상 동작하는지
  • 변환 후 실제 커널이 INT4 경로를 타는지(벤치로 확인)

변환 검증: 출력 차이와 품질 확인

PTQ는 학습 없이 바꾸는 만큼, “정확도/품질 저하를 어디까지 허용할지”를 수치로 관리해야 합니다.

import torch

torch.manual_seed(0)
x = torch.randn(8, 128, 2048, device="cuda", dtype=torch.float16)

with torch.no_grad():
    y_fp16 = model(x)
    y_int4 = qmodel(x)

# 상대 오차(간단 지표)
rel_err = (y_fp16 - y_int4).float().norm() / (y_fp16.float().norm() + 1e-12)
print("relative error:", rel_err.item())

LLM이라면 perplexity, MMLU 같은 태스크 지표 또는 최소한 샘플 프롬프트 세트의 회귀 테스트를 붙이는 게 안전합니다.

“2배 가속”에 가까워지려면: 벤치마크 방법

체감 성능은 대부분 latencythroughput으로 봅니다.

  • latency: 단일 요청의 지연(특히 batch=1)
  • throughput: 초당 토큰/샘플 처리량(배치가 커질수록 중요)

아래는 매우 단순한 벤치 템플릿입니다.

import time
import torch

def bench(model, x, iters=200, warmup=50):
    model.eval()
    with torch.no_grad():
        # warmup
        for _ in range(warmup):
            _ = model(x)
        torch.cuda.synchronize()

        t0 = time.time()
        for _ in range(iters):
            _ = model(x)
        torch.cuda.synchronize()
        t1 = time.time()

    return (t1 - t0) / iters

x = torch.randn(8, 128, 2048, device="cuda", dtype=torch.float16)

t_fp16 = bench(model, x)
t_int4 = bench(qmodel, x)

print("fp16 ms:", t_fp16 * 1000)
print("int4 ms:", t_int4 * 1000)
print("speedup:", t_fp16 / t_int4)

여기서 중요한 팁은 다음입니다.

  • 반드시 torch.cuda.synchronize()로 측정 구간을 고정하세요.
  • 입력 크기(배치, 시퀀스)를 실제 트래픽과 비슷하게 잡으세요.
  • 단일 설정으로 결론 내리지 말고, batch=1batch>1을 분리해서 보세요.

torch.compile과 함께 쓰기

PyTorch 2.x의 강점 중 하나가 torch.compile입니다. 양자화 모델도 컴파일이 잘 붙으면 커널/그래프 최적화로 이득이 더 커질 수 있습니다.

import torch

compiled_qmodel = torch.compile(qmodel, mode="max-autotune")

t_int4_compiled = bench(compiled_qmodel, x)
print("int4 compiled ms:", t_int4_compiled * 1000)

주의할 점도 있습니다.

  • 컴파일은 워밍업 비용이 큽니다. 서버 스타트업에 영향을 줍니다.
  • 동적 shape가 많으면 컴파일 캐시가 깨지거나 재컴파일이 잦아질 수 있습니다.
  • 일부 양자화 모듈은 컴파일에서 그래프 브레이크가 날 수 있습니다.

서빙에서 타임아웃이 민감하면, 컴파일 워밍업을 별도 단계로 분리하거나(예: 헬스체크 전에), 타임아웃/리트라이 정책을 조정해야 합니다. 운영 관점의 타임아웃 점검은 Cloud Run 504 Timeout 원인·해결 9가지도 같이 보면 도움이 됩니다.

정확도 손실을 줄이는 체크리스트

INT4 PTQ는 “속도 vs 품질”의 줄다리기입니다. 아래 순서로 접근하면 시행착오를 줄일 수 있습니다.

1) 먼저 INT8 또는 FP8 경로로 안전하게 기준선 만들기

가능하다면 INT8(또는 FP8)로 먼저 내려서 품질 저하 패턴을 파악한 뒤 INT4로 가는 것이 안정적입니다.

2) 민감 레이어 제외(exclude list)

모든 Linear를 한 번에 내리기보다, 품질에 민감한 레이어를 제외하는 방식이 잘 먹힙니다.

  • 임베딩/LM head
  • 특정 attention projection
  • 첫/마지막 블록

3) 그룹 크기 조정

group_size를 작게 하면 정확도는 좋아질 수 있지만 메타데이터/오버헤드가 늘 수 있습니다. 반대로 크게 하면 성능은 좋아질 수 있으나 정확도가 흔들릴 수 있습니다.

4) 캘리브레이션 데이터(있다면) 사용

PTQ라도 캘리브레이션 샘플(대표 입력)을 써서 스케일을 더 안정적으로 잡는 방식이 있습니다. 특히 activation까지 건드리는 경우에는 사실상 필수입니다.

성능이 안 나올 때 흔한 원인

“INT4로 바꿨는데 빨라지지 않는다”는 케이스는 보통 아래 중 하나입니다.

1) 커널이 INT4로 떨어지지 않음

  • 특정 GPU 아키텍처에서 INT4 경로가 제한적
  • 라이브러리 빌드 옵션/버전 불일치
  • 모델 구조가 커널이 기대하는 레이아웃과 다름

해결은 결국 “벤치 + 프로파일”입니다. Nsight Systems/Compute 또는 PyTorch profiler로 GEMM 커널이 무엇으로 실행되는지 확인하세요.

2) 배치가 너무 작아 메모리 이득이 체감되지 않음

batch=1의 디코딩 루프에서는 KV 캐시 접근, 스케줄링 오버헤드, 작은 GEMM이 지배적이라 INT4 이득이 제한될 수 있습니다. 반대로 prefill(긴 시퀀스 입력)이나 배치가 커지면 이득이 커질 가능성이 큽니다.

3) 디코더 외 병목(토크나이저, 네트워크, 후처리)

모델만 빨라져도, 전체 p95는 토크나이저/네트워크/큐잉에서 결정될 수 있습니다. 특히 gRPC에서 데드라인/재시도 설정이 잘못되면, 모델이 빨라져도 “폭주”로 p95가 무너집니다. 이 부분은 앞서 언급한 gRPC MSA에서 데드라인·재시도 폭주 막는 법을 같이 점검하세요.

운영 적용 패턴: 안전한 롤아웃

INT4 PTQ는 모델 동작이 미묘하게 달라질 수 있으므로, 운영에서는 다음 패턴이 안전합니다.

  • 카나리 배포: 트래픽 일부만 INT4로 보내고 품질/지연/에러율 비교
  • A/B 로그: 동일 프롬프트에 대해 FP16과 INT4를 병렬로 돌려 차이를 수집(비용은 들지만 매우 확실)
  • 폴백 경로: 특정 입력/길이에서 품질이 흔들리면 FP16로 폴백

또한 서빙 레이어(Istio, Envoy, KServe 등)에서는 타임아웃/리트라이 정책이 모델 지연 변화에 민감합니다. INT4로 빨라지면 리트라이를 줄일 여지가 생기기도 하지만, 반대로 워밍업/재컴파일 구간에서 일시적으로 느려질 수 있어 배포 직후 알람이 뜰 수 있습니다.

정리

  • PyTorch 2.x에서 TorchAO 기반 INT4 weight-only PTQ는 학습 없이도 큰 폭의 메모리 절감과 처리량 개선을 노릴 수 있는 현실적인 선택지입니다.
  • “2배 가속”은 자동으로 따라오지 않습니다. 커널이 제대로 붙는지, 워크로드가 메모리 병목인지, torch.compile로 최적화가 가능한지를 벤치마크로 확인해야 합니다.
  • 품질은 exclude list, group size, 캘리브레이션(가능하면)로 제어하고, 운영에서는 카나리/폴백으로 리스크를 낮추는 게 좋습니다.

다음 단계로는 실제 LLM(예: Transformer 기반)에서 어떤 Linear를 우선 양자화할지, prefill/decoding을 분리해 측정하는 벤치 스크립트, 그리고 프로파일링으로 INT4 커널 경로를 확인하는 방법까지 확장하면 “가속이 왜 안 나오는지”를 빠르게 좁힐 수 있습니다.