SimpleAES / inference /infer_single.py
SFM2001's picture
accelerate
48cf773
from utils.data_utils import *
from utils.prompts import *
import torch
from torch.cuda.amp import autocast
from create_app import *
from transformers import GenerationConfig
import time
def replace_single_newlines(text):
return re.sub(r'(?<!\n)\n(?!\n)', '\\\\n\\\\n', text)
def generate_full_prompt(topic, essay, cefr_stat):
essay = replace_single_newlines(essay)
paragraph_cnt = len(essay.replace('\\n\\n', '\\n').split('\\n'))
word_cnt = len(essay.split())
stat = f"The essay has {word_cnt} words and {paragraph_cnt} paragraphs.\n"
full_prompt = feedback_prompt_1 + "\n ##{{PROMPT}}\n```\n" + topic + "\n```\n ##{{ESSAY}}\n```\n" + essay + "\n```\n##CEFR Analysis\n" + str(cefr_stat) + "\n\n##{{STATS}}\n" + stat + '\n' + feedback_prompt_2
return full_prompt
def generate_and_score_essay(topic, essay):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
print("Analysing CEFR")
cefr_results = get_cefr_stats(essay)
full_prompt = generate_full_prompt(topic=topic, essay=essay, cefr_stat=cefr_results)
print("Generating prompt")
essay = replace_single_newlines(essay)
paragraph_cnt = len(essay.replace('\\n\\n', '\\n').split('\\n'))
text = QWEN_TOKENIZER.apply_chat_template(
[{"role": "user", "content": full_prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = QWEN_TOKENIZER(
text,
return_tensors="pt",
padding=True,
truncation=True,
padding_side='left'
).to(device)
print("Tokenized")
start = time.time()
gen_config = GenerationConfig(
max_new_tokens=850, # cut way down from 1500
do_sample=True,
top_k=20,
top_p=0.9,
temperature=0.7,
eos_token_id=QWEN_TOKENIZER.eos_token_id,
pad_token_id=QWEN_TOKENIZER.eos_token_id,
)
with torch.inference_mode():
outputs = QWEN_MODEL.generate(
**inputs,
generation_config=gen_config,
use_cache=True,
return_dict_in_generate=False,
)
print("Generated", time.time() - start)
generated_ids = outputs[0][inputs.input_ids.shape[1]:]
full_feedback = QWEN_TOKENIZER.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True # Fix spaces/newlines
)
output_match = re.search(r"{(.*?)}", full_feedback, re.DOTALL)
response = output_match.group(1).strip() if output_match else full_feedback
feedback_components = extract_feedback_keys_values(response)
feedback_components = dict(feedback_components)
feedback_components['word_count'] = len(essay.split())
feedback_components['paragraph_count'] = paragraph_cnt
feedback_components['cefr_stat'] = cefr_results
score_input = create_train_input({
'topic': topic,
'essay': essay,
'Corrected_essay': feedback_components.get('Corrected_essay', ''),
'TR_feedback': feedback_components.get('TR_feedback', ''),
'CC_feedback': feedback_components.get('CC_feedback', ''),
'LR_feedback': feedback_components.get('LR_feedback', ''),
'GRA_feedback': feedback_components.get('GRA_feedback', ''),
'word_count':feedback_components.get('word_count', ''),
'paragraph_count': feedback_components.get('paragraph_count', ''),
'cefr_stat': feedback_components.get('cefr_stat', '')
})
print("input got")
score_inputs = LONGFORMER_TOKENIZER(
score_input,
return_tensors="pt",
max_length=2048,
truncation=True,
padding=True
).to(device)
LONGFORMER_MODEL.eval()
with torch.no_grad():
outputs = LONGFORMER_MODEL(**score_inputs) # Get full outputs dictionary
scores = outputs['logits'].cpu().numpy()
scores = [round(x) for x in scores[0]]
print("Score got")
return scores, feedback_components