awacke1 commited on
Commit
004c842
1 Parent(s): 7859b33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
2
+ import torch
3
+ import gradio as gr
4
+
5
+ import os
6
+ import csv
7
+ from gradio import inputs, outputs
8
+ from datetime import datetime
9
+ import fastapi
10
+ from typing import List, Dict
11
+ import httpx
12
+ import pandas as pd
13
+ import datasets as ds
14
+ UseMemory=True
15
+
16
+ HF_TOKEN=os.environ.get("HF_TOKEN")
17
+
18
+ def SaveResult(text, outputfileName):
19
+ basedir = os.path.dirname(__file__)
20
+ savePath = outputfileName
21
+ print("Saving: " + text + " to " + savePath)
22
+ from os.path import exists
23
+ file_exists = exists(savePath)
24
+ if file_exists:
25
+ with open(outputfileName, "a") as f: #append
26
+ f.write(str(text.replace("\n"," ")))
27
+ f.write('\n')
28
+ else:
29
+ with open(outputfileName, "w") as f: #write
30
+ f.write(str("time, message, text\n")) # one time only to get column headers for CSV file
31
+ f.write(str(text.replace("\n"," ")))
32
+ f.write('\n')
33
+ return
34
+
35
+
36
+ def store_message(name: str, message: str, outputfileName: str):
37
+ basedir = os.path.dirname(__file__)
38
+ savePath = outputfileName
39
+ if name and message:
40
+ with open(savePath, "a") as csvfile:
41
+ writer = csv.DictWriter(csvfile, fieldnames=[ "time", "message", "name", ])
42
+ writer.writerow(
43
+ {"time": str(datetime.now()), "message": message.strip(), "name": name.strip() }
44
+ )
45
+ df = pd.read_csv(savePath)
46
+ df = df.sort_values(df.columns[0],ascending=False)
47
+ return df
48
+
49
+ mname = "facebook/blenderbot-400M-distill"
50
+ model = BlenderbotForConditionalGeneration.from_pretrained(mname)
51
+ tokenizer = BlenderbotTokenizer.from_pretrained(mname)
52
+
53
+ def take_last_tokens(inputs, note_history, history):
54
+ if inputs['input_ids'].shape[1] > 128:
55
+ inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-128:].tolist()])
56
+ inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-128:].tolist()])
57
+ note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
58
+ history = history[1:]
59
+ return inputs, note_history, history
60
+
61
+ def add_note_to_history(note, note_history):# good example of non async since we wait around til we know it went okay.
62
+ note_history.append(note)
63
+ note_history = '</s> <s>'.join(note_history)
64
+ return [note_history]
65
+
66
+ title = "💬ChatBack🧠💾"
67
+ description = """Chatbot With persistent memory dataset allowing multiagent system AI to access a shared dataset as memory pool with stored interactions.
68
+ Current Best SOTA Chatbot: https://huggingface.co/facebook/blenderbot-400M-distill?text=Hey+my+name+is+ChatBack%21+Are+you+ready+to+rock%3F """
69
+
70
+ def get_base(filename):
71
+ basedir = os.path.dirname(__file__)
72
+ loadPath = basedir + "\\" + filename
73
+ return loadPath
74
+
75
+ def chat(message, history):
76
+ history = history or []
77
+ if history:
78
+ history_useful = ['</s> <s>'.join([str(a[0])+'</s> <s>'+str(a[1]) for a in history])]
79
+ else:
80
+ history_useful = []
81
+
82
+ history_useful = add_note_to_history(message, history_useful)
83
+ inputs = tokenizer(history_useful, return_tensors="pt")
84
+ inputs, history_useful, history = take_last_tokens(inputs, history_useful, history)
85
+ reply_ids = model.generate(**inputs)
86
+ response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
87
+ history_useful = add_note_to_history(response, history_useful)
88
+ list_history = history_useful[0].split('</s> <s>')
89
+ history.append((list_history[-2], list_history[-1]))
90
+
91
+ df=pd.DataFrame()
92
+
93
+ if UseMemory:
94
+ outputfileName = 'ChatbotMemory.csv'
95
+ df = store_message(message, response, outputfileName) # Save to dataset
96
+ basedir = get_base(outputfileName)
97
+
98
+ return history, df, basedir
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("<h1><center>🍰Gradio chatbot backed by dataframe CSV memory🎨</center></h1>")
103
+
104
+ with gr.Row():
105
+ t1 = gr.Textbox(lines=1, default="", label="Chat Text:")
106
+ b1 = gr.Button("Respond and Retrieve Messages")
107
+
108
+ with gr.Row(): # inputs and buttons
109
+ s1 = gr.State([])
110
+ df1 = gr.Dataframe(wrap=True, max_rows=1000, overflow_row_behaviour= "paginate")
111
+ with gr.Row(): # inputs and buttons
112
+ file = gr.File(label="File")
113
+ s2 = gr.Markdown()
114
+
115
+ b1.click(fn=chat, inputs=[t1, s1], outputs=[s1, df1, file])
116
+
117
+ demo.launch(debug=True, show_error=True)