umm-maybe's picture
Update app.py
d9e1efb verified
import json
import gradio as gr
import os
import spacy
spacy.cli.download('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from rake_nltk import Rake
r = Rake()
import time
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('Organika (cmatthewbrown@gmail.com)', 'en')
## ctransformers disabled for now
# from ctransformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
# "Colby/StarCoder-3B-WoW-JSON",
# model_file="StarCoder-3B-WoW-JSON-ggml.bin",
# model_type="gpt_bigcode"
# )
# Use a pipeline as a high-level helper
from transformers import pipeline
topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
#model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")
def generate_text(prompt):
inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
inputs,
do_sample=True,
max_new_tokens=200,
temperature=0.6,
top_p=0.9,
top_k=40,
repetition_penalty=1.1
)
results = tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)
return results
def merlin_chat(message, history):
chat_text = ""
chat_list = []
for turn in history[-3:]:
chat_text += f"{turn[0]}\n\n{turn[1]}\n\n"
chat_list.append({"role": "user", "content": turn[0]})
chat_list.append({"role": "assistant", "content": turn[1]})
chat_text += f"{message}\n"
doc = nlp(chat_text)
ents_found = []
if doc.ents:
for ent in doc.ents:
if len(ents_found) == 3:
break
if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]:
continue
if ent.text in ents_found:
continue
ents_found.append(ent.text.title().lower())
r.extract_keywords_from_text(chat_text)
for phrase in r.get_ranked_phrases()[:3]:
phrase = phrase.lower()
if phrase not in ents_found:
ents_found.append(phrase)
context = ""
scores = topic_model(chat_text, ents_found, multi_label=True)['scores']
if ents_found:
max_score = 0
for k in range(len(ents_found)):
if scores[k] < 0.5:
continue
entity = ents_found[k]
if scores[k] > max_score:
max_score = scores[k]
max_topic = entity
print(f'# Looking up {entity} on Wikipedia... ', end='')
wiki_page = wiki_wiki.page(entity)
if wiki_page.exists():
print("page found... ")
entsum = wiki_page.summary
if "may refer to" in entsum or "may also refer to" in entsum:
print(" ambiguous, skipping.")
continue
else:
context += entsum + '\n\n'
else:
print("not found.")
system_msg = {
'role': 'system', 'content': context
}
chat_list.insert(0,system_msg)
user_msg = {'role': 'user', 'content': message}
chat_list.append(user_msg)
prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \""
print(f"PROMPT: {prompt}")
for attempt in range(3):
#result = generate_text(prompt, model_path, parameters, headers)
#result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1)
#response = result[0]['generated_text']
result = generate_text(prompt)
response = result.replace(prompt,"")
print(f"COMPLETION: {response}") # so we can see it in logs
start = 0
end = 0
cleanStr = response.lstrip()
#start = cleanStr.find('{')
#if start<=0:
# continue
end = cleanStr.find('}') + 1
if end<=0:
continue
cleanStr = cleanStr[:end]
messageStr = prompt + cleanStr + ']'
messages = json.loads(messageStr)
message = messages[-1]
if message['role'] != 'assistant':
continue
msg_text = message['content']
if chat_text.find(msg_text) >= 0:
continue
return message['content']
return "🤔"
gr.ChatInterface(merlin_chat).launch()