|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
import os |
|
import csv |
|
from gradio import inputs, outputs |
|
import huggingface_hub |
|
from huggingface_hub import Repository, hf_hub_download, upload_file |
|
from datetime import datetime |
|
|
|
from typing import List, Dict |
|
import httpx |
|
import pandas as pd |
|
|
|
|
|
UseMemory=True |
|
if UseMemory: |
|
DATASET_REPO_URL="https://huggingface.co/datasets/awacke1/ChatbotMemory.csv" |
|
DATASET_REPO_ID="awacke1/ChatbotMemory.csv" |
|
DATA_FILENAME="ChatbotMemory.csv" |
|
DATA_FILE=os.path.join("data", DATA_FILENAME) |
|
HF_TOKEN=os.environ.get("HF_TOKEN") |
|
if UseMemory: |
|
try: |
|
hf_hub_download( |
|
repo_id=DATASET_REPO_ID, |
|
filename=DATA_FILENAME, |
|
cache_dir=DATA_DIRNAME, |
|
force_filename=DATA_FILENAME |
|
) |
|
except: |
|
print("file not found") |
|
repo = Repository( |
|
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN |
|
) |
|
|
|
def get_df(name: str): |
|
dataset = load_dataset(str, split="train") |
|
return dataset |
|
|
|
def store_message(name: str, message: str): |
|
if name and message: |
|
with open(DATA_FILE, "a") as csvfile: |
|
writer = csv.DictWriter(csvfile, fieldnames=[ "time", "message", "name", ]) |
|
writer.writerow( |
|
{"time": str(datetime.now()), "message": message.strip(), "name": name.strip() } |
|
) |
|
commit_url = repo.push_to_hub() |
|
|
|
f=get_df(DATASET_REPO_ID) |
|
print(f) |
|
return "" |
|
|
|
|
|
mname = "facebook/blenderbot-400M-distill" |
|
model = BlenderbotForConditionalGeneration.from_pretrained(mname) |
|
tokenizer = BlenderbotTokenizer.from_pretrained(mname) |
|
|
|
def take_last_tokens(inputs, note_history, history): |
|
"""Filter the last 128 tokens""" |
|
if inputs['input_ids'].shape[1] > 128: |
|
inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-128:].tolist()]) |
|
inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-128:].tolist()]) |
|
note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])] |
|
history = history[1:] |
|
return inputs, note_history, history |
|
|
|
def add_note_to_history(note, note_history): |
|
"""Add a note to the historical information""" |
|
note_history.append(note) |
|
note_history = '</s> <s>'.join(note_history) |
|
return [note_history] |
|
|
|
title = "💬ChatBack🧠💾" |
|
description = """Chatbot With persistent memory dataset allowing multiagent system AI to access a shared dataset as memory pool with stored interactions. |
|
Current Best SOTA Chatbot: https://huggingface.co/facebook/blenderbot-400M-distill?text=Hey+my+name+is+ChatBack%21+Are+you+ready+to+rock%3F """ |
|
|
|
def chat(message, history): |
|
history = history or [] |
|
if history: |
|
history_useful = ['</s> <s>'.join([str(a[0])+'</s> <s>'+str(a[1]) for a in history])] |
|
else: |
|
history_useful = [] |
|
history_useful = add_note_to_history(message, history_useful) |
|
inputs = tokenizer(history_useful, return_tensors="pt") |
|
inputs, history_useful, history = take_last_tokens(inputs, history_useful, history) |
|
reply_ids = model.generate(**inputs) |
|
response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] |
|
history_useful = add_note_to_history(response, history_useful) |
|
list_history = history_useful[0].split('</s> <s>') |
|
history.append((list_history[-2], list_history[-1])) |
|
ret = store_message(message, response) |
|
return history, history |
|
|
|
gr.Interface( |
|
fn=chat, |
|
theme="huggingface", |
|
css=".footer {display:none !important}", |
|
inputs=["text", "state"], |
|
outputs=["chatbot", "state", "text"], |
|
title=title, |
|
allow_flagging="never", |
|
description=f"Gradio chatbot backed by memory in a dataset repository.", |
|
article=f"The memory dataset for saves is [{DATASET_REPO_URL}]({DATASET_REPO_URL}) 🦃Thanks!🦃 Check out HF Datasets: https://huggingface.co/spaces/awacke1/FreddysDatasetViewer SOTA papers code and datasets on chat are here: https://paperswithcode.com/datasets?q=chat&v=lst&o=newest" |
|
).launch(debug=True) |
|
|