File size: 2,977 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
import os
import re
import sys

torch_supported = ['211', '212','220','221']
cuda_supported = ['cu118', 'cu121']
python_supported = ['39', '310', '311']
repo_url = 'https://github.com/chengzeyi/stable-fast'
api_url = 'https://api.github.com/repos/chengzeyi/stable-fast/releases/tags/nightly'
path_url = '/releases/download/nightly'


def install_pip(arg: str):
    import subprocess
    cmd = f'"{sys.executable}" -m pip install -U {arg}'
    print(f'Running: {cmd}')
    result = subprocess.run(cmd, shell=True, check=False, env=os.environ)
    return result.returncode == 0


def get_nightly():
    import requests
    r = requests.get(api_url, timeout=10)
    if r.status_code != 200:
        print('Failed to get nightly version')
        return None
    json = r.json()
    assets = json.get('assets', [])
    if len(assets) == 0:
        print('Failed to get nightly version')
        return None
    asset = assets[0].get('name', '')
    pattern = r"-(.+?)\+"
    match = re.search(pattern, asset)
    if match:
        ver = match.group(1)
        print(f'Nightly version: {ver}')
        return ver
    else:
        print('Failed to get nightly version')
        return None


def install_stable_fast():
    import torch

    python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
    if python_ver not in python_supported:
        raise ValueError(f'StableFast unsupported python: {python_ver} required {python_supported}')
    if sys.platform == 'linux':
        bin_url = 'manylinux2014_x86_64.whl'
    elif sys.platform == 'win32':
        bin_url = 'win_amd64.whl'
    else:
        raise ValueError(f'StableFast unsupported platform: {sys.platform}')

    torch_ver, cuda_ver = torch.__version__.split('+')
    torch_ver = torch_ver.replace('.', '')
    sf_ver = get_nightly()

    if torch_ver not in torch_supported:
        print(f'StableFast unsupported torch: {torch_ver} required {torch_supported}')
        print('Installing from source...')
        url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
    elif cuda_ver not in cuda_supported:
        print(f'StableFast unsupported CUDA: {cuda_ver} required {cuda_supported}')
        print('Installing from source...')
        url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
    elif sf_ver is None:
        print('StableFast cannot determine version')
        print('Installing from source...')
        url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
    else:
        print('Installing wheel...')
        file_url = f'stable_fast-{sf_ver}+torch{torch_ver}{cuda_ver}-cp{python_ver}-cp{python_ver}-{bin_url}'
        url = f'{repo_url}/{path_url}/{file_url}'

    ok = install_pip(url)
    if ok:
        import sfast
        print(f'StableFast installed: {sfast.__version__}')
    else:
        print('StableFast install failed')

if __name__ == '__main__':
    install_stable_fast()