import os from tqdm import tqdm import torch import transformers from transformers import AutoTokenizer, AutoModelForCausalLM import math import re device = "cuda" if torch.cuda.is_available() else "cpu" description = { "C2": "Has a good command of idiomatic expressions and colloquialisms with awareness of connotative levels of meaning. Can convey finer shades of meaning precisely by using, with reasonable accuracy, a wide range of modification devices. Can backtrack and restructure around a difficulty so smoothly that the interlocutor is hardly aware of it.", "C1": "Can express themselves fluently and spontaneously, almost effortlessly. Has a good command of a broad lexical repertoire allowing gaps to be readily overcome with circumlocutions. There is little obvious searching for expressions or avoidance strategies; only a conceptually difficult subject can hinder a natural, smooth flow of language.", "B2": "Can interact with a degree of fluency and spontaneity that makes regular interaction, and sustained relationships with users of the target language, quite possible without imposing strain on either party. Can highlight the personal significance of events and experiences, and account for and sustain views clearly by providing relevant explanations and arguments.", "B1": "Can communicate with some confidence on familiar routine and non-routine matters related to their interests and professional field. Can exchange, check and confirm information, deal with less routine situations and explain why something is a problem. Can express thoughts on more abstract, cultural topics such as films, books, music, etc.", "A2": "Can interact with reasonable ease in structured situations and short conversations, provided the other person helps if necessary. Can manage simple, routine exchanges without undue effort; can ask and answer questions and exchange ideas and information on familiar topics in predictable everyday situations.", "A1": "Can interact in a simple way but communication is totally dependent on repetition at a slower rate, rephrasing and repair. Can ask and answer simple questions, initiate and respond to simple statements in areas of immediate need or on very familiar topics." } def format_context(context): return os.linesep.join([("A" if (i%2==0) else "B") + ": " + turn["text"] for i, turn in enumerate(context)]) def get_messages(instruction, item, apply_chat_template, system_msg, next_speaker="A"): item['messages'] = [{"role": "system", "content": f"Only output {next_speaker}'s response."}] if system_msg else [] item['messages'] += [{"role": "user", "content": f"{instruction}\nDialog:\n{format_context(item['context'])}\n"}] item['messages'] += [{"role": "assistant", "content": f"{item['response']}"}] if apply_chat_template: item['prompt'] = apply_chat_template(item['messages'][:-1], tokenize=False, add_generation_prompt=True) item['text'] = apply_chat_template(item['messages'], tokenize=False) return item def get_CEFR_prompt(item, apply_chat_template=None): next_speaker = "A" if len(item['context']) % 2 == 0 else "B" instruction = f"Given the dialog, provide a possible next turn of {next_speaker} that an English learner on CEFR level {item['CEFR']} could produce:" item = get_messages(instruction, item, apply_chat_template, False, next_speaker) item['messages'] = [{"role": "system", "content": f"Only output {next_speaker}'s response using language on CEFR level {item['CEFR']}. This level is described as: {description[item['CEFR']]}. Do not include quote marks or a label like 'B:'"}] + item['messages'] item['prompt'] = apply_chat_template(item['messages'][:-1], tokenize=False, add_generation_prompt=True) item['text'] = apply_chat_template(item['messages'], tokenize=False) return item def parse_response(response, format="A: "): if format in response: return response[response.index(format)+len(format):] return response def generate(model, tokenizer, prompts, eos_token_id=None, max_new_tokens=128, batch_size=32, verbose=False, skip_special_tokens=True, do_sample=False, repetition_penalty=1.0, length_penalty=1.0, num_beams=1): """ This generates tokens and returns the decoded and extracted response to the dialog generation task """ tokenizer.padding_side = "left" if eos_token_id == None: eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] model.eval() outputs = [] for i in tqdm(range(0, len(prompts), batch_size), total=math.ceil(len(prompts)/batch_size), desc="Generate"): batch = prompts[i:i + batch_size] model_input = tokenizer(batch, return_tensors="pt", padding='max_length', truncation=True, max_length=512).to(device) if verbose: print(model_input) with torch.no_grad(): token_ids = model.generate(**model_input, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, eos_token_id=eos_token_id, do_sample=do_sample, temperature=0.6 if do_sample else None, top_p=0.9 if do_sample else None, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_beams=num_beams) outputs += tokenizer.batch_decode(token_ids[:,model_input['input_ids'].shape[1]:], skip_special_tokens=skip_special_tokens, device="cpu") if verbose: print(outputs[-batch_size:]) tokenizer.padding_side = "right" responses = [re.search(r'(.*)(\nB:)?', output.strip()).group(1) for output in outputs] responses=outputs return responses[0] if len(responses)==1 else responses