File size: 4,058 Bytes
4f591e5
 
 
8a2aebb
33ff5ca
8a2aebb
48cf773
4f591e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2aebb
48cf773
4f591e5
 
48cf773
4f591e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48cf773
 
 
 
 
 
 
 
 
 
 
 
4f591e5
 
8a2aebb
 
 
4f591e5
48cf773
4f591e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48cf773
4f591e5
 
 
 
 
 
 
 
6a923ef
4f591e5
 
 
48cf773
4f591e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from utils.data_utils import *
from utils.prompts import *
import torch
from torch.cuda.amp import autocast
from create_app import * 
from transformers import GenerationConfig
import time

def replace_single_newlines(text):
    return re.sub(r'(?<!\n)\n(?!\n)', '\\\\n\\\\n', text)

def generate_full_prompt(topic, essay, cefr_stat):
    essay = replace_single_newlines(essay)
    paragraph_cnt = len(essay.replace('\\n\\n', '\\n').split('\\n'))
    word_cnt = len(essay.split())
    stat = f"The essay has {word_cnt} words and {paragraph_cnt} paragraphs.\n"
    full_prompt = feedback_prompt_1 + "\n ##{{PROMPT}}\n```\n" + topic + "\n```\n ##{{ESSAY}}\n```\n" + essay + "\n```\n##CEFR Analysis\n" + str(cefr_stat) + "\n\n##{{STATS}}\n" + stat + '\n' + feedback_prompt_2
    return full_prompt


def generate_and_score_essay(topic, essay):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
    print("Analysing CEFR")
    cefr_results = get_cefr_stats(essay)
    full_prompt = generate_full_prompt(topic=topic, essay=essay, cefr_stat=cefr_results)
    print("Generating prompt")
    essay = replace_single_newlines(essay)
    paragraph_cnt = len(essay.replace('\\n\\n', '\\n').split('\\n'))
    text = QWEN_TOKENIZER.apply_chat_template(
                [{"role": "user", "content": full_prompt}],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False
            )
    inputs = QWEN_TOKENIZER(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        padding_side='left'
    ).to(device)
    print("Tokenized")
    start = time.time()
    gen_config = GenerationConfig(
        max_new_tokens=850,     # cut way down from 1500
        do_sample=True,
        top_k=20,
        top_p=0.9,
        temperature=0.7,
        eos_token_id=QWEN_TOKENIZER.eos_token_id,
        pad_token_id=QWEN_TOKENIZER.eos_token_id,
        )
    with torch.inference_mode():
        outputs = QWEN_MODEL.generate(
            **inputs,
            generation_config=gen_config,
            use_cache=True,
            return_dict_in_generate=False,
        )
    print("Generated", time.time() - start)
    generated_ids = outputs[0][inputs.input_ids.shape[1]:]
    full_feedback = QWEN_TOKENIZER.decode(
        generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True  # Fix spaces/newlines
    )
    output_match = re.search(r"{(.*?)}", full_feedback, re.DOTALL)
    response = output_match.group(1).strip() if output_match else full_feedback
    feedback_components = extract_feedback_keys_values(response)
    feedback_components = dict(feedback_components)
    feedback_components['word_count'] = len(essay.split())
    feedback_components['paragraph_count'] = paragraph_cnt
    feedback_components['cefr_stat'] = cefr_results
    score_input = create_train_input({
    'topic': topic,
    'essay': essay,
    'Corrected_essay': feedback_components.get('Corrected_essay', ''),
    'TR_feedback': feedback_components.get('TR_feedback', ''),
    'CC_feedback': feedback_components.get('CC_feedback', ''),
    'LR_feedback': feedback_components.get('LR_feedback', ''),
    'GRA_feedback': feedback_components.get('GRA_feedback', ''),
    'word_count':feedback_components.get('word_count', ''),
    'paragraph_count': feedback_components.get('paragraph_count', ''),
    'cefr_stat': feedback_components.get('cefr_stat', '')
    })
    print("input got")
    score_inputs = LONGFORMER_TOKENIZER(
        score_input,
        return_tensors="pt",
        max_length=2048,
        truncation=True,
        padding=True
    ).to(device)
    LONGFORMER_MODEL.eval()
    with torch.no_grad():
        outputs = LONGFORMER_MODEL(**score_inputs)  # Get full outputs dictionary
        scores = outputs['logits'].cpu().numpy()
    scores = [round(x) for x in scores[0]]
    print("Score got")
    return scores, feedback_components