import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import spaces import threading model_name = "baidu/ERNIE-4.5-21B-A3B-Thinking" # Load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, ) @spaces.GPU(duration=120) def chat(message, history): messages = history + [{"role": "user", "content": message}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **model_inputs, "streamer": streamer, "max_new_tokens": 1024, "do_sample": True, "temperature": 0.7, } thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() new_history = history + [{"role": "user", "content": message}] new_history.append({"role": "assistant", "content": ""}) yield new_history, "" generated_text = "" for new_token in streamer: generated_text += new_token new_history[-1]["content"] = generated_text yield new_history, "" thread.join() with gr.Blocks(title="ERNIE Chat") as demo: gr.Markdown("# ERNIE-4.5-21B-A3B-Thinking Chat App") chatbot = gr.Chatbot( height=500, type="messages", show_copy_button=True, avatar_images=None ) msg = gr.Textbox( placeholder="Type your message here...", show_label=False, container=True, scale=7, ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") msg.submit(chat, [msg, chatbot], [chatbot, msg]) clear_btn.click(lambda: ([], ""), None, [chatbot, msg], queue=False) if __name__ == "__main__": demo.launch()