Optimizing Your Model for Inference with PyTorch Quantization
Editor’s Note: Jerry is a speaker for ODSC East 2022. Be sure to check out his talk, “Quantization in PyTorch,” to learn more about PyTorch quantization!
Quantization is a common technique that people use to make their model run faster, with lower memory footprint and lower power consumption for inference without the need to change the model architecture. In this blog post, we will briefly introduce what quantization is and how to apply quantization to your PyTorch models.
What is Quantization
Typical neural networks run in 32-bit floating point (float32) precision, which means both the activation and weight Tensors are in float32, and computations are performed in float32 precision as well. Quantization tries to reduce the precision of the model to a more compact data type that requires less memory for storage and performs computation faster, for example, 8-bit integer (int8). Taking int8 as an example, after we quantize the model, both activation and weight Tensors can be stored in int8 and the computations will be performed in int8 which is typically more efficient than float32 computations.
We can view quantization as a compression for the model, and it is not a lossless compression, since the lower precision data type may have less dynamic range and resolution. Therefore we need to have a trade-off between the accuracy of the model and the speedup, memory, and power consumption benefit that we get from quantization.
How to Use PyTorch Quantization
How do we obtain a quantized model from a floating point model? There are two ways in general:
- Post Training Quantization: After we have a trained model, we can convert the model to a quantized model, this is typically easy to apply, but we may see some accuracy loss for some types of models.
- Quantization Aware Training: During training, we insert fake quantization operators into the model to simulate the quantization behavior and convert the model to a quantized model after training based on the model with fake quantize operators. This is harder to apply than post-training quantization since it requires retraining the model, but typically gives better accuracy.
Since we are reducing the precision of the Tensors, we need to establish a mapping from the float32 Tensor and the quantized Tensor, a typical mapping function is affine transformation. For example, to quantize a float32 Tensor to an int8 Tensor, we can divide the float32 value by a `scale` and add a `zero_point` it, then we will clamp the value to int8, therefore we will need to figure out the `scale` and `zero_point` parameters for each Tensor that we want to quantize. In general, we have the following process (Post Training Quantization):
- Prepare: we insert some observers to the model to observe the statistics of a Tensor, for example, min/max values of the Tensor
- Calibration: We run the model with some representative sample data, this will allow the observers to record the Tensor statistics
- Convert: Based on the calibrated model, we can figure out the quantization parameters for the mapping function and convert the floating point operators to quantized operators
Currently, PyTorch offers two different ways of quantization: Eager Mode Quantization and FX Graph Mode Quantization. Here I’ll show an example using FX Graph Mode Quantization to quantize a resnet50 model from torchvision:
import copyfrom torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torchvision.models import resnet50
fp32_model = resnet50().eval()model = copy.deepcopy(fp32_model)
# `qconfig` means quantization configuration, it specifies how should we
# observe the activation and weight of an operator
# `qconfig_dict`, specifies the `qconfig` for each operator in the model
# we can specify `qconfig` for certain types of modules
# we can specify `qconfig` for a specific submodule in the model
# we can specify `qconfig` for some functioanl calls in the model
# we can also set `qconfig` to None to skip quantization for some operators
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}# `prepare_fx` inserts observers in the model based on the configuration in `qconfig_dict`
model_prepared = prepare_fx(model, qconfig_dict)# calibration runs the model with some sample data, which allows observers to record the statistics of
# the activation and weigths of the operators
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
for i in range(len(calibration_data)):
model_prepared(calibration_data[i])# `convert_fx` converts a calibrated model to a quantized model, this includes inserting
# quantize, dequantize operators to the model and swap floating point operators with quantized operators
model_quantized = convert_fx(copy.deepcopy(model_prepared))# benchmark
x = torch.randn(1, 3, 224, 224)
%timeit fp32_model(x)
%timeit model_quantized(x)
Output:
38.7 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8.65 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
As we can see, the quantized model achieved around 4.5x speedup over the original float32 model.
In the above example, we used `qconfig_dict` to control how to quantize a model, empty string means global configuration. We can use it to decide the types of quantization we want for each individual operator that is used in the model.
How to Do Numerical Debugging after Quantization
After we quantize the model, we may find that we got a great speedup, but the accuracy could suffer because we quantized too many operators. Then how do we find the operators that are most sensitive to quantization and skip quantizing these operators in order to recover the accuracy?
We can use Numeric Suite in PyTorch to find the impact of quantization on the activation and weight of the model. We also have Eager Mode Numeric Suite and FX Graph Mode Numeric Suite. It has three features:
- Compare the quantization loss for weight
- Compare the accumulative quantization loss for activation
- Compare the per operator quantization loss for activation
We’ll show a simple example comparing the quantization loss for weight of resnet50 model with FX Graph Mode Numeric Suite below.
# Compare weights of float_model and qmodel.
import torch.ao.ns._numeric_suite_fx as ns# Note: when comparing weights in models with Conv-BN for PTQ, we need to compare
# weights after Conv-BN fusion for a proper comparison. Because of this, we use
# `prepared_model` instead of `float_model` when comparing weights.# Extract conv and linear weights from corresponding parts of two models, and save
# them in `wt_compare_dict`.
resnet50_wt_compare_dict = ns.extract_weights(
'fp32', # string name for model A
model_prepared, # model A
'int8', # string name for model B
model_quantized, # model B
)# calculate SQNR between each pair of weights
# SQNR is a measure of quantization loss, large SQNR value means the quantization loss is small
ns.extend_logger_results_with_comparison(
resnet50_wt_compare_dict, # results object to modify inplace
'fp32', # string name of model A (from previous step)
'int8', # string name of model B (from previous step)
torch.ao.ns.fx.utils.compute_sqnr, # the function to use to compare two tensors
'sqnr', # the name to use to store the results under
)# massage the data into a format easy to graph and print
# Note: no util function for this since use cases may be different per user
# Note: there is a lot of debugging data, and it will be challenging to print all of it
# and fit on a laptop screen. It is up to the user to decide which data is useful for them.
resnet50_wt_to_print = []
for idx, (layer_name, v) in enumerate(resnet50_wt_compare_dict.items()):
resnet50_wt_to_print.append([
idx,
layer_name,
v['weight']['int8'][0]['prev_node_target_type'],
v['weight']['int8'][0]['values'][0].shape,
v['weight']['int8'][0]['sqnr'][0],
])%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')# a simple line graph
def plot(xdata, ydata, xlabel, ylabel, title):
fig = plt.figure(figsize=(10, 5), dpi=100)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
ax = plt.axes()
ax.plot(xdata, ydata) # plot the SQNR between fp32 and int8 weights for each layer
# Note: we may explore easier to read charts (bar chart, etc) at a later time, for now
# line chart + table is good enough.
plot([x[0] for x in resnet50_wt_to_print], [x[4] for x in resnet50_wt_to_print], 'idx', 'sqnr', 'weights, idx to sqnr')
After plotting the SQNR for all the weights, we can find the layer with the lowest SQNR, which means it has the largest quantization error, and skip quantizing that layer by changing the `qconfig_dict` settings. We can find an optimal point for our use case by gradually skipping quantization for the layers that are most sensitive to quantization.
Summary
In this article, we talked about quantization, a common technique to optimize a model for inference, and also the tools provided in PyTorch to quantize a model and debug quantization errors to recover the accuracy of the model. For a more in-depth overview of this topic, please check out my upcoming talk this April at ODSC East 2022: Quantization in PyTorch. We’ll have a more comprehensive walkthrough of the tools you can use in PyTorch Quantization.
Read more data science articles on OpenDataScience.com, including tutorials and guides from beginner to advanced levels! Subscribe to our weekly newsletter here and receive the latest news every Thursday. You can also get data science training on-demand wherever you are with our Ai+ Training platform. Subscribe to our fast-growing Medium Publication too, the ODSC Journal, and inquire about becoming a writer.