DSDUDEd commited on
Commit
c00d91a
·
verified ·
1 Parent(s): 1f289ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import gradio as gr
4
+ from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
5
+
6
+ # File to store chat memory
7
+ MEMORY_FILE = "cass_memory.json"
8
+ MAX_MEMORY = 50 # max number of message pairs to keep
9
+
10
+ # Load model and tokenizer
11
+ model_name = "DSDUDEd/Cass-Beta1.3"
12
+ model = GPT2LMHeadModel.from_pretrained(model_name)
13
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name)
14
+
15
+ # Load memory if exists
16
+ if os.path.exists(MEMORY_FILE):
17
+ with open(MEMORY_FILE, "r") as f:
18
+ memory = json.load(f)
19
+ else:
20
+ memory = []
21
+
22
+ def save_memory():
23
+ with open(MEMORY_FILE, "w") as f:
24
+ json.dump(memory[-MAX_MEMORY:], f, indent=2)
25
+
26
+ def chat_with_cass(user_input):
27
+ # Combine memory into context
28
+ context = " ".join([f"User: {u} Cass: {c}" for u, c in memory])
29
+ input_text = context + f" User: {user_input} Cass:"
30
+
31
+ inputs = tokenizer(input_text, return_tensors="pt")
32
+
33
+ outputs = model.generate(
34
+ **inputs,
35
+ max_length=150,
36
+ do_sample=True,
37
+ temperature=0.8,
38
+ top_p=0.9,
39
+ pad_token_id=tokenizer.eos_token_id
40
+ )
41
+
42
+ reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ new_reply = reply.split("Cass:")[-1].strip()
44
+
45
+ # Add to memory and save
46
+ memory.append((user_input, new_reply))
47
+ save_memory()
48
+
49
+ return memory, memory
50
+
51
+ # Gradio Interface
52
+ with gr.Blocks() as demo:
53
+ chatbot = gr.Chatbot(value=memory)
54
+ msg = gr.Textbox(label="You")
55
+ clear = gr.Button("Clear")
56
+
57
+ msg.submit(chat_with_cass, [msg], [chatbot])
58
+ clear.click(lambda: [], None, chatbot)
59
+
60
+ demo.launch()