File size: 5,345 Bytes
c66c032
 
02b0952
7ed8a9a
02b0952
7ed8a9a
 
 
 
02b0952
 
 
7ed8a9a
ff120ef
02caa8d
ad9571a
ff120ef
 
c66c032
ff120ef
 
 
 
 
 
7ed8a9a
 
 
 
 
 
 
 
 
c66c032
02b0952
ff120ef
 
 
7ed8a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b0952
 
ff120ef
 
 
 
 
116a0d1
 
ff120ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b0952
 
ff120ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b0952
7ed8a9a
 
 
02b0952
d747da0
02b0952
ff120ef
02b0952
 
 
 
ff120ef
02b0952
ff120ef
d747da0
ff120ef
 
d747da0
ff120ef
02b0952
ff120ef
 
 
 
 
 
229f61f
ff120ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b0952
02caa8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Dictionary to store loaded models and tokenizers
loaded_models = {}

# List of available models (ensure these are correct and accessible)
models = [
    "Qwen/Qwen2.5-7B-Instruct",
    "Qwen/Qwen2.5-0.5B-Instruct"
]


def load_all_models():
    """
    Pre-loads all models and their tokenizers into memory.
    """
    for model_name in models:
        if model_name not in loaded_models:
            try:
                logger.info(f"Loading model: {model_name}")
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForCausalLM.from_pretrained(model_name).to(
                    "cuda" if torch.cuda.is_available() else "cpu")
                loaded_models[model_name] = (model, tokenizer)
                logger.info(f"Successfully loaded {model_name}")
            except Exception as e:
                logger.error(f"Failed to load model {model_name}: {e}")

def get_model_response(model_name, message):
    """
    Generates a response from the specified model given a user message.
    """
    try:
        model, tokenizer = loaded_models[model_name]
        inputs = tokenizer(message, return_tensors="pt").to(model.device)

        # Generate response with appropriate parameters
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=512,
                do_sample=True,
                top_p=0.95,
                top_k=50
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except KeyError:
        logger.error(f"Model {model_name} not found in loaded_models.")
        return f"Error: Model {model_name} not loaded."
    except Exception as e:
        logger.error(f"Error generating response from {model_name}: {e}")
        return f"Error generating response: {e}"


def chat(message, history1, history2, model1, model2):
    """
    Handles the chat interaction by getting responses from both models
    and updating their respective histories.
    """
    response1 = get_model_response(model1, message)
    response2 = get_model_response(model2, message)

    history1 = history1 or []
    history2 = history2 or []

    # Update history for Model 1
    history1.append(("User", message))
    history1.append((model1.split("/")[-1], response1))

    # Update history for Model 2
    history2.append(("User", message))
    history2.append((model2.split("/")[-1], response2))

    return history1, history2


# Initialize vote counts
vote_counts = {"model1": 0, "model2": 0}


def upvote_vote(model1, model2):
    """
    Increments the vote count for Model 1 and returns updated counts.
    """
    vote_counts["model1"] += 1
    return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}"


def downvote_vote(model1, model2):
    """
    Increments the vote count for Model 2 and returns updated counts.
    """
    vote_counts["model2"] += 1
    return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}"


def clear_chat():
    """
    Clears both chat histories and resets vote counts.
    """
    global vote_counts
    vote_counts = {"model1": 0, "model2": 0}
    return [], [], "Votes - 0, 0"


# Pre-load all models before building the Gradio interface
load_all_models()

with gr.Blocks() as demo:
    gr.Markdown("# πŸ€– Model Comparison Space")

    # Dropdowns for selecting models
    with gr.Row():
        model1_dropdown = gr.Dropdown(choices=models, label="Model 1", value=models[0])
        model2_dropdown = gr.Dropdown(choices=models, label="Model 2", value=models[1])

    # Separate chatboxes for each model
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Model 1 Chat")
            chatbot1 = gr.Chatbot(label=f"{models[0].split('/')[-1]} Chat History")
        with gr.Column():
            gr.Markdown("### Model 2 Chat")
            chatbot2 = gr.Chatbot(label=f"{models[1].split('/')[-1]} Chat History")

    # Input textbox for user message
    msg = gr.Textbox(label="πŸ’¬ Your Message", placeholder="Type your message here...")

    # Buttons for upvote, downvote, and clearing the chat
    with gr.Row():
        upvote = gr.Button("πŸ‘ Upvote Model 1")
        downvote = gr.Button("πŸ‘ Downvote Model 2")
        clear = gr.Button("🧹 Clear Chat")

    # Textbox to display vote counts
    vote_text = gr.Textbox(label="πŸ† Vote Counts", value="Votes - 0, 0", interactive=False)

    # Define interactions
    msg.submit(
        chat,
        inputs=[msg, chatbot1, chatbot2, model1_dropdown, model2_dropdown],
        outputs=[chatbot1, chatbot2]
    )

    upvote.click(
        upvote_vote,
        inputs=[model1_dropdown, model2_dropdown],
        outputs=vote_text
    )

    downvote.click(
        downvote_vote,
        inputs=[model1_dropdown, model2_dropdown],
        outputs=vote_text
    )

    clear.click(
        clear_chat,
        outputs=[chatbot1, chatbot2, vote_text]
    )

if __name__ == "__main__":
    demo.launch(share=True)