import json import gradio as gr import os import spacy spacy.cli.download('en_core_web_sm') nlp = spacy.load('en_core_web_sm') import nltk nltk.download('stopwords') nltk.download('punkt') from rake_nltk import Rake r = Rake() import time import wikipediaapi wiki_wiki = wikipediaapi.Wikipedia('Organika (cmatthewbrown@gmail.com)', 'en') ## ctransformers disabled for now # from ctransformers import AutoModelForCausalLM # model = AutoModelForCausalLM.from_pretrained( # "Colby/StarCoder-3B-WoW-JSON", # model_file="StarCoder-3B-WoW-JSON-ggml.bin", # model_type="gpt_bigcode" # ) # Use a pipeline as a high-level helper from transformers import pipeline topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0) #model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0) from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto") def generate_text(prompt): inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda") outputs = model.generate( inputs, do_sample=True, max_new_tokens=200, temperature=0.6, top_p=0.9, top_k=40, repetition_penalty=1.1 ) results = tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False) return results def merlin_chat(message, history): chat_text = "" chat_list = [] for turn in history[-3:]: chat_text += f"{turn[0]}\n\n{turn[1]}\n\n" chat_list.append({"role": "user", "content": turn[0]}) chat_list.append({"role": "assistant", "content": turn[1]}) chat_text += f"{message}\n" doc = nlp(chat_text) ents_found = [] if doc.ents: for ent in doc.ents: if len(ents_found) == 3: break if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]: continue if ent.text in ents_found: continue ents_found.append(ent.text.title().lower()) r.extract_keywords_from_text(chat_text) for phrase in r.get_ranked_phrases()[:3]: phrase = phrase.lower() if phrase not in ents_found: ents_found.append(phrase) context = "" scores = topic_model(chat_text, ents_found, multi_label=True)['scores'] if ents_found: max_score = 0 for k in range(len(ents_found)): if scores[k] < 0.5: continue entity = ents_found[k] if scores[k] > max_score: max_score = scores[k] max_topic = entity print(f'# Looking up {entity} on Wikipedia... ', end='') wiki_page = wiki_wiki.page(entity) if wiki_page.exists(): print("page found... ") entsum = wiki_page.summary if "may refer to" in entsum or "may also refer to" in entsum: print(" ambiguous, skipping.") continue else: context += entsum + '\n\n' else: print("not found.") system_msg = { 'role': 'system', 'content': context } chat_list.insert(0,system_msg) user_msg = {'role': 'user', 'content': message} chat_list.append(user_msg) prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \"" print(f"PROMPT: {prompt}") for attempt in range(3): #result = generate_text(prompt, model_path, parameters, headers) #result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1) #response = result[0]['generated_text'] result = generate_text(prompt) response = result.replace(prompt,"") print(f"COMPLETION: {response}") # so we can see it in logs start = 0 end = 0 cleanStr = response.lstrip() #start = cleanStr.find('{') #if start<=0: # continue end = cleanStr.find('}') + 1 if end<=0: continue cleanStr = cleanStr[:end] messageStr = prompt + cleanStr + ']' messages = json.loads(messageStr) message = messages[-1] if message['role'] != 'assistant': continue msg_text = message['content'] if chat_text.find(msg_text) >= 0: continue return message['content'] return "🤔" gr.ChatInterface(merlin_chat).launch()