# prompt_enhancer_flashpack_cpu.py import gc import torch import torch.nn as nn import torch.optim as optim from datasets import load_dataset import gradio as gr from transformers import AutoTokenizer, AutoModel from flashpack import FlashPackMixin # keep if your mixin provides save_flashpack from typing import Tuple # ============================================================ # ๐Ÿ–ฅ Force CPU mode (safe for HF Spaces / Kaggle) # ============================================================ device = torch.device("cpu") torch.set_num_threads(4) # reduce CPU contention in shared environments print(f"๐Ÿ”ง Forcing device: {device} (CPU-only mode)") # ============================================================ # 1๏ธโƒฃ Define FlashPack model # ============================================================ class GemmaTrainer(nn.Module, FlashPackMixin): def __init__(self, input_dim: int = 768, hidden_dim: int = 1024, output_dim: int = 768): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # ============================================================ # 2๏ธโƒฃ Utility: encode prompts (CPU-friendly) # ============================================================ def build_encoder(model_name="gpt2", max_length: int = 32): tokenizer = AutoTokenizer.from_pretrained(model_name) # Some GPT2 tokenizers have no pad token โ€” set eos as pad if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token embed_model = AutoModel.from_pretrained(model_name).to(device) embed_model.eval() @torch.no_grad() def encode(prompt: str) -> torch.Tensor: """ Encodes a single prompt and returns a CPU tensor of shape (1, hidden_size). Always returns a CPU tensor to avoid device juggling in downstream code. """ inputs = tokenizer( prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length, ).to(device) outputs = embed_model(**inputs).last_hidden_state.mean(dim=1) # (1, hidden) return outputs.cpu() return tokenizer, embed_model, encode # ============================================================ # 3๏ธโƒฃ Train FlashPack mapping (CPU-optimized) # ============================================================ def train_flashpack_model( dataset_name: str = "gokaygokay/prompt-enhancer-dataset", model_name: str = "gpt2", max_length: int = 32, subset_limit: int | None = None, # set to int to train on subset for quick runs push_to_hub: bool = True, hf_repo: str = "rahul7star/FlashPack", ) -> Tuple[GemmaTrainer, object, AutoModel, AutoTokenizer, torch.Tensor]: """ Returns: (trained_model, dataset, embed_model, tokenizer, long_embeddings) All tensors remain on CPU to be safe in CPU-only environments. """ # 1) Load dataset print("๐Ÿ“ฆ Loading dataset...") dataset = load_dataset(dataset_name, split="train") if subset_limit is not None and subset_limit > 0: print(f"โš ๏ธ Using subset of dataset: first {subset_limit} examples for fast iteration") dataset = dataset.select(range(min(subset_limit, len(dataset)))) # 2) Build tokenizer + encoder print("๐Ÿ”ง Setting up tokenizer & encoder...") tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length) # 3) Encode dataset in a memory-friendly loop (returns CPU tensors) print("๐Ÿ”ข Encoding dataset into embeddings (CPU-friendly)...") short_list = [] long_list = [] for i, item in enumerate(dataset): short_list.append(encode_fn(item["short_prompt"])) long_list.append(encode_fn(item["long_prompt"])) # logging & GC every 100 items if (i + 1) % 100 == 0 or (i + 1) == len(dataset): print(f" โ†’ Encoded {i+1}/{len(dataset)} prompts") gc.collect() # Stack to single tensors on CPU short_embeddings = torch.vstack(short_list) # shape (N, hidden) long_embeddings = torch.vstack(long_list) print(f"โœ… Finished encoding: {short_embeddings.shape[0]} pairs, dim={short_embeddings.shape[1]}") # 4) Initialize GemmaTrainer (on CPU) model = GemmaTrainer( input_dim=short_embeddings.shape[1], hidden_dim=min(2048, int(short_embeddings.shape[1] * 2)), output_dim=long_embeddings.shape[1], ).to(device) # device is cpu # 5) Training loop (small-batch style to reduce memory pressure) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) max_epochs = 500 tolerance = 1e-4 batch_size = 64 # small batches on CPU n = short_embeddings.shape[0] print("๐Ÿš€ Training FlashPack mapper model (CPU). This may take some time...") for epoch in range(1, max_epochs + 1): model.train() epoch_loss = 0.0 # Shuffle indices each epoch perm = torch.randperm(n) for start in range(0, n, batch_size): idx = perm[start : start + batch_size] inputs = short_embeddings[idx].to(device) targets = long_embeddings[idx].to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_loss /= n if epoch % 10 == 0 or epoch == 1: print(f"Epoch {epoch:03d}/{max_epochs}, Loss={epoch_loss:.6f}") if epoch_loss < tolerance: print(f"โœ… Converged at epoch {epoch}, Loss={epoch_loss:.6f}") break # 6) Save model locally and optionally push to HF hub (robust) try: # If FlashPackMixin provides save_flashpack, use it: if hasattr(model, "save_flashpack"): print("๐Ÿ’พ Saving model with FlashPackMixin.save_flashpack()") model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=push_to_hub) else: # Fallback: simple torch.save path = "flashpack_model.pt" torch.save(model.state_dict(), path) print(f"๐Ÿ’พ Saved locally to {path}") if push_to_hub: try: from huggingface_hub import HfApi, HfFolder api = HfApi() token = HfFolder.get_token() api.upload_file(path_or_fileobj=path, path_in_repo=path, repo_id=hf_repo, token=token) print(f"๐Ÿš€ Uploaded model file to HF: {hf_repo}") except Exception as e: print("โš ๏ธ Could not push to HF Hub:", e) except Exception as e: print("โš ๏ธ Error while saving/pushing model:", e) print("โœ… Training done โ€” returning model and artifacts.") return model, dataset, embed_model, tokenizer, long_embeddings # ============================================================ # 4๏ธโƒฃ Build everything and prepare for inference # ============================================================ # For demo speed in CPU mode, you might want a subset_limit (e.g., 1000). # Set subset_limit=None to use full dataset. model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model( subset_limit=None, # change to a small int for faster testing push_to_hub=False, # toggle when you want to actually push ) model.eval() # Reusable encode function for inference (returns CPU tensor) @torch.no_grad() def encode_for_inference(prompt: str) -> torch.Tensor: inputs = tokenizer( prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32, ).to(device) return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu() # ============================================================ # 5๏ธโƒฃ Enhance prompt function (nearest neighbor via cosine) # ============================================================ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history): chat_history = chat_history or [] # encode user prompt (CPU tensor) short_emb = encode_for_inference(user_prompt) # (1, dim) with torch.no_grad(): mapped = model(short_emb.to(device)).cpu() # (1, dim) # cosine similarity against dataset long embeddings cos = nn.CosineSimilarity(dim=1) # mapped.repeat(len(long_embeddings), 1) is heavy; do efficient matmul similarity: sims = (long_embeddings @ mapped.t()).squeeze(1) # normalize: sims / (||long|| * ||mapped||) long_norms = long_embeddings.norm(dim=1) mapped_norm = mapped.norm() sims = sims / (long_norms * (mapped_norm + 1e-12)) best_idx = int(sims.argmax().item()) enhanced_prompt = dataset[best_idx]["long_prompt"] chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": enhanced_prompt}) return chat_history # ============================================================ # 6๏ธโƒฃ Gradio UI # ============================================================ with gr.Blocks(title="Prompt Enhancer โ€“ FlashPack (CPU)", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # โœจ Prompt Enhancer (FlashPack mapper) Enter a short prompt, and the model will **expand it with details and creative context**. (This demo runs on CPU โ€” expect slower inference/training than GPU.) """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") with gr.Column(scale=1): user_prompt = gr.Textbox( placeholder="Enter a short prompt...", label="Your Prompt", lines=3, ) temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") send_btn = gr.Button("๐Ÿš€ Enhance Prompt", variant="primary") clear_btn = gr.Button("๐Ÿงน Clear Chat") send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) gr.Markdown( """ --- ๐Ÿ’ก **Tips:** - CPU mode: training and large-batch encodes can take a while. Use `subset_limit` in the training call for quick tests. - Increase *Temperature* for more creative outputs (not used in the nearest-neighbour mapper but kept for UI parity). """ ) # ============================================================ # 7๏ธโƒฃ Launch # ============================================================ if __name__ == "__main__": demo.launch(show_error=True)