mars_my commited on
Commit
23e0d46
·
1 Parent(s): 191c982
Files changed (1) hide show
  1. app.py +123 -42
app.py CHANGED
@@ -1,11 +1,19 @@
1
  import gradio as gr
2
  import torch
 
 
 
 
3
  from transformers.models.bert import BertForSequenceClassification, BertTokenizer
4
  from transformers.models.roberta import RobertaForSequenceClassification, RobertaTokenizer
5
 
 
 
 
 
6
  # torch.set_grad_enabled(False)
7
  print('Loading Models from HuggingFace...')
8
- # load by default
9
  name_en = "yuchuantian/AIGC_detector_env3"
10
  model_en = RobertaForSequenceClassification.from_pretrained(name_en)
11
  tokenizer_en = RobertaTokenizer.from_pretrained(name_en)
@@ -13,9 +21,6 @@ tokenizer_en = RobertaTokenizer.from_pretrained(name_en)
13
  name_en3 = "yuchuantian/AIGC_detector_env3short"
14
  model_en3 = RobertaForSequenceClassification.from_pretrained(name_en3)
15
 
16
- name_en5 = "yuchuantian/AIGC_detector_env2"
17
- model_en5 = RobertaForSequenceClassification.from_pretrained(name_en5)
18
-
19
  name_zh = "yuchuantian/AIGC_detector_zhv3"
20
  model_zh = BertForSequenceClassification.from_pretrained(name_zh)
21
  tokenizer_zh = BertTokenizer.from_pretrained(name_zh)
@@ -23,9 +28,6 @@ tokenizer_zh = BertTokenizer.from_pretrained(name_zh)
23
  name_zh4 = "yuchuantian/AIGC_detector_zhv3short"
24
  model_zh4 = BertForSequenceClassification.from_pretrained(name_zh4)
25
 
26
- name_zh6 = "yuchuantian/AIGC_detector_zhv2"
27
- model_zh6 = BertForSequenceClassification.from_pretrained(name_zh6)
28
-
29
  print('Model Loading from HuggingFace Complete!')
30
 
31
 
@@ -48,12 +50,6 @@ def predict_en3(text):
48
  res = predict_func(text, tokenizer_en, model_en3)
49
  return id2label[res['label']], res['score']
50
 
51
- def predict_en5(text):
52
- id2label = ['Human', 'AI']
53
- res = predict_func(text, tokenizer_en, model_en5)
54
- return id2label[res['label']], res['score']
55
-
56
-
57
  def predict_zh(text):
58
  id2label = ['人类', 'AI']
59
  res = predict_func(text, tokenizer_zh, model_zh)
@@ -64,14 +60,80 @@ def predict_zh4(text):
64
  res = predict_func(text, tokenizer_zh, model_zh4)
65
  return id2label[res['label']], res['score']
66
 
67
- def predict_zh6(text):
68
- id2label = ['人类', 'AI']
69
- res = predict_func(text, tokenizer_zh, model_zh6)
70
- return id2label[res['label']], res['score']
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- print(predict_en('Peking University is one of the best universities in the world.'))
74
 
 
75
  print(predict_zh('很高兴认识你!'))
76
 
77
 
@@ -93,13 +155,48 @@ with gr.Blocks() as demo:
93
  [Paper Link 论文链接](https://arxiv.org/abs/2305.18149)
94
 
95
  The loadable versions are as follows 可加载的检测器版本如下:
96
- English: [En-v3](https://huggingface.co/yuchuantian/AIGC_detector_env3) / [En-v3-short](https://huggingface.co/yuchuantian/AIGC_detector_env3short) / [En_v2](https://huggingface.co/yuchuantian/AIGC_detector_env2)
97
- Chinese: [Zh-v3](https://huggingface.co/yuchuantian/AIGC_detector_zhv3) / [Zh-v3-short](https://huggingface.co/yuchuantian/AIGC_detector_zhv3short) / [Zh_v2](https://huggingface.co/yuchuantian/AIGC_detector_zhv2)
98
 
99
  Acknowledgement 致谢
100
  We sincerely thank [Hello-SimpleAI](https://huggingface.co/spaces/Hello-SimpleAI/chatgpt-detector-single) for their code.
101
  """)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Tab("中文-V3"):
104
  gr.Markdown("""
105
  注意: 本检测器提供的结果仅供参考,应谨慎作为事实依据。
@@ -109,7 +206,6 @@ with gr.Blocks() as demo:
109
  label2 = gr.Textbox(lines=1, label='预测结果')
110
  score2 = gr.Textbox(lines=1, label='模型概率')
111
 
112
-
113
  with gr.Tab("中文-V3-短文本"):
114
  gr.Markdown("""
115
  注意: 本检测器提供的结果仅供参考,应谨慎作为事实依据。
@@ -119,16 +215,6 @@ with gr.Blocks() as demo:
119
  label4 = gr.Textbox(lines=1, label='预测结果')
120
  score4 = gr.Textbox(lines=1, label='模型概率')
121
 
122
- with gr.Tab("中文-V2"):
123
- gr.Markdown("""
124
- 注意: 本检测器提供的结果仅供参考,应谨慎作为事实依据。
125
- """)
126
- t6 = gr.Textbox(lines=5, label='文本',value="北京大学建立于1898年7月3日,初名京师大学堂,辛亥革命后于1912年改为北京大学。1938年更名为国立西南联合大学。1946年10月在北平复员。1952年成为以文理学科为主的综合性大学。")
127
- button6 = gr.Button("🚀 检测!")
128
- label6 = gr.Textbox(lines=1, label='预测结果')
129
- score6 = gr.Textbox(lines=1, label='模型概率')
130
-
131
-
132
  with gr.Tab("English-V3"):
133
  gr.Markdown("""
134
  Note: The results are for reference only; they could not be used as factual evidence.
@@ -147,25 +233,20 @@ with gr.Blocks() as demo:
147
  label3 = gr.Textbox(lines=1, label='Predicted Label')
148
  score3 = gr.Textbox(lines=1, label='Probability')
149
 
150
- with gr.Tab("English-V2"):
151
- gr.Markdown("""
152
- Note: The results are for reference only; they could not be used as factual evidence.
153
- """)
154
- t5 = gr.Textbox(lines=5, label='Text',value="Originated as the Imperial University of Peking in 1898, Peking University was China's first national comprehensive university and the supreme education authority at the time. Since the founding of the People's Republic of China in 1949, it has developed into a comprehensive university with fundamental education and research in both humanities and science. The reform and opening-up of China in 1978 has ushered in a new era for the University unseen in history.")
155
- button5 = gr.Button("🚀 Predict!")
156
- label5 = gr.Textbox(lines=1, label='Predicted Label')
157
- score5 = gr.Textbox(lines=1, label='Probability')
158
-
159
  button1.click(predict_en, inputs=[t1], outputs=[label1,score1])
160
  button2.click(predict_zh, inputs=[t2], outputs=[label2,score2])
161
  button3.click(predict_en3, inputs=[t3], outputs=[label3,score3])
162
  button4.click(predict_zh4, inputs=[t4], outputs=[label4,score4])
163
- button5.click(predict_en5, inputs=[t5], outputs=[label5,score5])
164
- button6.click(predict_zh6, inputs=[t6], outputs=[label6,score6])
165
 
166
  # Page Count
167
  gr.Markdown("""
168
  <center><a href='https://clustrmaps.com/site/1bsdc' title='Visit tracker'><img src='//clustrmaps.com/map_v2.png?cl=080808&w=a&t=tt&d=NXQdnwxvIm27veMbB5F7oHNID09nhSvkBRZ_Aji9eIA&co=ffffff&ct=808080'/></a></center>
169
  """)
170
 
 
 
171
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import uuid
4
+ import time
5
+ import threading
6
+ from datetime import datetime
7
  from transformers.models.bert import BertForSequenceClassification, BertTokenizer
8
  from transformers.models.roberta import RobertaForSequenceClassification, RobertaTokenizer
9
 
10
+ # 任务状态管理
11
+ task_results = {}
12
+ task_lock = threading.Lock()
13
+
14
  # torch.set_grad_enabled(False)
15
  print('Loading Models from HuggingFace...')
16
+ # load V3 models only
17
  name_en = "yuchuantian/AIGC_detector_env3"
18
  model_en = RobertaForSequenceClassification.from_pretrained(name_en)
19
  tokenizer_en = RobertaTokenizer.from_pretrained(name_en)
 
21
  name_en3 = "yuchuantian/AIGC_detector_env3short"
22
  model_en3 = RobertaForSequenceClassification.from_pretrained(name_en3)
23
 
 
 
 
24
  name_zh = "yuchuantian/AIGC_detector_zhv3"
25
  model_zh = BertForSequenceClassification.from_pretrained(name_zh)
26
  tokenizer_zh = BertTokenizer.from_pretrained(name_zh)
 
28
  name_zh4 = "yuchuantian/AIGC_detector_zhv3short"
29
  model_zh4 = BertForSequenceClassification.from_pretrained(name_zh4)
30
 
 
 
 
31
  print('Model Loading from HuggingFace Complete!')
32
 
33
 
 
50
  res = predict_func(text, tokenizer_en, model_en3)
51
  return id2label[res['label']], res['score']
52
 
 
 
 
 
 
 
53
  def predict_zh(text):
54
  id2label = ['人类', 'AI']
55
  res = predict_func(text, tokenizer_zh, model_zh)
 
60
  res = predict_func(text, tokenizer_zh, model_zh4)
61
  return id2label[res['label']], res['score']
62
 
 
 
 
 
63
 
64
+ # 异步任务处理函数
65
+ def process_task_async(task_id, text, model_type):
66
+ """异步处理任务"""
67
+ try:
68
+ with task_lock:
69
+ task_results[task_id] = {
70
+ "status": "processing",
71
+ "created_at": datetime.now().isoformat(),
72
+ "result": None,
73
+ "error": None
74
+ }
75
+
76
+ # 根据模型类型选择对应的预测函数
77
+ if model_type == "en_v3":
78
+ result = predict_en(text)
79
+ elif model_type == "en_v3_short":
80
+ result = predict_en3(text)
81
+ elif model_type == "zh_v3":
82
+ result = predict_zh(text)
83
+ elif model_type == "zh_v3_short":
84
+ result = predict_zh4(text)
85
+ else:
86
+ raise ValueError(f"Unknown model type: {model_type}")
87
+
88
+ with task_lock:
89
+ task_results[task_id]["status"] = "completed"
90
+ task_results[task_id]["result"] = {
91
+ "label": result[0],
92
+ "score": float(result[1])
93
+ }
94
+ task_results[task_id]["completed_at"] = datetime.now().isoformat()
95
+
96
+ except Exception as e:
97
+ with task_lock:
98
+ task_results[task_id]["status"] = "error"
99
+ task_results[task_id]["error"] = str(e)
100
+ task_results[task_id]["completed_at"] = datetime.now().isoformat()
101
+
102
+
103
+ def submit_task(text, model_type):
104
+ """提交异步任务"""
105
+ task_id = str(uuid.uuid4())
106
+
107
+ # 启动后台线程处理任务
108
+ thread = threading.Thread(
109
+ target=process_task_async,
110
+ args=(task_id, text, model_type)
111
+ )
112
+ thread.daemon = True
113
+ thread.start()
114
+
115
+ return {
116
+ "task_id": task_id,
117
+ "status": "submitted",
118
+ "message": "Task submitted successfully"
119
+ }
120
+
121
+
122
+ def query_task_result(task_id):
123
+ """查询任务结果"""
124
+ with task_lock:
125
+ if task_id not in task_results:
126
+ return {
127
+ "error": "Task not found",
128
+ "task_id": task_id
129
+ }
130
+
131
+ task_info = task_results[task_id].copy()
132
+ task_info["task_id"] = task_id
133
+ return task_info
134
 
 
135
 
136
+ print(predict_en('Peking University is one of the best universities in the world.'))
137
  print(predict_zh('很高兴认识你!'))
138
 
139
 
 
155
  [Paper Link 论文链接](https://arxiv.org/abs/2305.18149)
156
 
157
  The loadable versions are as follows 可加载的检测器版本如下:
158
+ English: [En-v3](https://huggingface.co/yuchuantian/AIGC_detector_env3) / [En-v3-short](https://huggingface.co/yuchuantian/AIGC_detector_env3short)
159
+ Chinese: [Zh-v3](https://huggingface.co/yuchuantian/AIGC_detector_zhv3) / [Zh-v3-short](https://huggingface.co/yuchuantian/AIGC_detector_zhv3short)
160
 
161
  Acknowledgement 致谢
162
  We sincerely thank [Hello-SimpleAI](https://huggingface.co/spaces/Hello-SimpleAI/chatgpt-detector-single) for their code.
163
  """)
164
 
165
+ with gr.Tab("异步API接口"):
166
+ gr.Markdown("""
167
+ ## 异步API接口使用说明
168
+
169
+ ### 1. 提交任务接口
170
+ - 函数: `submit_task(text, model_type)`
171
+ - 参数:
172
+ - text: 要检测的文本
173
+ - model_type: 模型类型 (en_v3, en_v3_short, zh_v3, zh_v3_short)
174
+ - 返回: task_id 和状态信息
175
+
176
+ ### 2. 查询结果接口
177
+ - 函数: `query_task_result(task_id)`
178
+ - 参数: task_id (任务ID)
179
+ - 返回: 任务状态和结果
180
+ """)
181
+
182
+ with gr.Row():
183
+ with gr.Column():
184
+ api_text = gr.Textbox(lines=5, label='文本内容', value="北京大学建立于1898年7月3日")
185
+ api_model = gr.Dropdown(
186
+ choices=["en_v3", "en_v3_short", "zh_v3", "zh_v3_short"],
187
+ value="zh_v3",
188
+ label="模型类型"
189
+ )
190
+ submit_btn = gr.Button("📤 提交任务")
191
+
192
+ with gr.Column():
193
+ task_id_input = gr.Textbox(label="任务ID", placeholder="输入任务ID查询结果")
194
+ query_btn = gr.Button("🔍 查询结果")
195
+
196
+ with gr.Row():
197
+ submit_output = gr.JSON(label="提交结果")
198
+ query_output = gr.JSON(label="查询结果")
199
+
200
  with gr.Tab("中文-V3"):
201
  gr.Markdown("""
202
  注意: 本检测器提供的结果仅供参考,应谨慎作为事实依据。
 
206
  label2 = gr.Textbox(lines=1, label='预测结果')
207
  score2 = gr.Textbox(lines=1, label='模型概率')
208
 
 
209
  with gr.Tab("中文-V3-短文本"):
210
  gr.Markdown("""
211
  注意: 本检测器提供的结果仅供参考,应谨慎作为事实依据。
 
215
  label4 = gr.Textbox(lines=1, label='预测结果')
216
  score4 = gr.Textbox(lines=1, label='模型概率')
217
 
 
 
 
 
 
 
 
 
 
 
218
  with gr.Tab("English-V3"):
219
  gr.Markdown("""
220
  Note: The results are for reference only; they could not be used as factual evidence.
 
233
  label3 = gr.Textbox(lines=1, label='Predicted Label')
234
  score3 = gr.Textbox(lines=1, label='Probability')
235
 
236
+ # 绑定事件
237
+ submit_btn.click(submit_task, inputs=[api_text, api_model], outputs=[submit_output])
238
+ query_btn.click(query_task_result, inputs=[task_id_input], outputs=[query_output])
239
+
 
 
 
 
 
240
  button1.click(predict_en, inputs=[t1], outputs=[label1,score1])
241
  button2.click(predict_zh, inputs=[t2], outputs=[label2,score2])
242
  button3.click(predict_en3, inputs=[t3], outputs=[label3,score3])
243
  button4.click(predict_zh4, inputs=[t4], outputs=[label4,score4])
 
 
244
 
245
  # Page Count
246
  gr.Markdown("""
247
  <center><a href='https://clustrmaps.com/site/1bsdc' title='Visit tracker'><img src='//clustrmaps.com/map_v2.png?cl=080808&w=a&t=tt&d=NXQdnwxvIm27veMbB5F7oHNID09nhSvkBRZ_Aji9eIA&co=ffffff&ct=808080'/></a></center>
248
  """)
249
 
250
+ # 启用队列模式以支持真正的异步处理
251
+ demo.queue()
252
  demo.launch()