import torch # Function for generating text based on input def generate_text(input_text, model, tokenizer): # Append the special token to the input input_text = input_text + ' [LABEL]' input_ids = tokenizer.encode(input_text, return_tensors='pt') attention_mask = torch.ones_like(input_ids) outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids) + 5, do_sample=True, top_p=0.95) generated = tokenizer.decode(outputs[0], skip_special_tokens=False) labels = generated.split(',') labels = [label.replace('[LABEL]', '').strip() for label in labels] return generated # Function for sequence classification def classify_text(input_text, model, tokenizer): # Tokenize the input text input_ids = tokenizer.encode(input_text, return_tensors='pt') attention_mask = torch.ones_like(input_ids) # Perform sequence classification result = model(input_ids, attention_mask=attention_mask) # Post-process the results (e.g., select labels based on a threshold) labels = post_process_labels(result) return labels # Post-process labels based on a threshold or confidence score def post_process_labels(results): # Implement your logic to extract and filter labels # based on your sequence classification model's output # For example, you might use a threshold for each label's score # to determine whether it should be considered a valid theme. # Return the selected labels as a list. selected_labels = [results] return selected_labels