from langchain_openai import ChatOpenAI from langchain.chains import ConversationChain from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.docstore.document import Document from langchain.chains.summarize import load_summarize_chain from langchain.output_parsers import StructuredOutputParser, ResponseSchema import json import random class OpenAILLM: def __init__(self, temperature: float = 1., model_name: str = 'gpt-4', mcq_question_number: int = 10, mcq_false_answer_number: int = 3): # Model-related instantiations self.llm = ChatOpenAI(temperature=temperature, model_name=model_name) self.Memory = ConversationBufferMemory self.chain_summary = load_summarize_chain(self.llm, chain_type="map_reduce", verbose=True) self.chain_chat = ConversationChain(llm=self.llm, verbose=False, memory=self.Memory()) # Other utils instantiation self.docs = [] self.text_splitter = RecursiveCharacterTextSplitter() self.chat_document_intro = "Read the following document: " self.chat_message_begin = "What would you like to know about the uploaded document?" self.mcq_question_number = mcq_question_number self.mcq_false_answer_number = mcq_false_answer_number self.mcq_intro = f""" Generate a question, correct answer and {self.mcq_false_answer_number} possible false answers from the inputted document. Make sure that it is unique from the ones you have generated before! Only create 3 possible false answers and a correct answers! """ self.mcq_answer_sheet = [] self.mcq_query = None def upload_text(self, text): texts = self.text_splitter.split_text(text) self.docs = [Document(text) for text in texts] def is_text_uploaded(self): return True if self.docs else False def empty_text(self): self.docs = [] self.chain_chat.memory = self.Memory() self.mcq_answer_sheet = [] def get_text_summary(self): summary = self.chain_summary.run(self.docs) return summary def start_chat(self): # Add document to the system's context self.chain_chat.memory.save_context({"input": self.chat_document_intro}, {"output": ""}) for doc in self.docs: self.chain_chat.memory.save_context({"input": doc.page_content}, {"output": ""}) return str(self.chain_chat.memory), self.chat_message_begin def get_chat_response(self, user_input: str): response = self.chain_chat.predict(input=user_input) return response def start_mcq(self): # Instantiate response schema to define JSON output response_schemas = [ ResponseSchema(name="question", description="Question generated from provided document."), ResponseSchema(name="answer", description="One correct answer for the asked question."), ResponseSchema(name="choices", description=f"{self.mcq_false_answer_number} available false options for a multiple-choice question in comma separated."), ] output_format_instructions = StructuredOutputParser.from_response_schemas( response_schemas).get_format_instructions() # Define the prompt that will be used for MCQ questions prompt = PromptTemplate( template="{task_instructions}\n {output_format_instructions}", input_variables=["task_instructions", "output_format_instructions"] ) # Get the MCQ query based on the prompt (by filling in the prompt values) self.mcq_query = prompt.format(task_instructions=self.mcq_intro, output_format_instructions=output_format_instructions) # Upload the document to the model self.start_chat() def get_mcq_question(self): while True: try: response = self.chain_chat.predict(input=self.mcq_query) response_parsed = json.loads(response[len(r"```json"):-len(r"```")]) question = response_parsed["question"] answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in response_parsed["choices"].split(',')][:self.mcq_false_answer_number] break except Exception as e: print(e) self.mcq_answer_sheet.append({ "question": question, "answer": answers[0], "user_answer": None, "choices": answers }) return question, random.sample(answers, len(answers)) def mcq_record_answer(self, answer): self.mcq_answer_sheet[-1]["user_answer"] = answer def get_mcq_score(self): score = sum([sheet['answer'] == sheet['user_answer'] for sheet in self.mcq_answer_sheet]) score_perc = round(score / self.mcq_question_number, 4) * 100 return score, score_perc