|
|
|
app_script = """ |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
# Load the model and tokenizer |
|
model_path = "Ozaii/TinyWali1.1B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
# Ensure the model is in evaluation mode and on the correct device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
# Define Generation Parameters and Function with Enhanced Context Management |
|
def generate_response(user_input, chat_history): |
|
max_context_length = 750 # Specify the maximum context length |
|
max_response_length = 150 # Specify the maximum response length |
|
|
|
# Prepare the prompt with chat history |
|
prompt = "" |
|
for message in chat_history: |
|
if message[0] is not None: |
|
prompt += f"User: {message[0]}\n" |
|
if message[1] is not None: |
|
prompt += f"Assistant: {message[1]}\n" |
|
prompt += f"User: {user_input}\nAssistant:" |
|
|
|
# Ensure the context does not exceed the maximum context length |
|
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) |
|
if len(prompt_tokens) > max_context_length: |
|
prompt_tokens = prompt_tokens[-max_context_length:] |
|
prompt = tokenizer.decode(prompt_tokens, clean_up_tokenization_spaces=True) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
# Generate response |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_length=len(inputs.input_ids[0]) + max_response_length, # Limit the maximum length for context and response |
|
min_length=45, |
|
temperature=0.7, # Slightly higher temperature for more diverse responses |
|
top_k=30, |
|
top_p=0.9, # Allow a bit more randomness |
|
repetition_penalty=1.1, # Mild repetition penalty |
|
no_repeat_ngram_size=3, # Ensure no repeated phrases |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Post-process the assistant's response |
|
assistant_response = response.split("Assistant:")[-1].strip() |
|
# Ensure the response ends properly by stripping incomplete sentences |
|
assistant_response = assistant_response.split('\\n')[0].strip() |
|
|
|
# Append the interaction to the chat history |
|
chat_history.append((user_input, assistant_response)) |
|
|
|
# Return the updated chat history |
|
return chat_history, chat_history |
|
|
|
def restart_chat(): |
|
return [], [] |
|
|
|
# Create Gradio Interface |
|
with gr.Blocks() as chat_interface: |
|
gr.Markdown("<h1><center>W.AI Chat Nikker xD</center></h1>") |
|
chat_history = gr.State([]) |
|
with gr.Column(): |
|
chatbox = gr.Chatbot() |
|
with gr.Row(): |
|
user_input = gr.Textbox(show_label=False, placeholder="Summon Wali Here...") |
|
submit_button = gr.Button("Send") |
|
restart_button = gr.Button("Restart") |
|
|
|
submit_button.click( |
|
generate_response, |
|
inputs=[user_input, chat_history], |
|
outputs=[chatbox, chat_history] |
|
) |
|
|
|
restart_button.click( |
|
restart_chat, |
|
inputs=[], |
|
outputs=[chatbox, chat_history] |
|
) |
|
|
|
# Launch the Gradio interface |
|
chat_interface.launch(share=True) |
|
""" |
|
|
|
with open("app.py", "w") as f: |
|
f.write(app_script) |