Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| from typing import Dict, List | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, BertPreTrainedModel, BertModel | |
| class BertForCLSClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.post_init() | |
| def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None): | |
| outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) | |
| cls_output = outputs.last_hidden_state[:, 0, :] # 跟你原本一樣 | |
| logits = self.classifier(cls_output) | |
| return logits | |
| # ========================= | |
| # 設定:模型位置 & HF 私有 Token(可選) | |
| # ========================= | |
| # 若你把權重直接放在 Space 目錄(例如 config.json、model.safetensors、tokenizer 檔), | |
| # 可把 MODEL_ID 改成 "." 以載入本地檔案。 | |
| MODEL_ID = "TAIDE-EDU/task4-level-judgement" | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ========================= | |
| # 載入模型與 tokenizer | |
| # ========================= | |
| def load_model_and_tokenizer(): | |
| kwargs = {} | |
| if HF_TOKEN and MODEL_ID != ".": | |
| kwargs["token"] = HF_TOKEN | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **kwargs) | |
| model = BertForCLSClassification.from_pretrained(MODEL_ID, **kwargs).to(device) | |
| model.eval() | |
| # 從 config 取 id2label;若缺少則提供預設 | |
| id2label = getattr(model.config, "id2label", None) or {0: "入門基礎", 1: "進階高階", 2: "流利精通"} | |
| # 依 id 排序成 list 方便顯示順序穩定 | |
| ordered_labels = [id2label[i] if i in id2label else str(i) for i in range(len(id2label))] | |
| return model, tokenizer, ordered_labels | |
| model, tokenizer, ordered_labels = load_model_and_tokenizer() | |
| MAX_LEN = 512 | |
| try: | |
| if model is not None: | |
| MAX_LEN = min(getattr(model.config, "max_position_embeddings", 512) or 512, 512) | |
| except Exception: | |
| pass | |
| # ========================= | |
| # 推論邏輯(任務0:程度等級分布) | |
| # ========================= | |
| def class_judgement(text: str) -> dict: | |
| """回傳任務0的機率分布(字串,美觀顯示於 Label)。""" | |
| if not text or not text.strip(): | |
| return "(請輸入或從下方表格點選範例)" | |
| batch = tokenizer( | |
| [text], | |
| max_length=MAX_LEN, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| output = model(**batch) | |
| print(output.shape) | |
| probs = torch.softmax(output, dim=-1)[0].tolist() | |
| # 轉成「label: prob」並格式化為多行字串 | |
| predictions = {} | |
| for i, lab in enumerate(ordered_labels): | |
| p = probs[i] if i < len(probs) else 0.0 | |
| predictions[lab] = p | |
| return predictions | |
| # ========================= | |
| # 範例資料(只放兩個) | |
| # - 閱讀測驗:id=1 | |
| # - 克漏字: id=2 | |
| # ========================= | |
| headers = ['id', '範例'] | |
| reading_article = ( | |
| "文章:\n" | |
| "我家旁邊有一個公園,公園裡有高高的樹和很多花。早上常常有小鳥在唱歌," | |
| "小孩也喜歡去那裡玩。下雨以後,空氣會變得很新鮮。我和媽媽有時會一起去公園走路," | |
| "看天上的雲,覺得很快樂。\n\n" | |
| "題目:\n" | |
| "1.這篇文章主要在介紹什麼地方?\n" | |
| "(A)我家旁邊的公園\n" | |
| "(B)我學校附近的商店\n" | |
| "(C)我媽媽的工作地點\n" | |
| "(D)小孩最想去的玩具店\n" | |
| ) | |
| cloze_text = ( | |
| "陳老師在大學教書已經十年了,他一直認為,學生除了學會課本上的知識," | |
| "更重要的是學會__1__和別人合作。每學期開始前,他都會__2__一個小組討論的主題," | |
| "讓學生們分組進行研究。每個組員要分工合作,有的人負責收集資料,有的人負責__3__和報告," | |
| "大家__4__幫忙,才能完成老師的要求。有時候小組成員之間會出現意見不同的情況," | |
| "陳老師總是__5__他們耐心溝通,學會傾聽別人的想法。他相信,這樣的經歷__6__能提升學生的能力," | |
| "__6__會對他們未來的工作和生活有很大幫助。\n" | |
| "1.\n(A)既然\n(B)如何\n(C)常常\n(D)但是\n" | |
| "2.\n(A)放棄\n(B)安排\n(C)要求\n(D)修理\n" | |
| "3.\n(A)整理\n(B)傳達\n(C)指導\n(D)販賣\n" | |
| "4.\n(A)故意\n(B)互相\n(C)分別\n(D)偶爾\n" | |
| "5.\n(A)鼓勵\n(B)懷疑\n(C)責罵\n(D)批評\n" | |
| "6.\n(A)不僅⋯⋯還⋯⋯\n(B)與其⋯⋯不如⋯⋯\n(C)要麼⋯⋯要麼⋯⋯\n(D)只有⋯⋯才⋯⋯\n" | |
| ) | |
| # 兩個表只各放一列(id=1 與 id=2),其他欄位填空字串即可 | |
| reading_samples = [ | |
| [1, reading_article], | |
| ] | |
| filling_samples = [ | |
| [2, cloze_text], | |
| ] | |
| reading_test_df = pd.DataFrame(reading_samples, columns=headers) | |
| filling_test_df = pd.DataFrame(filling_samples, columns=headers) | |
| # 點選表格時要填入的實際「完整 Prompt」內容 | |
| reading_test_id2data_str: Dict[int, str] = {1: reading_article} | |
| filling_test_id2data_str: Dict[int, str] = {2: cloze_text} | |
| # ========================= | |
| # 表格點選事件 | |
| # ========================= | |
| def on_row_select(evt: gr.SelectData, df: pd.DataFrame) -> str: | |
| """回傳應填入 Textbox 的完整文字。""" | |
| # evt.index 可能是 int 或 [int];統一成 int | |
| row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index | |
| row = df.iloc[int(row_idx)] | |
| row_id = int(row["id"]) | |
| if row_id in reading_test_id2data_str.keys(): | |
| # 閱讀測驗 | |
| return reading_test_id2data_str[row_id] | |
| else: | |
| # 克漏字 | |
| return filling_test_id2data_str[row_id] | |
| # ========================= | |
| # Gradio 介面 | |
| # ========================= | |
| with gr.Blocks(title="class_judgement") as demo: | |
| gr.Markdown("## 華測會等級分類器(任務0 Demo)\n從下方**範例表**點選一列,系統會自動帶入並推論。") | |
| with gr.Row(): | |
| inp = gr.Textbox(label="輸入文章", lines=12, placeholder="也可手動貼文後按下『送出』") | |
| btn = gr.Button("送出", variant="primary") | |
| with gr.Row(): | |
| out0 = gr.Label(label="任務0:程度等級分布") | |
| # 閱讀測驗範例表(只 1 列) | |
| table_reading = gr.Dataframe( | |
| value=reading_test_df, | |
| headers=headers, | |
| row_count=(len(reading_test_df), "fixed"), | |
| col_count=len(headers), | |
| interactive=False, | |
| wrap=True, | |
| label="閱讀測驗範例表(點選帶入)" | |
| ) | |
| table_reading.select( | |
| on_row_select, | |
| inputs=table_reading, | |
| outputs=[inp] | |
| ) | |
| # 克漏字範例表(只 1 列) | |
| table_filling = gr.Dataframe( | |
| value=filling_test_df, | |
| headers=headers, | |
| row_count=(len(filling_test_df), "fixed"), | |
| col_count=len(headers), | |
| interactive=False, | |
| wrap=True, | |
| label="克漏字填空範例表(點選帶入)" | |
| ) | |
| table_filling.select( | |
| on_row_select, | |
| inputs=table_filling, | |
| outputs=[inp] | |
| ) | |
| # 手動輸入或點範例帶入後,按「送出」只輸出任務0 | |
| btn.click( | |
| class_judgement, | |
| inputs=inp, | |
| outputs=[out0] | |
| ) | |
| # 建立完 demo 後 | |
| demo.queue(max_size=32) | |
| if __name__ == "__main__": | |
| demo.launch() |