Ozaii commited on
Commit
b8afd09
1 Parent(s): c272035

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create app.py file
2
+ app_script = """
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ # Load the model and tokenizer
8
+ model_path = "Ozaii/TinyWali1.1B"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
10
+ model = AutoModelForCausalLM.from_pretrained(model_path)
11
+
12
+ # Ensure the model is in evaluation mode and on the correct device
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ # Define Generation Parameters and Function with Enhanced Context Management
18
+ def generate_response(user_input, chat_history):
19
+ max_context_length = 750 # Specify the maximum context length
20
+ max_response_length = 150 # Specify the maximum response length
21
+
22
+ # Prepare the prompt with chat history
23
+ prompt = ""
24
+ for message in chat_history:
25
+ if message[0] is not None:
26
+ prompt += f"User: {message[0]}\n"
27
+ if message[1] is not None:
28
+ prompt += f"Assistant: {message[1]}\n"
29
+ prompt += f"User: {user_input}\nAssistant:"
30
+
31
+ # Ensure the context does not exceed the maximum context length
32
+ prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
33
+ if len(prompt_tokens) > max_context_length:
34
+ prompt_tokens = prompt_tokens[-max_context_length:]
35
+ prompt = tokenizer.decode(prompt_tokens, clean_up_tokenization_spaces=True)
36
+
37
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
38
+
39
+ # Generate response
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ inputs.input_ids,
43
+ max_length=len(inputs.input_ids[0]) + max_response_length, # Limit the maximum length for context and response
44
+ min_length=45,
45
+ temperature=0.7, # Slightly higher temperature for more diverse responses
46
+ top_k=30,
47
+ top_p=0.9, # Allow a bit more randomness
48
+ repetition_penalty=1.1, # Mild repetition penalty
49
+ no_repeat_ngram_size=3, # Ensure no repeated phrases
50
+ eos_token_id=tokenizer.eos_token_id,
51
+ pad_token_id=tokenizer.eos_token_id
52
+ )
53
+
54
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+
56
+ # Post-process the assistant's response
57
+ assistant_response = response.split("Assistant:")[-1].strip()
58
+ # Ensure the response ends properly by stripping incomplete sentences
59
+ assistant_response = assistant_response.split('\\n')[0].strip()
60
+
61
+ # Append the interaction to the chat history
62
+ chat_history.append((user_input, assistant_response))
63
+
64
+ # Return the updated chat history
65
+ return chat_history, chat_history
66
+
67
+ def restart_chat():
68
+ return [], []
69
+
70
+ # Create Gradio Interface
71
+ with gr.Blocks() as chat_interface:
72
+ gr.Markdown("<h1><center>W.AI Chat Nikker xD</center></h1>")
73
+ chat_history = gr.State([])
74
+ with gr.Column():
75
+ chatbox = gr.Chatbot()
76
+ with gr.Row():
77
+ user_input = gr.Textbox(show_label=False, placeholder="Summon Wali Here...")
78
+ submit_button = gr.Button("Send")
79
+ restart_button = gr.Button("Restart")
80
+
81
+ submit_button.click(
82
+ generate_response,
83
+ inputs=[user_input, chat_history],
84
+ outputs=[chatbox, chat_history]
85
+ )
86
+
87
+ restart_button.click(
88
+ restart_chat,
89
+ inputs=[],
90
+ outputs=[chatbox, chat_history]
91
+ )
92
+
93
+ # Launch the Gradio interface
94
+ chat_interface.launch(share=True)
95
+ """
96
+ # Save the script to app.py
97
+ with open("app.py", "w") as f:
98
+ f.write(app_script)