Spaces:
Runtime error
Runtime error
| import json | |
| import gradio as gr | |
| import os | |
| import spacy | |
| spacy.cli.download('en_core_web_sm') | |
| nlp = spacy.load('en_core_web_sm') | |
| import nltk | |
| nltk.download('stopwords') | |
| nltk.download('punkt') | |
| from rake_nltk import Rake | |
| r = Rake() | |
| import time | |
| import wikipediaapi | |
| wiki_wiki = wikipediaapi.Wikipedia('Organika (cmatthewbrown@gmail.com)', 'en') | |
| ## ctransformers disabled for now | |
| # from ctransformers import AutoModelForCausalLM | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # "Colby/StarCoder-3B-WoW-JSON", | |
| # model_file="StarCoder-3B-WoW-JSON-ggml.bin", | |
| # model_type="gpt_bigcode" | |
| # ) | |
| # Use a pipeline as a high-level helper | |
| from transformers import pipeline | |
| topic_model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0) | |
| #model = pipeline("text-generation", model="Organika/StarCoder-7B-WoW-JSON_1", device=0) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import torch | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| model_name = "umm-maybe/StarCoder-7B-WoW-JSON_1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto") | |
| def generate_text(prompt): | |
| inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda") | |
| outputs = model.generate( | |
| inputs, | |
| do_sample=True, | |
| max_new_tokens=200, | |
| temperature=0.6, | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.1 | |
| ) | |
| results = tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False) | |
| return results | |
| def merlin_chat(message, history): | |
| chat_text = "" | |
| chat_list = [] | |
| for turn in history[-3:]: | |
| chat_text += f"{turn[0]}\n\n{turn[1]}\n\n" | |
| chat_list.append({"role": "user", "content": turn[0]}) | |
| chat_list.append({"role": "assistant", "content": turn[1]}) | |
| chat_text += f"{message}\n" | |
| doc = nlp(chat_text) | |
| ents_found = [] | |
| if doc.ents: | |
| for ent in doc.ents: | |
| if len(ents_found) == 3: | |
| break | |
| if ent.text.isnumeric() or ent.label in ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]: | |
| continue | |
| if ent.text in ents_found: | |
| continue | |
| ents_found.append(ent.text.title().lower()) | |
| r.extract_keywords_from_text(chat_text) | |
| for phrase in r.get_ranked_phrases()[:3]: | |
| phrase = phrase.lower() | |
| if phrase not in ents_found: | |
| ents_found.append(phrase) | |
| context = "" | |
| scores = topic_model(chat_text, ents_found, multi_label=True)['scores'] | |
| if ents_found: | |
| max_score = 0 | |
| for k in range(len(ents_found)): | |
| if scores[k] < 0.5: | |
| continue | |
| entity = ents_found[k] | |
| if scores[k] > max_score: | |
| max_score = scores[k] | |
| max_topic = entity | |
| print(f'# Looking up {entity} on Wikipedia... ', end='') | |
| wiki_page = wiki_wiki.page(entity) | |
| if wiki_page.exists(): | |
| print("page found... ") | |
| entsum = wiki_page.summary | |
| if "may refer to" in entsum or "may also refer to" in entsum: | |
| print(" ambiguous, skipping.") | |
| continue | |
| else: | |
| context += entsum + '\n\n' | |
| else: | |
| print("not found.") | |
| system_msg = { | |
| 'role': 'system', 'content': context | |
| } | |
| chat_list.insert(0,system_msg) | |
| user_msg = {'role': 'user', 'content': message} | |
| chat_list.append(user_msg) | |
| prompt = json.dumps(chat_list)[:-1] + ",{\"role\": \"assistant\", \"content\": \"" | |
| print(f"PROMPT: {prompt}") | |
| for attempt in range(3): | |
| #result = generate_text(prompt, model_path, parameters, headers) | |
| #result = model(prompt,return_full_text=False, max_new_tokens=256, temperature=0.8, repetition_penalty=1.1) | |
| #response = result[0]['generated_text'] | |
| result = generate_text(prompt) | |
| response = result.replace(prompt,"") | |
| print(f"COMPLETION: {response}") # so we can see it in logs | |
| start = 0 | |
| end = 0 | |
| cleanStr = response.lstrip() | |
| #start = cleanStr.find('{') | |
| #if start<=0: | |
| # continue | |
| end = cleanStr.find('}') + 1 | |
| if end<=0: | |
| continue | |
| cleanStr = cleanStr[:end] | |
| messageStr = prompt + cleanStr + ']' | |
| messages = json.loads(messageStr) | |
| message = messages[-1] | |
| if message['role'] != 'assistant': | |
| continue | |
| msg_text = message['content'] | |
| if chat_text.find(msg_text) >= 0: | |
| continue | |
| return message['content'] | |
| return "🤔" | |
| gr.ChatInterface(merlin_chat).launch() |