cuda version checks
Browse files- README.md +24 -2
- float8_quantize.py +22 -6
- requirements.txt +2 -1
README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Flux
|
| 2 |
|
| 3 |
This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
|
| 4 |
|
|
@@ -13,12 +13,34 @@ This repository contains an implementation of the Flux model, along with an API
|
|
| 13 |
|
| 14 |
## Installation
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
To install the required dependencies, run:
|
| 17 |
|
| 18 |
```bash
|
| 19 |
-
pip install -r requirements.txt
|
| 20 |
```
|
| 21 |
|
|
|
|
|
|
|
| 22 |
## Usage
|
| 23 |
|
| 24 |
You can run the API server using the following command:
|
|
|
|
| 1 |
+
# Flux FP8 (true) Matmul Implementation with FastAPI
|
| 2 |
|
| 3 |
This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
|
| 4 |
|
|
|
|
| 13 |
|
| 14 |
## Installation
|
| 15 |
|
| 16 |
+
This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 20 |
+
mamba activate flux-fp8-matmul-api
|
| 21 |
+
|
| 22 |
+
# or with conda
|
| 23 |
+
conda create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 24 |
+
conda activate flux-fp8-matmul-api
|
| 25 |
+
|
| 26 |
+
# or with nightly... (which is what I am using) - also, just switch 'mamba' to 'conda' if you are using conda
|
| 27 |
+
mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia
|
| 28 |
+
mamba activate flux-fp8-matmul-api
|
| 29 |
+
|
| 30 |
+
# or with pip
|
| 31 |
+
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 32 |
+
# or pip nightly
|
| 33 |
+
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
To install the required dependencies, run:
|
| 37 |
|
| 38 |
```bash
|
| 39 |
+
python -m pip install -r requirements.txt
|
| 40 |
```
|
| 41 |
|
| 42 |
+
If you get errors installing `torch-cublas-hgemm`, feel free to comment it out in requirements.txt, since it's not necessary, but will speed up inference for non-fp8 linear layers.
|
| 43 |
+
|
| 44 |
## Usage
|
| 45 |
|
| 46 |
You can run the API server using the following command:
|
float8_quantize.py
CHANGED
|
@@ -9,8 +9,21 @@ from torchao.float8.float8_utils import (
|
|
| 9 |
from torch.nn import init
|
| 10 |
import math
|
| 11 |
from torch.compiler import is_compiling
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
from cublas_ops import CublasLinear
|
| 16 |
except ImportError:
|
|
@@ -244,19 +257,22 @@ class F8Linear(nn.Module):
|
|
| 244 |
x = self.quantize_input(x)
|
| 245 |
|
| 246 |
prev_dims = x.shape[:-1]
|
| 247 |
-
|
| 248 |
x = x.view(-1, self.in_features)
|
| 249 |
|
| 250 |
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
|
| 251 |
-
|
| 252 |
x,
|
| 253 |
self.float8_data.T,
|
| 254 |
-
self.input_scale_reciprocal,
|
| 255 |
-
self.scale_reciprocal,
|
| 256 |
bias=self.bias,
|
| 257 |
out_dtype=self.weight.dtype,
|
| 258 |
use_fast_accum=True,
|
| 259 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
@classmethod
|
| 262 |
def from_linear(
|
|
|
|
| 9 |
from torch.nn import init
|
| 10 |
import math
|
| 11 |
from torch.compiler import is_compiling
|
| 12 |
+
from torch import __version__
|
| 13 |
+
from torch.version import cuda
|
| 14 |
|
| 15 |
+
IS_TORCH_2_4 = __version__ >= (2, 4) and __version__ < (2, 5)
|
| 16 |
+
LT_TORCH_2_4 = __version__ < (2, 4)
|
| 17 |
+
if LT_TORCH_2_4:
|
| 18 |
+
if not hasattr(torch, "_scaled_mm"):
|
| 19 |
+
raise RuntimeError(
|
| 20 |
+
"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
|
| 21 |
+
)
|
| 22 |
+
CUDA_VERSION = float(cuda) if cuda else 0
|
| 23 |
+
if CUDA_VERSION < 12.4:
|
| 24 |
+
raise RuntimeError(
|
| 25 |
+
f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
|
| 26 |
+
)
|
| 27 |
try:
|
| 28 |
from cublas_ops import CublasLinear
|
| 29 |
except ImportError:
|
|
|
|
| 257 |
x = self.quantize_input(x)
|
| 258 |
|
| 259 |
prev_dims = x.shape[:-1]
|
|
|
|
| 260 |
x = x.view(-1, self.in_features)
|
| 261 |
|
| 262 |
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
|
| 263 |
+
out = torch._scaled_mm(
|
| 264 |
x,
|
| 265 |
self.float8_data.T,
|
| 266 |
+
scale_a=self.input_scale_reciprocal,
|
| 267 |
+
scale_b=self.scale_reciprocal,
|
| 268 |
bias=self.bias,
|
| 269 |
out_dtype=self.weight.dtype,
|
| 270 |
use_fast_accum=True,
|
| 271 |
+
)
|
| 272 |
+
if IS_TORCH_2_4:
|
| 273 |
+
out = out[0]
|
| 274 |
+
out = out.view(*prev_dims, self.out_features)
|
| 275 |
+
return out
|
| 276 |
|
| 277 |
@classmethod
|
| 278 |
def from_linear(
|
requirements.txt
CHANGED
|
@@ -12,4 +12,5 @@ sentencepiece
|
|
| 12 |
click
|
| 13 |
accelerate
|
| 14 |
quanto
|
| 15 |
-
pydash
|
|
|
|
|
|
| 12 |
click
|
| 13 |
accelerate
|
| 14 |
quanto
|
| 15 |
+
pydash
|
| 16 |
+
pybase64
|