Engineering

Jul 28, 2023

Engineering

Digging bitsandbytes issue

  • Jeongseok Kang

    Researcher

Jul 28, 2023

Engineering

Digging bitsandbytes issue

  • Jeongseok Kang

    Researcher

Backend.AI is a popular choice for developing these LLMs because of its ease of use in running large clusters and distributed processing. In fact, we get a lot of feedback and requests from customers, and today I'd like to share how we solved one of them.

On April 4, 2023, we received a report of an issue where an error occurs when running certain packages in the container environment provided by the NGC Catalog1 (NVIDIA GPU Cloud). The NGC Catalog is a list of containers2 with optimized environments for developing AI/ML, metaverse, and high-performance computing applications, and because it is operated and distributed directly by NVIDIA, it is highly trusted and considered the standard for CUDA environments in particular. Therefore, an issue with this environment represents a potential risk that many users will face in the future, and we have decided to address this issue as a high priority.

Reproducing the problem

I first went through the process of reproducing the issue to determine the exact cause. In this case, I was running ViperGPT3 developed by Columbia University and encountered an error in a package called bitsandbytes. ViperGPT has a dependency on bitsandbytes as shown below.

accelerate==0.18.0 backoff==2.2.1 // highlight-next-line bitsandbytes==0.38.1 cityscapesscripts==2.2.1 git+https://github.com/openai/CLIP.git decord==0.6.0 dill==0.3.6 ...

I was able to reproduce the problem by simply importing bitsandbytes.

The execution environment used the nvcr.io/nvidia/pytorch:22.05-py3 image.

$ pip install bitsandbytes # 0.37.1 $ python >> import bitsandbytes ===================================BUG REPORT=================================== Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues ================================================================================ CUDA exception! Error code: OS call failed or operation not supported on this OS CUDA exception! Error code: initialization error CUDA SETUP: CUDA runtime path found: /home/work/data/miniconda3/envs/vipergpt/lib/libcudart.so /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library... warn(msg) CUDA SETUP: Detected CUDA version 116 CUDA SETUP: Loading binary /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so... /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable. warn("The installed version of bitsandbytes was compiled without GPU support. "

The bitsandbytes traverses all the CUDA devices installed in the execution environment and checks their Compute Capability 4. We were supposed to check the number of CUDA devices installed in the execution environment using libcuda.so in the following way. We noticed that an error occurs when we call cuDeviceGetCount()5. The error was 304 CUDA_ERROR_OPERATING_SYSTEM.

def get_compute_capabilities(cuda): """ 1. find libcuda.so library (GPU driver) (/usr/lib) init_device -> init variables -> call function by reference 2. call extern C function to determine CC (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) 3. Check for CUDA errors https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 """ nGpus = ct.c_int() cc_major = ct.c_int() cc_minor = ct.c_int() device = ct.c_int() # highlight-next-line check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus))) ccs = [] for i in range(nGpus.value): check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i)) ref_major = ct.byref(cc_major) ref_minor = ct.byref(cc_minor) # 2. call extern C function to determine CC check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)) ccs.append(f"{cc_major.value}.{cc_minor.value}") return ccs

What is bitsandbytes?

Since the advent of Transformer, language models have shown high performance gains, and it has become a trend to increase the size of the model by stacking more Transformer blocks. This has led to a large number of GPU resources being required not only to train the model but also to service it. For example, to service GPT-3 with 175B parameters, eight 80GB A100 GPUs costing about $15,000 are required. This is a huge burden not only for individuals, but also for enterprises or research institutes, which is why there is a lot of research on lightweighting inference models for servicing.

Image source: A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes (Hugging Face)

bitsandbytes has open-sourced LLM.int8()6, a work by Tim Dettmers, a PhD candidate at the University of Washington, with Facebook AI Research (now Meta AI). It has shown to reduce the size of the model while maintaining performance by applying a vector-wise quantization method that treats each vector independently when computing matrix products, and by using a mix of 8-bit and 16-bit techniques to minimize losses by representing important vectors in 16-bit. It has been merged into Hugging Face's Transformer implementation and is used in a variety of models including [Llama2] (https://github.com/facebookresearch/llama-recipes/blob/cd82118b74d2fd739bd6227af33b661d04a97406/requirements.txt#L6), [QLoRA] (https://github.com/artidoro/qlora/blob/6c6fc4653abd17ce550f48878a24c7bd8772e98a/requirements.txt#L1), [KoAlpaca] (https://github.com/Beomi/KoAlpaca/blob/4596f882957d286b4d60559b97dcf783822d23f5/webui/requirements.txt#L5), and [KULLM] (https://github.com/nlpai-lab/KULLM/blob/b7a78b62ed6cd9d83c51ad5a92a9dd40b9f35998/requirements.txt#L4).

Identify the cause

Now that we've located and reproduced the problem, it's time to get to the bottom of it. I looked to see if there were any similar cases, but I couldn't find any. Also, cuInit() was called normally, making it even more difficult to pinpoint the cause.

import ctypes count = ctypes.c_int() libcuda = ctypes.CDLL("libcuda.so") libcuda.cuInit(0) # 0 (CUDA_SUCCESS) libcuda.cuDeviceGetCount(ctypes.byref(count)) # 304 (CUDA_ERROR_OPERATING_SYSTEM) libcudart = ctypes.CDLL("libcudart.so") libcudart.cudaGetDeviceCount(ctypes.byref(count)) # 304 (CUDA_ERROR_OPERATING_SYSTEM)

I filed an issue on the GitHub repo (TimDettmers/bitsandbytes#264) for advice, and was told to update the package to the latest version and try again. After updating to version 0.38.0.post1, which was the latest at the time, I tested again, and the same problem occurred. I couldn't afford to lose too much time, so I decided to switch gears and remove the offending part.

Image source: Greco-Roman Mythology in Comics (Ghana Publishers)

Troubleshooting

My first approach was to use CUDA-Python7. CUDA-Python is the CUDA Python Low-Level Bindings package officially distributed by NVIDIA. I had used it before and found it useful, so I immediately thought of it and decided to install and test it.

$ pip install cuda-python
from cuda import cuda from cuda import cudart cuda.cuInit(0) # (<CUresult.CUDA_SUCCESS: 0>,) cudart.cudaGetDeviceCount() # (<cudaError_t.cudaSuccess: 0>, 1)

Fortunately, cudart.cudaGetDeviceCount() worked fine, and I proceeded to test integrating it into bitsandbytes. However, calling torch.cuda.is_available() after calling cuda.cuInit(0) resulted in an error. This was because I called cudaGetDeviceCount() inside torch.cuda.is_available().

from cuda import cuda, cudart cuda.cuInit(0) # <CUresult.CUDA_SUCCESS: 0>,) cuda.cudaGetDeviceCount() # (<cudaError_t.cudaSuccess: 0>, 1) import bitsandbytes # ... # /opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:82: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 304: OS call failed or operation not supported on this OS (Triggered internally at /opt/pytorch/pytorch/c10/cuda/CUDAFunctions.cpp:109.) # return torch._C._cuda_getDeviceCount() > 0 # ...

The problem seemed to be back to square one. I took a breath and calmly reread the error log above. Then something caught my eye.

torch._C._cuda_getDeviceCount() > 0

Note that bitsandbytes was already using PyTorch internally, which means it had a dependency on PyTorch. To be precise, `bitsandbytes' had a dependency on lion-pytorch, which had a dependency on PyTorch. And PyTorch already had an interface to CUDA functions, which I decided to take advantage of this time.

Fortunately, all of the CUDA functions used by bitsandbytes existed in PyTorch. I made the following changes to the functions that were previously called via libcuda.so and libcudart.so.

libcuda/libcudarttorch
libcuda.cuDeviceGetCount()torch.cuda.device_count()
libcuda.cuDeviceGet()torch.cuda.device()
libcuda.cuDeviceComputeCapability()torch.cuda.get_device_capability()
libcudart.cudaRuntimeGetVersion()torch.version.cuda

After verifying that it worked after the change, I registered a PR in the GitHub repository (TimDettmers/bitsandbytes#375) to apply to the distribution package version.

Postscript

On July 14, 2023, about two months after registering the PR, the patch was merged into the main branch and included in version 0.40.1.

I was also able to get some feedback from the author, Tim Dettmers, whose thoughts and philosophy are evident in this short article. Through this opportunity, I was able to learn more about LLM's ecosystem. It was also the first time in a long time that I was able to feel the fun of open source activities. I think the appeal of open source activities is that we can collaborate beyond spatial constraints and learn from each other's ideas. We run an open source version of Backend.AI alongside an enterprise version. We will always strive to provide a better user experience and a better developer experience.

This post is automatically translated from Korean

Footnotes

  1. NVIDIA GPU Cloud

  2. The NGC catalog hosts containers for AI/ML, metaverse, and HPC applications and are performance-optimized, tested, and ready to deploy on GPU-powered on-prem, cloud, and edge systems.

  3. ViperGPT: Visual Inference via Python Execution for Reasoning, March 14, 2023.

  4. https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability

  5. https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74

  6. LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale, November 10, 2022.

  7. https://developer.nvidia.com/cuda-python

We're here for you!

Complete the form and we'll be in touch soon

Contact Us

Headquarter & HPC Lab

Namyoung Bldg. 4F/5F, 34, Seolleung-ro 100-gil, Gangnam-gu, Seoul, Republic of Korea

© Lablup Inc. All rights reserved.