Jak podaje definicja: Kwantyzacja to nieodwracalne nieliniowe odwzorowanie statyczne zmniejszające dokładność danych przez ograniczenie ich zbioru wartości.
W tym poście dowiesz się jak zaaplikować tę technikę transformacji dla modeli stworzonych w PyTorch i jak wpływa ona na wydajność oraz precyzję modelu.

W wersji PyTorch 1.3 wprowadzono możliwość kwantyzacji modelu do precyzji INT8 (wartość stałoprzecinkowa zapisana na 8 bitach). Jak podają autorzy PyTorch jest to na razie implementacja eksperymentalna, jednakże mi udało się bez problemów zaaplikować ją do przykładowego modelu.

Domyślnie wszystkie wagi tensorów, z których zbudowane są warstwy sieci neuronowych zapisywane są w formacie Float32 (FP32) - jest to typ zmiennoprzecinkowy zajmujący 32 bity. Dane zapisane przy użyciu FP32 zajmują 4 razy więcej pamięci niż INT8. Ta pierwsza cecha wpływa na możliwości używania dużych modeli na urządzeniach o skromniejszej ilości dostępnej pamięci.

Operacje stałoprzecinkowe (szczególnie dodawanie i mnożenie) są także o wiele mniej kosztowne w wywołaniu od zmiennoprzecinkowych, licząc ile cykli zegara taktującego procesor potrzebnych jest do ich wykonania. Trend rynkowy, taki jak przykładowo technologia Intel-a nazwana"Deep Learning Boost" wspierają także operacji modeli sieci neuronowych, ale tylko jeżeli chodzi o operacje INT8.

Patrząc na powyższe argumenty widać zatem, że do zastosowań gdzie liczy się mały rozmiar zajmowanej pamięci a także wyśrubowana wydajność modeli, transformacja sieci z typu FP32 to INT8 może okazać się zbawienna.

Prostota

Sama koncepcja transformacji jest niezwykle prosta.

Żeby zapisać N wag F32 dla danej warstwy sieci potrzebne jest N wag INT8 wraz z pojedynczą parą liczb scale (skala) i zeropoint (punkt zera).

Pożądaną cechą transformacji jest fakt, żeby równanie:

$$FP32value \approx (INT8value - zeropoint) * scale$$

było obarczone minimalnym możliwym błędem. By dokonać tej transformacji w formie optymalnej transformowana sieć musi zebrać statystyki na reprezentatywnym zbiorze danych. Dzięki temu można określić dla każdej warstwy zakres możliwych przyjmowanych wartości i dobrać optymalne wartości transformacji.

Błąd

Operacja kwantyzacji obarczona jest pewnym błędem, w związku z czym wynikowa sieć nie jest idealnym odwzorowaniem sieci źródłowej. Z tego powodu zmienia się też jej dokładność. Dla każdego przypadku takiej transformacji należy przeprowadzić analizę utraty skuteczności modelu, by właściwie podjąć decyzję czy nowe wyniki są akceptowalne dla adresowanego problemu.

Z mojej praktyki wynika, że utrata skuteczności jest "niewielka", jednakże każda domena rządzi się swoimi prawami i definiuje co to właściwie znaczy "niewielka utrata".

Mówiąc o błędzie należy też wspomnieć o dwóch możliwościach generowania zoptymalizowanej wersji sieci:

  • post-training (po treningu sieci)
  • quantization-aware training (wykonanie treningu sieci, który bezpośrednio adresuje temat przygotowania sieci do kwantyzacji)

Oba terminy wyjaśnię przy okazji prezentacji kodu.

Kod (wersja statyczna)

Kwantyzacja statyczna po treningu to proces składający się z następujących kroków:

  • opcjonalne złączenie warstw sieci celem polepszenia wydajności kwantyzacji
  • określenie parametrów kwantyzacji
  • przygotowania sieci do zbierania statystyk rozpiętości wartości dla warstw sieci
  • zebranie statystyk na reprezentatywnym zbierze danych
  • dokonanie operacji konwersji na typ INT8
  • (weryfikacja wyniku)

[Przedstawiany kod, dostosowany do potrzeb tego postu, pochodzi z mojego repozytorium na GitHub: bwosh/torch-quantization].
Zakładając, że mamy model z wyuczonymi wagami, przykładowo:

model = MyModel()
model.load_state_dict(torch.load('weights.pth'))

framework wymusza określenie parametrów kwantyzacji. Tutaj użyjemy ich domyślnych wartości:

import torch.quantization as quantization
model.qconfig = quantization.default_qconfig

następnie, opcjonalnie, możemy dokonać złączenia operacji : Conv2d, BatchNorm, ReLU w różnych wariantach - ja, posiadając w moim modelu warstwy konwolucji z następującymi funkcjami aktywacji ReLU, łączę je następująco:

pmodel = quantization.fuse_modules(model, 
     [['conv1','relu1'],
      ['conv2','relu2'],
      ['conv3','relu3']])

następnie dokonuję operacji przygotowania do procesu kwantyzacji - polega on na automatycznej wymianie warstw sieci odpowiednikami, które oprócz wykonywania właściwych im operacji zbierają statystyki odnośnie przetwarzanych wartości:

qmodel = quantization.prepare(pmodel,
          {"":quantization.default_qconfig})

następnie posiadając reprezentatywny dataset oraz obiekt ładujący dane w batch-ach (DataLoader) iteruję wszystkie próbki:

for img, _ in dataloader: 
    output = qmodel(img)

gdy operacja się kończy, na podstawie zebranych danych ostateczna operacja konwersji może zostać wykonana:

qmodel = quantization.convert(qmodel)

i to tyle. Model jest przekonwertowany - teraz należałyby zmierzyć jego poprawność - w zależności od modelu będzie to inny kod, więc nie będę go tu przytaczał - po szczegóły zapraszam do repozytorium kodu.

Kod (wersja quantization-aware training)

Optymalizacja procesu kwantyzacja może być zaadresowana wcześniej niż po zakończeniu treningu. Sprytny zabieg, którego dokonać można podczas samego treningu sprawi, że skuteczność modelu skwantyzowanego będzie wyższa niż w poprzednim, statycznym przykładzie.

Owy sprytny zabieg polega na symulacji operacji kwantyzacji podczas przetwarzania danych przez sieć. Cała operacja dokonywana jest na typie danych FP32, przez co proces optymalizacji modelu i liczenie gradientów nadal działa. Symulacja kwantyzacji sprawia, że cały model jest do pewnego momentu w stanie skompensować straty wynikające z utraty precyzji na pojedynczych warstwach.

Dokumentacja PyTorch podaje, że symulacja polega na przeprowadzeniu operacji:

x_out = ( clamp(
		    round(x/scale + zero_point), 
            quant_min, 
            quant_max
           ) - zero_point
        ) * scale
       

Posiadając model, wstępne wagi oraz dataset do przeprowadzenia treningu można rozpocząć prace od załadowania modelu i ustalenia konfiguracji operacji kwantyzacji:

model = MyModel()
model.load_state_dict(torch.load('weights.pth'))
model.qconfig = quantization.get_default_qat_qconfig()

kolejnym krokiem, podobnie jak w przypadku statycznym jest złączenie niektórych warstw:

pmodel = quantization.fuse_modules(model, 
		[['conv1','relu1'],
        ['conv2','relu2'],
        ['conv3','relu3']])

następną operacją jest konwersja warstw sieci do ich odpowiedników, przeprowadzających operację symulacji kwantyzacji:

qmodel = torch.quantization.prepare_qat(pmodel, inplace=True)

po tej konfiguracji nastąpić powinien proces nauki sieci, którego nie będę tu przytaczał. By zakończyć proces konwersji, a tym samym zamienić warstwy udające kwantyzację ich prawdziwymi odpowiednikami, należy dokonać finalnej operacji:

qmodel.eval()
qmodel = quantization.convert(qmodel)

Tym samym otrzymujemy model, który wyuczony w treningu z symulowaną kwantyzacją powinien osiągać wynik wyższy niż w wersji statycznej.

Moje wyniki

Dla przykładowo przeprowadzonych testów na zbiorze MNIST, używając kodu z repozytorium, moje modele uzyskały odpowiednio:

Model Celność(Accuracy) Rozmiar pliku wag (bajty) Czas inferencji (ms)
Oryginalny 96.5% 13595 5.28
Post-training quantization 95.0% 5792 2.76
Quantization-aware training 95.5% 6527 2.69

GitHub

Testując kwantyzację w PyTorch 1.3 stworzyłem repozytorium

bwosh/torch-quantization ,

które wykonuje trening sieci i kwantyzację wraz z podsumowaniami wyników. Zapraszam do zapoznania się z pełnym kodem!