File size: 6,590 Bytes
edb9ee2
 
 
 
 
 
337b381
a9bf2b7
edb9ee2
 
 
 
 
 
 
337b381
 
 
 
 
 
 
 
 
 
 
 
7f7fd49
337b381
 
 
 
 
ae88e37
7f7fd49
 
337b381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9bf2b7
 
 
 
 
 
 
 
337b381
edb9ee2
feea605
 
 
 
edb9ee2
feea605
edb9ee2
 
a9bf2b7
 
 
edb9ee2
 
 
 
 
a9bf2b7
 
 
 
edb9ee2
 
337b381
ae88e37
edb9ee2
 
 
 
 
337b381
ae88e37
 
7f7fd49
 
337b381
 
 
 
a9bf2b7
edb9ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
6893def
edb9ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a26a04
edb9ee2
 
 
 
a9bf2b7
 
edb9ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a26a04
edb9ee2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import subprocess
import signal
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
import tempfile
import torch
import requests

from huggingface_hub import HfApi, ModelCard, whoami
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from pathlib import Path
from textwrap import dedent


###########

import subprocess
import threading
from queue import Queue, Empty

def stream_output(pipe, queue):
    """Read output from pipe and put it in the queue."""
    for line in iter(pipe.readline, b''):
        queue.put(line.decode('utf-8').rstrip())
    pipe.close()

def run_command(command, env_vars):
    # Create process with pipes for stdout and stderr
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        #bufsize=1,
        universal_newlines=False,
        env=env_vars,
    )
    
    # Create queues to store output
    stdout_queue = Queue()
    stderr_queue = Queue()
    
    # Create and start threads to read output
    stdout_thread = threading.Thread(target=stream_output, args=(process.stdout, stdout_queue))
    stderr_thread = threading.Thread(target=stream_output, args=(process.stderr, stderr_queue))
    stdout_thread.daemon = True
    stderr_thread.daemon = True
    stdout_thread.start()
    stderr_thread.start()

    output_stdout = ""
    output_stderr = ""
    # Monitor output in real-time
    while process.poll() is None:
        # Check stdout
        try:
            stdout_line = stdout_queue.get_nowait()
            print(f"STDOUT: {stdout_line}")
            output_stdout += stdout_line + "\n"
        except Empty:
            pass
            
        # Check stderr
        try:
            stderr_line = stderr_queue.get_nowait()
            print(f"STDERR: {stderr_line}")
            output_stderr += stderr_line + "\n"
        except Empty:
            pass
    
    # Get remaining lines
    stdout_thread.join()
    stderr_thread.join()
    
    return (process.returncode, output_stdout, output_stderr)

###########

def guess_base_model(ft_model_id):
    res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}")
    res = res.json()
    for tag in res["tags"]:
        if tag.startswith("base_model:"):
            return tag.split(":")[-1]
    raise Exception("Cannot guess the base model, please enter it manually")


def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
    # validate the oauth token
    try:
        whoami(oauth_token.token)
    except Exception as e:
        raise gr.Error("You must be logged in")

    model_name = ft_model_id.split('/')[-1]

    # validate the oauth token
    whoami(oauth_token.token)

    if not os.path.exists("outputs"):
        os.makedirs("outputs")

    try:
        api = HfApi(token=oauth_token.token)

        if not base_model_id:
            base_model_id = guess_base_model(ft_model_id)
            print("guess_base_model", base_model_id)
        
        with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            cmd = [
                "mergekit-extract-lora",
                ft_model_id,
                base_model_id,
                outputdir,
                f"--rank={rank}",
                f"--device={device}"
            ]
            print("cmd", cmd)
            env_vars = dict(os.environ, HF_TOKEN=oauth_token.token)
            returncode, output_stdout, output_stderr = run_command(cmd, env_vars)
            print("returncode", returncode)
            print("output_stdout", output_stdout)
            print("output_stderr", output_stderr)
            if returncode != 0:
                raise Exception(f"Error converting to LoRA PEFT {output_stderr}")
            print("Model converted to LoRA PEFT successfully!")
            print(f"Converted model path: {outputdir}")

            # Check output dir
            if not os.listdir(outputdir):
                raise Exception("Output directory is empty!")

            # Create repo
            username = whoami(oauth_token.token)["name"]
            new_repo_url = api.create_repo(repo_id=f"{username}/LoRA-{model_name}", exist_ok=True, private=private_repo)
            new_repo_id = new_repo_url.repo_id
            print("Repo created successfully!", new_repo_url)

            # Upload files
            api.upload_folder(
                folder_path=outputdir,
                path_in_repo="",
                repo_id=new_repo_id,
            )
            print("Uploaded", outputdir)

        return (
            f'<h1>βœ… DONE</h1><br/><br/>Find your repo here: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{new_repo_id}</a>'
        )
    except Exception as e:
        return (f"<h1>❌ ERROR</h1><br/><br/>{e}")


css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo: 
    gr.Markdown("You must be logged in.")
    gr.LoginButton(min_width=250)

    ft_model_id = HuggingfaceHubSearch(
        label="Fine tuned model repository",
        placeholder="Fine tuned model",
        search_type="model",
    )

    base_model_id = HuggingfaceHubSearch(
        label="Base model repository (optional)",
        placeholder="If empty, it will be guessed from repo tags",
        search_type="model",
    )

    rank = gr.Dropdown(
        ["16", "32", "64", "128"],
        label="LoRA rank",
        info="Higher the rank, better the result, but heavier the adapter",
        value="32",
        filterable=False,
        visible=True
    )

    private_repo = gr.Checkbox(
        value=False,
        label="Private Repo",
        info="Create a private repo under your username."
    )

    iface = gr.Interface(
        fn=process_model,
        inputs=[
            ft_model_id,
            base_model_id,
            rank,
            private_repo,
        ],
        outputs=[
            gr.Markdown(label="output"),
        ],
        title="Convert fine tuned model into LoRA with mergekit-extract-lora",
        description="The space takes a fine tuned model, a base model, then make a PEFT-compatible LoRA adapter based on the difference between 2 models.<br/><br/>NOTE: Each conversion takes about <b>5 to 20 minutes</b>, depending on how big the model is.",
        api_name=False
    )

# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False)