|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from rdflib import Graph |
|
from datasets import load_dataset |
|
from transformers import pipeline |
|
from transformers import TextQueryProcessor, QuestionAnswerer |
|
from gradio import Interface |
|
|
|
|
|
SPECIALIZATIONS = { |
|
"Science": {"subfields": ["Physics", "Biology", "Chemistry"]}, |
|
"History": {"subfields": ["Ancient", "Medieval", "Modern"]}, |
|
"Art": {"subfields": ["Literature", "Visual", "Music"]}, |
|
} |
|
|
|
|
|
knowledge_graphs = { |
|
specialization: Graph() for specialization in SPECIALIZATIONS.keys() |
|
} |
|
|
|
|
|
model_names = { |
|
"Physics": "allenai/bart-large-cc2", |
|
"Biology": "bert-base-uncased-finetuned-squad", |
|
"Chemistry": "allenai/biobert-base", |
|
"Ancient": "facebook/bart-base-uncased-cnn", |
|
"Medieval": "distilbert-base-uncased-finetuned-squad", |
|
"Modern": "allenai/longformer-base-4096", |
|
"Literature": "gpt2-large", |
|
"Visual": "autoencoder/bart-encoder", |
|
"Music": "openai/music-gpt", |
|
} |
|
|
|
models = { |
|
specialization: AutoModelForSeq2SeqLM.from_pretrained(model_names[specialization]) |
|
for specialization in model_names.keys() |
|
} |
|
|
|
tokenizers = { |
|
specialization: AutoTokenizer.from_pretrained(model_names[specialization]) |
|
for specialization in model_names.keys() |
|
} |
|
|
|
qa_processor = TextQueryProcessor.from_pretrained("allenai/bart-large") |
|
qa_model = QuestionAnswerer.from_pretrained("allenai/bart-large") |
|
|
|
|
|
generation_pipeline = pipeline("text-generation", model="gpt2", top_k=5) |
|
|
|
|
|
interface = Interface( |
|
fn=interact, |
|
inputs=["text", "specialization"], |
|
outputs=["text"], |
|
title="AI Chatbot Civilization", |
|
description="Interact with a generation of chatbots!", |
|
) |
|
|
|
def interact(text, specialization): |
|
"""Interact with a chatbot based on prompt and specialization.""" |
|
|
|
chatbot = Chatbot(specialization) |
|
|
|
|
|
processed_prompt = process_prompt(text, specialization) |
|
|
|
|
|
response = models[specialization].generate( |
|
input_ids=tokenizers[specialization]( |
|
processed_prompt, return_tensors="pt" |
|
).input_ids |
|
) |
|
|
|
|
|
if response.sequences[0].decode() == "Consult": |
|
|
|
answer = qa_model(qa_processor(text, knowledge_graphs[specialization])) |
|
return answer["answer"] |
|
|
|
|
|
if need_creative_format(text): |
|
return generation_pipeline(text, max_length=50) |
|
|
|
return response.sequences[0].decode() |
|
|
|
def process_prompt(text, specialization): |
|
"""Preprocess prompt based on specialization and subfield.""" |
|
|
|
|
|
return text |
|
|
|
def need_creative_format(text): |
|
"""Check if prompt requires creative text generation.""" |
|
|
|
|
|
return False |
|
|
|
def learn(data, specialization): |
|
"""Update knowledge graph and fine-tune model based on data.""" |
|
|
|
|
|
pass |
|
|
|
def mutate(chatbot): |
|
"""Create a new chatbot with potentially mutated specialization.""" |
|
|
|
|
|
pass |
|
|
|
|
|
chatbots = [Chatbot(specialization) for specialization in SPECIALIZATIONS.keys()] |
|
|
|
|
|
|