X1A commited on
Commit
b42bec2
·
1 Parent(s): c213eb0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +123 -0
README.md CHANGED
@@ -1,3 +1,126 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - zh
5
  ---
6
+
7
+ # UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization
8
+
9
+ <div style='display:flex; gap: 0.25rem; '><a href='https://uni-poll.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://huggingface.co/spaces/X1A/UniPoll'><img src='https://img.shields.io/badge/Huggingface-Demo-yellow'></a><a href='https://github.com/X1AOX1A/UniPoll'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2306.06851'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
10
+
11
+ The official repository of the paper [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851).
12
+
13
+
14
+ ## Model Card for UniPoll
15
+
16
+ ### Model Description
17
+
18
+ - **Developed by:** [https://liyixia.me](https://liyixia.me);
19
+ - **Model type:** Encoder-Decoder;
20
+ - **Language(s) (NLP):** Chinese;
21
+ - **License:** apache-2.0
22
+
23
+ ### Model Source
24
+
25
+ - **Paper:** [UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization](https://arxiv.org/abs/2306.06851).
26
+
27
+ ### Training Details
28
+
29
+ - Please refer to the [paper](https://arxiv.org/abs/2306.06851) and [Github](https://github.com/X1AOX1A/UniPoll).
30
+
31
+ ## Uses
32
+
33
+ ```python
34
+ import logging
35
+ from typing import List, Tuple
36
+ from transformers import AutoConfig
37
+ from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration
38
+ from utils import T5PegasusTokenizer
39
+
40
+ def load_model(model_path):
41
+ config = AutoConfig.from_pretrained(model_path)
42
+ tokenizer = T5PegasusTokenizer.from_pretrained(model_path)
43
+ model = MT5ForConditionalGeneration.from_pretrained(model_path, config=config)
44
+ return model, tokenizer
45
+
46
+ def wrap_prompt(post, comments):
47
+ if not comments or comments == "":
48
+ prompt="生成 <title> 和 <choices>: [SEP] {post}"
49
+ return prompt.format(post=post)
50
+ else:
51
+ prompt="生成 <title> 和 <choices>: [SEP] {post} [SEP] {comments}"
52
+ return prompt.format(post=post, comments=comments)
53
+
54
+ def generate(query, model, tokenizer, num_beams=4):
55
+ tokens = tokenizer(query, return_tensors="pt")["input_ids"]
56
+ output = model.generate(tokens, num_beams=num_beams, max_length=100)
57
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
58
+ return output_text
59
+
60
+ def post_process(raw_output: str) -> Tuple[str, str]:
61
+ def same_title_choices(raw_output):
62
+ raw_output = raw_output.replace("<title>", "")
63
+ raw_output = raw_output.replace("<choices>", "")
64
+ return raw_output.strip(), [raw_output.strip()]
65
+
66
+ def split_choices(choices_str: str) -> List[str]:
67
+ choices = choices_str.split("<c>")
68
+ choices = [choice.strip() for choice in choices]
69
+ return choices
70
+
71
+ if "<title>" in raw_output and "<choices>" in raw_output:
72
+ index1 = raw_output.index("<title>")
73
+ index2 = raw_output.index("<choices>")
74
+ if index1 > index2:
75
+ logging.debug(f"idx1>idx2, same title and choices will be used.\nraw_output: {raw_output}")
76
+ return same_title_choices(raw_output)
77
+ title = raw_output[index1+7: index2].strip() # "你 觉得 线 上 复试 公平 吗"
78
+ choices_str = raw_output[index2+9:].strip() # "公平 <c> 不 公平"
79
+ choices = split_choices(choices_str) # ["公平", "不 公平"]
80
+ else:
81
+ logging.debug(f"missing title/choices, same title and choices will be used.\nraw_output: {raw_output}")
82
+ title, choices = same_title_choices(raw_output)
83
+
84
+ def remove_blank(string):
85
+ return string.replace(" ", "")
86
+
87
+ title = remove_blank(title)
88
+ choices = [remove_blank(choice) for choice in choices]
89
+ return title, choices
90
+
91
+ if __name__ == "__main__":
92
+ model_path = "./UniPoll-t5"
93
+
94
+ # input post and comments(optional, None) text
95
+ post = "#线上复试是否能保障公平# 高考延期惹的祸,考研线上复试,那还能保证公平吗?"
96
+ comments = "这个世界上本来就没有绝对的公平。你可以说一个倒数第一考了第一,但考上了他也还是啥都不会。也可以说他会利用一切机会达到目的,反正结果就是人家考的好,你还找不出来证据。线上考试,平时考倒数的人进了年级前十。平时考试有水分,线上之后,那不就是在水里考?"
97
+
98
+ model, tokenizer = load_model(model_path) # load model and tokenizer
99
+ query = wrap_prompt(post, comments) # wrap prompt
100
+ raw_output = generate(query, model, tokenizer) # generate output
101
+ title, choices = post_process(raw_output) # post process
102
+
103
+ print("Raw output:", raw_output)
104
+ print("Processed title:", title)
105
+ print("Processed choices:", choices)
106
+ ```
107
+
108
+
109
+ ## Citation
110
+
111
+ ```
112
+ @misc{li2023unipoll,
113
+ title={UniPoll: A Unified Social Media Poll Generation Framework via Multi-Objective Optimization},
114
+ author={Yixia Li and Rong Xiang and Yanlin Song and Jing Li},
115
+ year={2023},
116
+ eprint={2306.06851},
117
+ archivePrefix={arXiv},
118
+ primaryClass={cs.CL}
119
+ }
120
+ ```
121
+
122
+ ## Contact Information
123
+
124
+ If you have any questions or inquiries related to this research project, please feel free to contact:
125
+
126
+ - Yixia Li: [email protected]