rahul7star commited on
Commit
c4c4bdc
·
verified ·
1 Parent(s): 5bac7a5

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +82 -54
app_flash.py CHANGED
@@ -7,7 +7,11 @@ from datasets import load_dataset
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
 
 
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
  # ============================================================
13
  # 1️⃣ Define FlashPack model
@@ -25,79 +29,102 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
25
  x = self.fc2(x)
26
  return x
27
 
28
- # ============================================================
29
- # 2️⃣ Load dataset
30
- # ============================================================
31
- dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
32
 
33
  # ============================================================
34
- # 3️⃣ Prepare tokenizer & embedding model
35
  # ============================================================
36
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
37
- tokenizer.pad_token = tokenizer.eos_token # FIX padding error
 
 
 
38
 
39
- embed_model = AutoModel.from_pretrained("gpt2").to(device)
40
- embed_model.eval() # inference only
 
41
 
42
- def encode_prompt(prompt):
43
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device)
44
- with torch.no_grad():
45
- return embed_model(**inputs).last_hidden_state.mean(dim=1)
46
 
47
- # Encode all dataset prompts
48
- print("📦 Encoding dataset prompts...")
49
- short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
50
- long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
51
- print(f"✅ Encoded {len(dataset)} prompts")
 
 
 
 
 
52
 
53
- # ============================================================
54
- # 4️⃣ Train FlashPack model
55
- # ============================================================
56
- model = GemmaTrainer(input_dim=short_embeddings.shape[1], output_dim=long_embeddings.shape[1]).to(device)
57
- criterion = nn.MSELoss()
58
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
59
-
60
- max_epochs = 500
61
- tolerance = 1e-4
62
-
63
- for epoch in range(max_epochs):
64
- optimizer.zero_grad()
65
- outputs = model(short_embeddings)
66
- loss = criterion(outputs, long_embeddings)
67
- loss.backward()
68
- optimizer.step()
69
- if loss.item() < tolerance:
70
- print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
71
- break
72
- if (epoch + 1) % 50 == 0:
73
- print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # ============================================================
76
- # 5️⃣ Save FlashPack model to Hub
77
- # ============================================================
78
- FLASHPACK_REPO = "rahul7star/FlashPack"
79
- model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
80
- print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
81
 
82
  # ============================================================
83
- # 6️⃣ Load FlashPack model
84
  # ============================================================
85
- loaded_model = model.from_flashpack(FLASHPACK_REPO)
 
86
 
87
  # ============================================================
88
- # 7️⃣ Gradio interface
89
  # ============================================================
 
 
 
 
 
 
 
 
 
 
 
90
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
91
  chat_history = chat_history or []
92
-
93
- # Encode user prompt
94
  short_emb = encode_prompt(user_prompt)
 
95
  with torch.no_grad():
96
- long_emb = loaded_model(short_emb)
97
 
98
- # Find nearest matching long prompt in dataset (simple approach)
99
  cos = nn.CosineSimilarity(dim=1)
100
- sims = cos(long_emb.repeat(len(long_embeddings),1), long_embeddings)
101
  best_idx = sims.argmax()
102
  enhanced_prompt = dataset[best_idx]["long_prompt"]
103
 
@@ -105,8 +132,9 @@ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
105
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
106
  return chat_history
107
 
 
108
  # ============================================================
109
- # 8️⃣ Gradio UI
110
  # ============================================================
111
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
112
  gr.Markdown(
 
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
+ # ============================================================
11
+ # 🧠 Device setup
12
+ # ============================================================
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"🔧 Using device: {device}")
15
 
16
  # ============================================================
17
  # 1️⃣ Define FlashPack model
 
29
  x = self.fc2(x)
30
  return x
31
 
 
 
 
 
32
 
33
  # ============================================================
34
+ # 2️⃣ Encode and train using GPU
35
  # ============================================================
36
+ @spaces.GPU(duration=60) # 10-minute GPU allocation window
37
+ def train_flashpack_model():
38
+ # Load dataset
39
+ print("📦 Loading dataset...")
40
+ dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
41
 
42
+ # Tokenizer setup
43
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
+ tokenizer.pad_token = tokenizer.eos_token # ✅ Fix padding issue
45
 
46
+ # Base embedding model
47
+ embed_model = AutoModel.from_pretrained("gpt2").to(device)
48
+ embed_model.eval()
 
49
 
50
+ def encode_prompt(prompt):
51
+ inputs = tokenizer(
52
+ prompt,
53
+ return_tensors="pt",
54
+ truncation=True,
55
+ padding="max_length",
56
+ max_length=32
57
+ ).to(device)
58
+ with torch.no_grad():
59
+ return embed_model(**inputs).last_hidden_state.mean(dim=1)
60
 
61
+ # Encode dataset prompts
62
+ print("🔢 Encoding dataset into embeddings...")
63
+ short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
64
+ long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
65
+ print(f"✅ Encoded {len(dataset)} pairs")
66
+
67
+ # Train FlashPack model
68
+ model = GemmaTrainer(
69
+ input_dim=short_embeddings.shape[1],
70
+ output_dim=long_embeddings.shape[1]
71
+ ).to(device)
72
+
73
+ criterion = nn.MSELoss()
74
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
75
+ max_epochs = 500
76
+ tolerance = 1e-4
77
+
78
+ for epoch in range(max_epochs):
79
+ optimizer.zero_grad()
80
+ outputs = model(short_embeddings)
81
+ loss = criterion(outputs, long_embeddings)
82
+ loss.backward()
83
+ optimizer.step()
84
+ if loss.item() < tolerance:
85
+ print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
86
+ break
87
+ if (epoch + 1) % 50 == 0:
88
+ print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
89
+
90
+ # Save to Hugging Face Hub
91
+ FLASHPACK_REPO = "rahul7star/FlashPack"
92
+ model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
93
+ print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
94
+
95
+ return model, dataset, embed_model, tokenizer, long_embeddings
96
 
 
 
 
 
 
 
97
 
98
  # ============================================================
99
+ # 3️⃣ Run training once and load for inference
100
  # ============================================================
101
+ model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
102
+ model.eval()
103
 
104
  # ============================================================
105
+ # 4️⃣ Inference function for Gradio
106
  # ============================================================
107
+ def encode_prompt(prompt):
108
+ inputs = tokenizer(
109
+ prompt,
110
+ return_tensors="pt",
111
+ truncation=True,
112
+ padding="max_length",
113
+ max_length=32
114
+ ).to(device)
115
+ with torch.no_grad():
116
+ return embed_model(**inputs).last_hidden_state.mean(dim=1)
117
+
118
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
119
  chat_history = chat_history or []
 
 
120
  short_emb = encode_prompt(user_prompt)
121
+
122
  with torch.no_grad():
123
+ long_emb = model(short_emb)
124
 
125
+ # Nearest match search
126
  cos = nn.CosineSimilarity(dim=1)
127
+ sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings)
128
  best_idx = sims.argmax()
129
  enhanced_prompt = dataset[best_idx]["long_prompt"]
130
 
 
132
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
133
  return chat_history
134
 
135
+
136
  # ============================================================
137
+ # 5️⃣ Gradio UI
138
  # ============================================================
139
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
140
  gr.Markdown(