Published on

Expriments with Bitnet 1.5 (~ngmi~)

Authors

Paper: arxiv.org/abs/2402.17764

image/png (ever wondered why they didn't showed loss curve in paper)

It kind of works but Don't Get hyped up

This is one more of quantized-aware training rather than something radically new. there is room for kernel fusion and cuda magic optimization becasue hypothesis is storing weight matrix contain only 1,0,-1 that means we only need ADD,SUB operation which are better than expensive MUL. while at inferece type we can achive carzy speed up I am not sure that can be done at traning time. because we need smooth changes in optimizer gradient momentum.

For questions/corrections reach me over twitter/x or 00shxf@gmail.com. I build code in midst of chaos so if you find something to correct let me know.

Here is code if you are just interested in it

Implementation Details

this is a bitnet implementation based on officially shared pdf by authors.

image/png

Weight Quantization

Note for current experiments all weight store in fp32/16 and quaztized in forward process.

image/png

Here γ\gamma = scale and W~\widetilde{W} = w_quant

from torch import Tensor

def weight_quant(w: Tensor) -> tuple[Tensor, Tensor]:
    scale: Tensor = 1.0 / w.abs().mean().clamp(min=1e-5)
    quant: Tensor = (w * scale).round().clamp(-1, 1) / scale
    w_quant: Tensor = w + (quant - w).detach()
    scale = abs(w_quant).max().detach()
    w_quant = w_quant / scale
    return w_quant, scale

w_quant is matrices of 1,0,-1

Activation Quantization

This is from a precursor paper https://arxiv.org/pdf/2310.11453.pdf

image/png

from torch import Tensor

def activation_quant(x: Tensor) -> Tensor:
    scale: Tensor = 127.0 / x.abs().max(dim=1, keepdim=True).values.clamp(min=1e-5)
    y: Tensor = (x * scale).round().clamp(-128, 127) / scale
    return x + (y - x).detach()

BitLinear

Now this two functions we can implement BitLinear; this layer is a drop-in replacement for nn.Linear. There is a lot of room for optimization with kernel fusion, manully writing backwork pass etc etc.

import torch
import torch.nn as nn

class BitLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(BitLinear, self).__init__(*args, **kwargs)
        self.rms_norm = RMSNorm(self.in_features)

    def forward(self, x: Tensor) -> Tensor:
        w = self.weight
        x_norm = self.rms_norm(x)

        x_quant = activation_quant(x_norm)
        w_quant, scale = weight_quant(w)

        output = nn.functional.linear(x_quant, w_quant)
        return output * scale

Traning code

You can find my pytorch implementation here https://github.com/joey00072/ohara/tree/master/experiments/bitnet.

I trained it 15.5M llama and 15.6M bitnet llama (bitnet rmsnorm extra 0.1M). You can check loss curve bitnet is worse than llama (and I am pretty sure 2bit quatization aware traning will be better or same). wandb run

image/png Right now bitnet is pretty usesless for inference on current hardware, it much better to train model in bfloat16 and use quatized version if you need it. To get to crazy speed-up,

pre requisite for bitnet1.5 are big. we need someone to create special chip that have mixed precision with 2-bit. So big assumption is we will use 2 bit model for inferace so someone will spend lot of (a lot) money to build chip, software and train quantization aware llm. Not to mention the fact the we dont know if scaling laws hold same for 1.5bit (its 2 bit ) quantiztion traning as they do for normal. Unless someone is ready to shit ton of money.

BITNET NGMI


ko-fi

( ˶ᵔ ᵕ ᵔ˶ ) Discuss on Twitter