Spaces:
Sleeping
Sleeping
reversing last update
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, LlamaForCausalLM
|
4 |
import spacy
|
5 |
import nltk
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -32,8 +31,6 @@ import uuid
|
|
32 |
import time
|
33 |
import asyncio
|
34 |
import aiohttp
|
35 |
-
import torch
|
36 |
-
from dotenv import load_dotenv
|
37 |
print("***************************************************************")
|
38 |
|
39 |
st.set_page_config(
|
@@ -47,8 +44,6 @@ st.set_page_config(
|
|
47 |
|
48 |
st.set_option('deprecation.showPyplotGlobalUse',False)
|
49 |
|
50 |
-
HF_TOKEN = st.secrets['HF_TOKEN']
|
51 |
-
|
52 |
class QuestionGenerationError(Exception):
|
53 |
"""Custom exception for question generation errors."""
|
54 |
pass
|
@@ -90,7 +85,7 @@ def load_model(modelname):
|
|
90 |
# Load Spacy Model
|
91 |
@st.cache_resource
|
92 |
def load_nlp_models():
|
93 |
-
nlp = spacy.load("
|
94 |
s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
|
95 |
return nlp, s2v
|
96 |
|
@@ -103,13 +98,6 @@ def load_qa_models():
|
|
103 |
spell = SpellChecker()
|
104 |
return similarity_model, spell
|
105 |
|
106 |
-
@st.cache_resource
|
107 |
-
def load_llm_model():
|
108 |
-
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
109 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
110 |
-
model = LlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, device_map="auto")
|
111 |
-
return tokenizer, model
|
112 |
-
|
113 |
with st.sidebar:
|
114 |
select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
|
115 |
if select_model == "T5-large":
|
@@ -121,10 +109,6 @@ similarity_model, spell = load_qa_models()
|
|
121 |
context_model = similarity_model
|
122 |
sentence_model = similarity_model
|
123 |
model, tokenizer = load_model(modelname)
|
124 |
-
# llm_tokenizer, llm_model = load_llm_model()
|
125 |
-
llm_tokenizer, llm_model = "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct"
|
126 |
-
pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=200)
|
127 |
-
|
128 |
# Info Section
|
129 |
def display_info():
|
130 |
st.sidebar.title("Information")
|
@@ -334,65 +318,7 @@ def get_word_type(word):
|
|
334 |
doc = nlp(word)
|
335 |
return doc[0].pos_
|
336 |
|
337 |
-
def generate_text_with_llama(prompt):
|
338 |
-
full_prompt = f"""[INST] {prompt} [/INST]"""
|
339 |
-
result = pipe(prompt, temperature=0.7, do_sample=True)[0]['generated_text']
|
340 |
-
# Extract the generated part after the prompt
|
341 |
-
# return result.split('[/INST]')[-1].strip()
|
342 |
-
return result
|
343 |
-
|
344 |
-
async def generate_options_with_llm(answer, context, question, n=4):
|
345 |
-
prompt = f"""Given the following context, question, and correct answer, generate {n-1} incorrect but plausible answer options. The options should be:
|
346 |
-
1. Contextually related to the given context
|
347 |
-
2. Grammatically consistent with the question
|
348 |
-
3. Different from the correct answer
|
349 |
-
4. Not explicitly mentioned in the given context
|
350 |
-
|
351 |
-
Context: {context}
|
352 |
-
Question: {question}
|
353 |
-
Correct Answer: {answer}
|
354 |
-
|
355 |
-
Provide the options in a comma-separated list.
|
356 |
-
"""
|
357 |
-
|
358 |
-
try:
|
359 |
-
response = await asyncio.to_thread(generate_text_with_llama, prompt)
|
360 |
-
options = [option.strip() for option in response.split(',')]
|
361 |
-
options = [option for option in options if option.lower() != answer.lower()]
|
362 |
-
print(f"\n\nLLM Options are: {options}\n\n")
|
363 |
-
return options[:n-1] # Ensure we only return n-1 options
|
364 |
-
except Exception as e:
|
365 |
-
st.error(f"Error generating options with LLM: {e}")
|
366 |
-
return []
|
367 |
-
|
368 |
-
|
369 |
async def generate_options_async(answer, context, question, n=4):
|
370 |
-
options = [answer]
|
371 |
-
|
372 |
-
# Generate options using the language model
|
373 |
-
llm_options = await generate_options_with_llm(answer, context, question, n)
|
374 |
-
options.extend(llm_options)
|
375 |
-
|
376 |
-
# If we don't have enough options, fall back to previous methods
|
377 |
-
if len(options) < n:
|
378 |
-
semantic_options = await generate_semantic_options(answer, context, question, n - len(options))
|
379 |
-
options.extend(semantic_options)
|
380 |
-
|
381 |
-
# If we still don't have enough options, use the fallback method
|
382 |
-
while len(options) < n:
|
383 |
-
fallback_options = await get_fallback_options(answer, context)
|
384 |
-
for option in fallback_options:
|
385 |
-
if option not in options and ensure_grammatical_consistency(question, answer, option):
|
386 |
-
options.append(option)
|
387 |
-
if len(options) == n:
|
388 |
-
break
|
389 |
-
|
390 |
-
# Shuffle the options
|
391 |
-
random.shuffle(options)
|
392 |
-
|
393 |
-
return options
|
394 |
-
|
395 |
-
async def generate_semantic_options(answer, context, question, n=4):
|
396 |
try:
|
397 |
options = [answer]
|
398 |
|
@@ -409,7 +335,7 @@ async def generate_semantic_options(answer, context, question, n=4):
|
|
409 |
for word in context_words:
|
410 |
if get_word_type(word) == answer_type:
|
411 |
similarity = get_semantic_similarity(answer, word)
|
412 |
-
if 0.
|
413 |
similar_words.append((word, similarity))
|
414 |
|
415 |
# Sort by similarity (descending) and take top n-1
|
@@ -519,16 +445,13 @@ async def generate_questions_async(text, num_questions, context_window_size, num
|
|
519 |
st.error(f"An unexpected error occurred: {str(e)}")
|
520 |
return []
|
521 |
|
522 |
-
async def process_batch(batch, keywords, context_window_size, num_beams
|
523 |
questions = []
|
524 |
for text in batch:
|
525 |
keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
|
526 |
for keyword, context in keyword_sentence_mapping.items():
|
527 |
question = await generate_question_async(context, keyword, num_beams)
|
528 |
-
|
529 |
-
options = await generate_options_async(keyword, context, question)
|
530 |
-
else:
|
531 |
-
options =await generate_semantic_options(keyword, context, question)
|
532 |
overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
|
533 |
if overall_score >= 0.5:
|
534 |
questions.append({
|
@@ -604,7 +527,6 @@ def assess_question_quality(context, question, answer):
|
|
604 |
return overall_score, relevance_score, complexity_score, spelling_correctness
|
605 |
|
606 |
def main():
|
607 |
-
# load_dotenv()
|
608 |
# Streamlit interface
|
609 |
st.title(":blue[Question Generator System]")
|
610 |
session_id = get_session_id()
|
@@ -654,7 +576,7 @@ def main():
|
|
654 |
start_time = time.time()
|
655 |
with st.spinner("Generating questions..."):
|
656 |
try:
|
657 |
-
state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords
|
658 |
if not state['generated_questions']:
|
659 |
st.warning("No questions were generated. The text might be too short or lack suitable content.")
|
660 |
else:
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
|
3 |
import spacy
|
4 |
import nltk
|
5 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
31 |
import time
|
32 |
import asyncio
|
33 |
import aiohttp
|
|
|
|
|
34 |
print("***************************************************************")
|
35 |
|
36 |
st.set_page_config(
|
|
|
44 |
|
45 |
st.set_option('deprecation.showPyplotGlobalUse',False)
|
46 |
|
|
|
|
|
47 |
class QuestionGenerationError(Exception):
|
48 |
"""Custom exception for question generation errors."""
|
49 |
pass
|
|
|
85 |
# Load Spacy Model
|
86 |
@st.cache_resource
|
87 |
def load_nlp_models():
|
88 |
+
nlp = spacy.load("en_core_web_md")
|
89 |
s2v = sense2vec.Sense2Vec().from_disk('s2v_old')
|
90 |
return nlp, s2v
|
91 |
|
|
|
98 |
spell = SpellChecker()
|
99 |
return similarity_model, spell
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
with st.sidebar:
|
102 |
select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
|
103 |
if select_model == "T5-large":
|
|
|
109 |
context_model = similarity_model
|
110 |
sentence_model = similarity_model
|
111 |
model, tokenizer = load_model(modelname)
|
|
|
|
|
|
|
|
|
112 |
# Info Section
|
113 |
def display_info():
|
114 |
st.sidebar.title("Information")
|
|
|
318 |
doc = nlp(word)
|
319 |
return doc[0].pos_
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
async def generate_options_async(answer, context, question, n=4):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
try:
|
323 |
options = [answer]
|
324 |
|
|
|
335 |
for word in context_words:
|
336 |
if get_word_type(word) == answer_type:
|
337 |
similarity = get_semantic_similarity(answer, word)
|
338 |
+
if 0.3 < similarity < 0.8: # Adjust these thresholds as needed
|
339 |
similar_words.append((word, similarity))
|
340 |
|
341 |
# Sort by similarity (descending) and take top n-1
|
|
|
445 |
st.error(f"An unexpected error occurred: {str(e)}")
|
446 |
return []
|
447 |
|
448 |
+
async def process_batch(batch, keywords, context_window_size, num_beams):
|
449 |
questions = []
|
450 |
for text in batch:
|
451 |
keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
|
452 |
for keyword, context in keyword_sentence_mapping.items():
|
453 |
question = await generate_question_async(context, keyword, num_beams)
|
454 |
+
options = await generate_options_async(keyword, context, question)
|
|
|
|
|
|
|
455 |
overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
|
456 |
if overall_score >= 0.5:
|
457 |
questions.append({
|
|
|
527 |
return overall_score, relevance_score, complexity_score, spelling_correctness
|
528 |
|
529 |
def main():
|
|
|
530 |
# Streamlit interface
|
531 |
st.title(":blue[Question Generator System]")
|
532 |
session_id = get_session_id()
|
|
|
576 |
start_time = time.time()
|
577 |
with st.spinner("Generating questions..."):
|
578 |
try:
|
579 |
+
state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords))
|
580 |
if not state['generated_questions']:
|
581 |
st.warning("No questions were generated. The text might be too short or lack suitable content.")
|
582 |
else:
|