X1A
/

Chinese

UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization

The official repository of the paper UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization.

Model Card for UniPoll

Model Description

  • Developed by: https://liyixia.me;
  • Model type: Encoder-Decoder;
  • Language(s) (NLP): Chinese;
  • License: apache-2.0

Model Source

Training Details

Uses

import logging
from typing import List, Tuple
from transformers import AutoConfig
from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration

import jieba
from functools import partial
from transformers import BertTokenizer

class T5PegasusTokenizer(BertTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_tokenizer = partial(jieba.cut, HMM=False)

    def _tokenize(self, text, *arg, **kwargs):
        split_tokens = []
        for text in self.pre_tokenizer(text):
            if text in self.vocab:
                split_tokens.append(text)
            else:
                split_tokens.extend(super()._tokenize(text))
        return split_tokens

def load_model(model_path):
    config = AutoConfig.from_pretrained(model_path)
    tokenizer = T5PegasusTokenizer.from_pretrained(model_path)
    model = MT5ForConditionalGeneration.from_pretrained(model_path, config=config)
    return model, tokenizer

def wrap_prompt(post, comments):
    if not comments or comments == "":
        prompt="生成 <title> 和 <choices>: [SEP] {post}"
        return prompt.format(post=post)
    else:
        prompt="生成 <title> 和 <choices>: [SEP] {post} [SEP] {comments}"
        return prompt.format(post=post, comments=comments)

def generate(query, model, tokenizer, num_beams=4):
    tokens = tokenizer(query, return_tensors="pt")["input_ids"]
    output = model.generate(tokens, num_beams=num_beams, max_length=100)
    output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    return output_text
    
def post_process(raw_output: str) -> Tuple[str, str]:
    def same_title_choices(raw_output):
        raw_output = raw_output.replace("<title>", "")
        raw_output = raw_output.replace("<choices>", "")
        return raw_output.strip(), [raw_output.strip()]
    
    def split_choices(choices_str: str) -> List[str]:
        choices = choices_str.split("<c>")
        choices = [choice.strip() for choice in choices]
        return choices

    if "<title>" in raw_output and "<choices>" in raw_output:
        index1 = raw_output.index("<title>")
        index2 = raw_output.index("<choices>")
        if index1 > index2:
            logging.debug(f"idx1>idx2, same title and choices will be used.\nraw_output: {raw_output}")
            return same_title_choices(raw_output)
        title = raw_output[index1+7: index2].strip()    # "你 觉得 线 上 复试 公平 吗"
        choices_str = raw_output[index2+9:].strip()     # "公平 <c> 不 公平"
        choices = split_choices(choices_str)            # ["公平", "不 公平"]
    else:        
        logging.debug(f"missing title/choices, same title and choices will be used.\nraw_output: {raw_output}")
        title, choices = same_title_choices(raw_output)

    def remove_blank(string):
        return string.replace(" ", "")
    
    title = remove_blank(title)
    choices = [remove_blank(choice) for choice in choices]
    return title, choices
    
if __name__ == "__main__":
    model_path = "./UniPoll-t5"    

    # input post and comments(optional, None) text
    post = "#线上复试是否能保障公平# 高考延期惹的祸,考研线上复试,那还能保证公平吗?"
    comments = "这个世界上本来就没有绝对的公平。你可以说一个倒数第一考了第一,但考上了他也还是啥都不会。也可以说他会利用一切机会达到目的,反正结果就是人家考的好,你还找不出来证据。线上考试,平时考倒数的人进了年级前十。平时考试有水分,线上之后,那不就是在水里考?"
    
    model, tokenizer = load_model(model_path)         # load model and tokenizer
    query = wrap_prompt(post, comments)               # wrap prompt
    raw_output = generate(query, model, tokenizer)    # generate output
    title, choices = post_process(raw_output)         # post process

    print("Raw output:", raw_output)
    print("Processed title:", title)
    print("Processed choices:", choices)

Citation

@misc{li2023unipoll,
      title={UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization}, 
      author={Yixia Li and Rong Xiang and Yanlin Song and Jing Li},
      year={2023},
      eprint={2306.06851},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Contact Information

If you have any questions or inquiries related to this research project, please feel free to contact:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.