|
|
|
import gradio as gr |
|
import torch |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel |
|
|
|
|
|
model_path = 'redael/model_udc' |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_path) |
|
model = GPT2LMHeadModel.from_pretrained(model_path) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
if device.type == 'cuda': |
|
model = model.half() |
|
|
|
def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, temperature=0.7, top_p=0.9, repetition_penalty=2.0): |
|
|
|
prompt = f"User: {prompt}\nAssistant:" |
|
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) |
|
outputs = model.generate( |
|
inputs['input_ids'], |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
pad_token_id=tokenizer.eos_token_id, |
|
num_beams=num_beams, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
early_stopping=True |
|
) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = response.split("Assistant:")[-1].strip() |
|
response_lines = response.split('\n') |
|
clean_response = [] |
|
for line in response_lines: |
|
if "User:" not in line and "Assistant:" not in line: |
|
clean_response.append(line) |
|
response = ' '.join(clean_response) |
|
return response.strip() |
|
|
|
def respond(message, history: list[tuple[str, str]]): |
|
|
|
system_message = "You are a friendly chatbot." |
|
conversation = system_message + "\n" |
|
for user_message, assistant_response in history: |
|
conversation += f"User: {user_message}\nAssistant: {assistant_response}\n" |
|
conversation += f"User: {message}\nAssistant:" |
|
|
|
|
|
max_tokens = 100 |
|
temperature = 0.7 |
|
top_p = 0.9 |
|
|
|
response = generate_response(conversation, model, tokenizer, max_length=max_tokens, temperature=temperature, top_p=top_p) |
|
|
|
return response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|