- Published on

# Expriments with Bitnet 1.5 (~ngmi~)

- Authors
- Name
- Joey00072
- @shxf0072

Paper: arxiv.org/abs/2402.17764

(ever wondered why they didn't showed loss curve in paper)

### It kind of works but Don't Get hyped up

This is one more of quantized-aware training rather than something radically new. there is room for kernel fusion and cuda magic optimization becasue hypothesis is storing weight matrix contain only 1,0,-1 that means we only need ADD,SUB operation which are better than expensive MUL. while at inferece type we can achive carzy speed up I am not sure that can be done at traning time. because we need smooth changes in optimizer gradient momentum.

For questions/corrections reach me over twitter/x or `00shxf@gmail.com`

. I build code in midst of chaos so if you find something to correct let me know.

Here is code if you are just interested in it

## Implementation Details

this is a bitnet implementation based on officially shared pdf by authors.

### Weight Quantization

Note for current experiments all weight store in fp32/16 and quaztized in forward process.

Here $\gamma$ = scale and $\widetilde{W}$ = w_quant

```
from torch import Tensor
def weight_quant(w: Tensor) -> tuple[Tensor, Tensor]:
scale: Tensor = 1.0 / w.abs().mean().clamp(min=1e-5)
quant: Tensor = (w * scale).round().clamp(-1, 1) / scale
w_quant: Tensor = w + (quant - w).detach()
scale = abs(w_quant).max().detach()
w_quant = w_quant / scale
return w_quant, scale
```

w_quant is matrices of 1,0,-1

### Activation Quantization

This is from a precursor paper https://arxiv.org/pdf/2310.11453.pdf

```
from torch import Tensor
def activation_quant(x: Tensor) -> Tensor:
scale: Tensor = 127.0 / x.abs().max(dim=1, keepdim=True).values.clamp(min=1e-5)
y: Tensor = (x * scale).round().clamp(-128, 127) / scale
return x + (y - x).detach()
```

### BitLinear

Now this two functions we can implement BitLinear; this layer is a drop-in replacement for nn.Linear. There is a lot of room for optimization with kernel fusion, manully writing backwork pass etc etc.

```
import torch
import torch.nn as nn
class BitLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super(BitLinear, self).__init__(*args, **kwargs)
self.rms_norm = RMSNorm(self.in_features)
def forward(self, x: Tensor) -> Tensor:
w = self.weight
x_norm = self.rms_norm(x)
x_quant = activation_quant(x_norm)
w_quant, scale = weight_quant(w)
output = nn.functional.linear(x_quant, w_quant)
return output * scale
```

### Traning code

You can find my pytorch implementation here https://github.com/joey00072/ohara/tree/master/experiments/bitnet.

I trained it 15.5M llama and 15.6M bitnet llama (bitnet rmsnorm extra 0.1M). You can check loss curve bitnet is worse than llama (and I am pretty sure 2bit quatization aware traning will be better or same). wandb run

Right now bitnet is pretty usesless for inference on current hardware, it much better to train model in bfloat16 and use quatized version if you need it. To get to crazy speed-up,

pre requisite for bitnet1.5 are big. we need someone to create special chip that have mixed precision with 2-bit. So big assumption is we will use 2 bit model for inferace so someone will spend lot of (a lot) money to build chip, software and train quantization aware llm. Not to mention the fact the we dont know if scaling laws hold same for 1.5bit (its 2 bit ) quantiztion traning as they do for normal. Unless someone is ready to shit ton of money.

BITNET ~~NGMI~~