|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- zh |
|
|
--- |
|
|
|
|
|
# UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization |
|
|
|
|
|
<div style='display:flex; gap: 0.25rem; '><a href='https://uni-poll.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://huggingface.co/spaces/X1A/UniPoll'><img src='https://img.shields.io/badge/Huggingface-Demo-yellow'></a><a href='https://github.com/X1AOX1A/UniPoll'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2306.06851'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div> |
|
|
|
|
|
The official repository of the paper [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851). |
|
|
|
|
|
|
|
|
## Model Card for UniPoll |
|
|
|
|
|
### Model Description |
|
|
|
|
|
- **Developed by:** [https://liyixia.me](https://liyixia.me); |
|
|
- **Model type:** Encoder-Decoder; |
|
|
- **Language(s) (NLP):** Chinese; |
|
|
- **License:** apache-2.0 |
|
|
|
|
|
### Model Source |
|
|
|
|
|
- **Paper:** [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851). |
|
|
|
|
|
### Training Details |
|
|
|
|
|
- Please refer to the [paper](https://arxiv.org/abs/2306.06851) and [Github](https://github.com/X1AOX1A/UniPoll). |
|
|
|
|
|
## Uses |
|
|
|
|
|
```python |
|
|
import logging |
|
|
from typing import List, Tuple |
|
|
from transformers import AutoConfig |
|
|
from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration |
|
|
|
|
|
import jieba |
|
|
from functools import partial |
|
|
from transformers import BertTokenizer |
|
|
|
|
|
class T5PegasusTokenizer(BertTokenizer): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.pre_tokenizer = partial(jieba.cut, HMM=False) |
|
|
|
|
|
def _tokenize(self, text, *arg, **kwargs): |
|
|
split_tokens = [] |
|
|
for text in self.pre_tokenizer(text): |
|
|
if text in self.vocab: |
|
|
split_tokens.append(text) |
|
|
else: |
|
|
split_tokens.extend(super()._tokenize(text)) |
|
|
return split_tokens |
|
|
|
|
|
def load_model(model_path): |
|
|
config = AutoConfig.from_pretrained(model_path) |
|
|
tokenizer = T5PegasusTokenizer.from_pretrained(model_path) |
|
|
model = MT5ForConditionalGeneration.from_pretrained(model_path, config=config) |
|
|
return model, tokenizer |
|
|
|
|
|
def wrap_prompt(post, comments): |
|
|
if not comments or comments == "": |
|
|
prompt="生成 <title> 和 <choices>: [SEP] {post}" |
|
|
return prompt.format(post=post) |
|
|
else: |
|
|
prompt="生成 <title> 和 <choices>: [SEP] {post} [SEP] {comments}" |
|
|
return prompt.format(post=post, comments=comments) |
|
|
|
|
|
def generate(query, model, tokenizer, num_beams=4): |
|
|
tokens = tokenizer(query, return_tensors="pt")["input_ids"] |
|
|
output = model.generate(tokens, num_beams=num_beams, max_length=100) |
|
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] |
|
|
return output_text |
|
|
|
|
|
def post_process(raw_output: str) -> Tuple[str, str]: |
|
|
def same_title_choices(raw_output): |
|
|
raw_output = raw_output.replace("<title>", "") |
|
|
raw_output = raw_output.replace("<choices>", "") |
|
|
return raw_output.strip(), [raw_output.strip()] |
|
|
|
|
|
def split_choices(choices_str: str) -> List[str]: |
|
|
choices = choices_str.split("<c>") |
|
|
choices = [choice.strip() for choice in choices] |
|
|
return choices |
|
|
|
|
|
if "<title>" in raw_output and "<choices>" in raw_output: |
|
|
index1 = raw_output.index("<title>") |
|
|
index2 = raw_output.index("<choices>") |
|
|
if index1 > index2: |
|
|
logging.debug(f"idx1>idx2, same title and choices will be used.\nraw_output: {raw_output}") |
|
|
return same_title_choices(raw_output) |
|
|
title = raw_output[index1+7: index2].strip() # "你 觉得 线 上 复试 公平 吗" |
|
|
choices_str = raw_output[index2+9:].strip() # "公平 <c> 不 公平" |
|
|
choices = split_choices(choices_str) # ["公平", "不 公平"] |
|
|
else: |
|
|
logging.debug(f"missing title/choices, same title and choices will be used.\nraw_output: {raw_output}") |
|
|
title, choices = same_title_choices(raw_output) |
|
|
|
|
|
def remove_blank(string): |
|
|
return string.replace(" ", "") |
|
|
|
|
|
title = remove_blank(title) |
|
|
choices = [remove_blank(choice) for choice in choices] |
|
|
return title, choices |
|
|
|
|
|
if __name__ == "__main__": |
|
|
model_path = "./UniPoll-t5" |
|
|
|
|
|
# input post and comments(optional, None) text |
|
|
post = "#线上复试是否能保障公平# 高考延期惹的祸,考研线上复试,那还能保证公平吗?" |
|
|
comments = "这个世界上本来就没有绝对的公平。你可以说一个倒数第一考了第一,但考上了他也还是啥都不会。也可以说他会利用一切机会达到目的,反正结果就是人家考的好,你还找不出来证据。线上考试,平时考倒数的人进了年级前十。平时考试有水分,线上之后,那不就是在水里考?" |
|
|
|
|
|
model, tokenizer = load_model(model_path) # load model and tokenizer |
|
|
query = wrap_prompt(post, comments) # wrap prompt |
|
|
raw_output = generate(query, model, tokenizer) # generate output |
|
|
title, choices = post_process(raw_output) # post process |
|
|
|
|
|
print("Raw output:", raw_output) |
|
|
print("Processed title:", title) |
|
|
print("Processed choices:", choices) |
|
|
``` |
|
|
|
|
|
|
|
|
## Citation |
|
|
|
|
|
``` |
|
|
@misc{li2023unipoll, |
|
|
title={UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization}, |
|
|
author={Yixia Li and Rong Xiang and Yanlin Song and Jing Li}, |
|
|
year={2023}, |
|
|
eprint={2306.06851}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Contact Information |
|
|
|
|
|
If you have any questions or inquiries related to this research project, please feel free to contact: |
|
|
|
|
|
- Yixia Li: [email protected] |