|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
import torch |
|
import gradio as gr |
|
from datasets import load_dataset |
|
|
|
|
|
import os |
|
import csv |
|
from gradio import inputs, outputs |
|
import huggingface_hub |
|
from huggingface_hub import Repository, hf_hub_download, upload_file |
|
from datetime import datetime |
|
|
|
|
|
import fastapi |
|
|
|
from typing import List, Dict |
|
import httpx |
|
import pandas as pd |
|
import datasets as ds |
|
|
|
UseMemory=True |
|
HF_TOKEN=os.environ.get("HF_TOKEN") |
|
|
|
def SaveResult(text, outputfileName): |
|
basedir = os.path.dirname(__file__) |
|
savePath = outputfileName |
|
print("Saving: " + text + " to " + savePath) |
|
from os.path import exists |
|
file_exists = exists(savePath) |
|
if file_exists: |
|
with open(outputfileName, "a") as f: |
|
f.write(str(text.replace("\n"," "))) |
|
f.write('\n') |
|
else: |
|
with open(outputfileName, "w") as f: |
|
f.write(str("time, message, text\n")) |
|
f.write(str(text.replace("\n"," "))) |
|
f.write('\n') |
|
return |
|
|
|
|
|
def store_message(name: str, message: str, outputfileName: str): |
|
basedir = os.path.dirname(__file__) |
|
savePath = outputfileName |
|
|
|
|
|
from os.path import exists |
|
file_exists = exists(savePath) |
|
|
|
if not file_exists: |
|
with open(savePath, "w") as f: |
|
f.write("time, message, name\n") |
|
|
|
f.write(f"{str(datetime.now())}, Welcome to Chatback!, System\n") |
|
f.write(f"{str(datetime.now())}, How can I assist you today?, System\n") |
|
|
|
|
|
if name and message: |
|
with open(savePath, "a") as csvfile: |
|
writer = csv.DictWriter(csvfile, fieldnames=["time", "message", "name"]) |
|
writer.writerow( |
|
{"time": str(datetime.now()), "message": message.strip(), "name": name.strip()} |
|
) |
|
|
|
|
|
df = pd.read_csv(savePath) |
|
df = df.sort_values(df.columns[0], ascending=False) |
|
return df |
|
|
|
|
|
mname = "facebook/blenderbot-400M-distill" |
|
model = BlenderbotForConditionalGeneration.from_pretrained(mname) |
|
tokenizer = BlenderbotTokenizer.from_pretrained(mname) |
|
|
|
def take_last_tokens(inputs, note_history, history): |
|
if inputs['input_ids'].shape[1] > 128: |
|
inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-128:].tolist()]) |
|
inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-128:].tolist()]) |
|
note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])] |
|
history = history[1:] |
|
return inputs, note_history, history |
|
|
|
def add_note_to_history(note, note_history): |
|
note_history.append(note) |
|
note_history = '</s> <s>'.join(note_history) |
|
return [note_history] |
|
|
|
title = "💬ChatBack🧠💾" |
|
description = """Chatbot With persistent memory dataset allowing multiagent system AI to access a shared dataset as memory pool with stored interactions. |
|
Current Best SOTA Chatbot: https://huggingface.co/facebook/blenderbot-400M-distill?text=Hey+my+name+is+ChatBack%21+Are+you+ready+to+rock%3F """ |
|
|
|
def get_base(filename): |
|
basedir = os.path.dirname(__file__) |
|
print(basedir) |
|
|
|
loadPath = basedir + filename |
|
print(loadPath) |
|
return loadPath |
|
|
|
def chat(message, history): |
|
history = history or [] |
|
if history: |
|
history_useful = ['</s> <s>'.join([str(a[0])+'</s> <s>'+str(a[1]) for a in history])] |
|
else: |
|
history_useful = [] |
|
|
|
history_useful = add_note_to_history(message, history_useful) |
|
inputs = tokenizer(history_useful, return_tensors="pt") |
|
inputs, history_useful, history = take_last_tokens(inputs, history_useful, history) |
|
reply_ids = model.generate(**inputs) |
|
response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] |
|
history_useful = add_note_to_history(response, history_useful) |
|
list_history = history_useful[0].split('</s> <s>') |
|
history.append((list_history[-2], list_history[-1])) |
|
|
|
df=pd.DataFrame() |
|
|
|
if UseMemory: |
|
|
|
outputfileName = 'ChatbotMemory3.csv' |
|
df = store_message(message, response, outputfileName) |
|
basedir = get_base(outputfileName) |
|
|
|
return history, df, basedir |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1><center>🍰Gradio chatbot backed by dataframe CSV memory🎨</center></h1>") |
|
|
|
with gr.Row(): |
|
t1 = gr.Textbox(lines=1, default="", label="Chat Text:") |
|
b1 = gr.Button("Respond and Retrieve Messages") |
|
|
|
with gr.Row(): |
|
s1 = gr.State([]) |
|
df1 = gr.Dataframe(wrap=True, max_rows=1000, overflow_row_behaviour= "paginate") |
|
with gr.Row(): |
|
file = gr.File(label="File") |
|
s2 = gr.Markdown() |
|
|
|
b1.click(fn=chat, inputs=[t1, s1], outputs=[s1, df1, file]) |
|
|
|
demo.launch(debug=True, show_error=True) |
|
|