File size: 5,851 Bytes
c213eb0 b42bec2 c213eb0 b42bec2 8d8e03d b42bec2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
---
license: apache-2.0
language:
- zh
---
# UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization
<div style='display:flex; gap: 0.25rem; '><a href='https://uni-poll.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://huggingface.co/spaces/X1A/UniPoll'><img src='https://img.shields.io/badge/Huggingface-Demo-yellow'></a><a href='https://github.com/X1AOX1A/UniPoll'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2306.06851'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
The official repository of the paper [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851).
## Model Card for UniPoll
### Model Description
- **Developed by:** [https://liyixia.me](https://liyixia.me);
- **Model type:** Encoder-Decoder;
- **Language(s) (NLP):** Chinese;
- **License:** apache-2.0
### Model Source
- **Paper:** [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851).
### Training Details
- Please refer to the [paper](https://arxiv.org/abs/2306.06851) and [Github](https://github.com/X1AOX1A/UniPoll).
## Uses
```python
import logging
from typing import List, Tuple
from transformers import AutoConfig
from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration
import jieba
from functools import partial
from transformers import BertTokenizer
class T5PegasusTokenizer(BertTokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pre_tokenizer = partial(jieba.cut, HMM=False)
def _tokenize(self, text, *arg, **kwargs):
split_tokens = []
for text in self.pre_tokenizer(text):
if text in self.vocab:
split_tokens.append(text)
else:
split_tokens.extend(super()._tokenize(text))
return split_tokens
def load_model(model_path):
config = AutoConfig.from_pretrained(model_path)
tokenizer = T5PegasusTokenizer.from_pretrained(model_path)
model = MT5ForConditionalGeneration.from_pretrained(model_path, config=config)
return model, tokenizer
def wrap_prompt(post, comments):
if not comments or comments == "":
prompt="生成 <title> 和 <choices>: [SEP] {post}"
return prompt.format(post=post)
else:
prompt="生成 <title> 和 <choices>: [SEP] {post} [SEP] {comments}"
return prompt.format(post=post, comments=comments)
def generate(query, model, tokenizer, num_beams=4):
tokens = tokenizer(query, return_tensors="pt")["input_ids"]
output = model.generate(tokens, num_beams=num_beams, max_length=100)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return output_text
def post_process(raw_output: str) -> Tuple[str, str]:
def same_title_choices(raw_output):
raw_output = raw_output.replace("<title>", "")
raw_output = raw_output.replace("<choices>", "")
return raw_output.strip(), [raw_output.strip()]
def split_choices(choices_str: str) -> List[str]:
choices = choices_str.split("<c>")
choices = [choice.strip() for choice in choices]
return choices
if "<title>" in raw_output and "<choices>" in raw_output:
index1 = raw_output.index("<title>")
index2 = raw_output.index("<choices>")
if index1 > index2:
logging.debug(f"idx1>idx2, same title and choices will be used.\nraw_output: {raw_output}")
return same_title_choices(raw_output)
title = raw_output[index1+7: index2].strip() # "你 觉得 线 上 复试 公平 吗"
choices_str = raw_output[index2+9:].strip() # "公平 <c> 不 公平"
choices = split_choices(choices_str) # ["公平", "不 公平"]
else:
logging.debug(f"missing title/choices, same title and choices will be used.\nraw_output: {raw_output}")
title, choices = same_title_choices(raw_output)
def remove_blank(string):
return string.replace(" ", "")
title = remove_blank(title)
choices = [remove_blank(choice) for choice in choices]
return title, choices
if __name__ == "__main__":
model_path = "./UniPoll-t5"
# input post and comments(optional, None) text
post = "#线上复试是否能保障公平# 高考延期惹的祸,考研线上复试,那还能保证公平吗?"
comments = "这个世界上本来就没有绝对的公平。你可以说一个倒数第一考了第一,但考上了他也还是啥都不会。也可以说他会利用一切机会达到目的,反正结果就是人家考的好,你还找不出来证据。线上考试,平时考倒数的人进了年级前十。平时考试有水分,线上之后,那不就是在水里考?"
model, tokenizer = load_model(model_path) # load model and tokenizer
query = wrap_prompt(post, comments) # wrap prompt
raw_output = generate(query, model, tokenizer) # generate output
title, choices = post_process(raw_output) # post process
print("Raw output:", raw_output)
print("Processed title:", title)
print("Processed choices:", choices)
```
## Citation
```
@misc{li2023unipoll,
title={UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization},
author={Yixia Li and Rong Xiang and Yanlin Song and Jing Li},
year={2023},
eprint={2306.06851},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
## Contact Information
If you have any questions or inquiries related to this research project, please feel free to contact:
- Yixia Li: liyixia@me.com |