File size: 1,710 Bytes
bf32e91
 
 
94dc208
 
bf32e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94dc208
bf32e91
 
 
 
 
 
 
94dc208
bf32e91
94dc208
bf32e91
 
94dc208
bf32e91
94dc208
bf32e91
 
 
 
 
 
 
 
993bb68
fd7017c
94dc208
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import gradio as gr
from transformers import AutoModel, AutoTokenizer

def get_code_generative_models():
    models_dir = os.path.join(os.getcwd(), "models")
    models = []
    for model_name in os.listdir(models_dir):
        model_path = os.path.join(models_dir, model_name)
        if os.path.isdir(model_path):
            model_info = AutoModel.from_pretrained(model_path)
            if "config.json" in [f.name for f in model_info.files]:
                models.append((model_name, model_path))
    return models

def model_inference(model_name, model_path, input_data):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    inputs = tokenizer(input_data, return_tensors="pt")
    outputs = model(**inputs)
    result = outputs.last_hidden_state[:, 0, :]
    return result.tolist()

def main():
    models = get_code_generative_models()
    with gr.Blocks() as demo:
        gr.Markdown("### Select Model and Input")
        with gr.Row():
            model_name = gr.Dropdown(label="Model", choices=[m[0] for m in models])
            input_data = gr.Textbox(label="Input")

        model_path = gr.State(None)

        def update_model_path(model_name):
            model_path.set(next(filter(lambda m: m[0] == model_name, models))[1])

        input_data.change(update_model_path, inputs=model_name, outputs=model_path)

        output = gr.Textbox(label="Output")

        def infer(model_name, input_data):
            return model_inference(model_name, model_path, input_data)

        output.change(fn=infer, inputs=[model_name, input_data], outputs=output)

    interface = demo.launch()

if __name__ == "__main__":
    main()