Update app_flash1.py
Browse files- app_flash1.py +15 -14
app_flash1.py
CHANGED
|
@@ -18,8 +18,11 @@ print(f"🔧 Using device: {device} (CPU-only mode)")
|
|
| 18 |
# Model Definition
|
| 19 |
# ===========================
|
| 20 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 21 |
-
def __init__(self
|
| 22 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
| 23 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 24 |
self.relu = nn.ReLU()
|
| 25 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
|
@@ -33,14 +36,14 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
|
|
| 33 |
x = self.fc3(x)
|
| 34 |
return x
|
| 35 |
|
| 36 |
-
|
| 37 |
# ===========================
|
| 38 |
# Encoder
|
| 39 |
# ===========================
|
| 40 |
-
def build_encoder(model_name="gpt2", max_length=128):
|
| 41 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 42 |
if tokenizer.pad_token is None:
|
| 43 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 44 |
embed_model = AutoModel.from_pretrained(model_name).to(device)
|
| 45 |
embed_model.eval()
|
| 46 |
|
|
@@ -48,11 +51,11 @@ def build_encoder(model_name="gpt2", max_length=128):
|
|
| 48 |
def encode(prompt: str) -> torch.Tensor:
|
| 49 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
|
| 50 |
padding="max_length", max_length=max_length).to(device)
|
| 51 |
-
|
| 52 |
-
mean_pool =
|
| 53 |
-
max_pool, _ =
|
| 54 |
-
return torch.cat([mean_pool, max_pool], dim=1).cpu()
|
| 55 |
-
|
| 56 |
return tokenizer, embed_model, encode
|
| 57 |
|
| 58 |
# ===========================
|
|
@@ -80,14 +83,14 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
|
|
| 80 |
def log_fn(msg):
|
| 81 |
logs.append(msg)
|
| 82 |
print(msg)
|
| 83 |
-
|
| 84 |
log_fn("📦 Loading dataset...")
|
| 85 |
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
|
| 86 |
log_fn(f"✅ Loaded {len(dataset)} samples")
|
| 87 |
|
| 88 |
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
|
| 89 |
|
| 90 |
-
#
|
| 91 |
s_list, l_list = [], []
|
| 92 |
for i, item in enumerate(dataset):
|
| 93 |
s_list.append(encode_fn(item["short_prompt"]))
|
|
@@ -122,7 +125,7 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
|
|
| 122 |
chat = chat or []
|
| 123 |
short_emb = encode_fn(prompt)
|
| 124 |
mapped = model(short_emb.to(device)).cpu()
|
| 125 |
-
long_prompt = f"🌟 Enhanced prompt: {prompt}
|
| 126 |
chat.append({"role": "user", "content": prompt})
|
| 127 |
chat.append({"role": "assistant", "content": long_prompt})
|
| 128 |
return chat
|
|
@@ -132,7 +135,6 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
|
|
| 132 |
# ===========================
|
| 133 |
# Lazy Load / Get Model
|
| 134 |
# ===========================
|
| 135 |
-
# ===========================
|
| 136 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 137 |
local_model_path = "model.flashpack"
|
| 138 |
|
|
@@ -151,8 +153,7 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
|
| 151 |
print(f"⚠️ Error accessing HF: {e}")
|
| 152 |
return None, None, None, None
|
| 153 |
|
| 154 |
-
|
| 155 |
-
model = GemmaTrainer(input_dim=1536).from_flashpack(local_model_path)
|
| 156 |
model.eval()
|
| 157 |
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
|
| 158 |
|
|
|
|
| 18 |
# Model Definition
|
| 19 |
# ===========================
|
| 20 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 21 |
+
def __init__(self):
|
| 22 |
super().__init__()
|
| 23 |
+
input_dim = 1536 # GPT-2 mean+max pooled embeddings
|
| 24 |
+
hidden_dim = 1024
|
| 25 |
+
output_dim = 1536
|
| 26 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 27 |
self.relu = nn.ReLU()
|
| 28 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
| 36 |
x = self.fc3(x)
|
| 37 |
return x
|
| 38 |
|
|
|
|
| 39 |
# ===========================
|
| 40 |
# Encoder
|
| 41 |
# ===========================
|
| 42 |
+
def build_encoder(model_name="gpt2", max_length: int = 128):
|
| 43 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 44 |
if tokenizer.pad_token is None:
|
| 45 |
tokenizer.pad_token = tokenizer.eos_token
|
| 46 |
+
|
| 47 |
embed_model = AutoModel.from_pretrained(model_name).to(device)
|
| 48 |
embed_model.eval()
|
| 49 |
|
|
|
|
| 51 |
def encode(prompt: str) -> torch.Tensor:
|
| 52 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
|
| 53 |
padding="max_length", max_length=max_length).to(device)
|
| 54 |
+
last_hidden = embed_model(**inputs).last_hidden_state
|
| 55 |
+
mean_pool = last_hidden.mean(dim=1)
|
| 56 |
+
max_pool, _ = last_hidden.max(dim=1)
|
| 57 |
+
return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
|
| 58 |
+
|
| 59 |
return tokenizer, embed_model, encode
|
| 60 |
|
| 61 |
# ===========================
|
|
|
|
| 83 |
def log_fn(msg):
|
| 84 |
logs.append(msg)
|
| 85 |
print(msg)
|
| 86 |
+
|
| 87 |
log_fn("📦 Loading dataset...")
|
| 88 |
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
|
| 89 |
log_fn(f"✅ Loaded {len(dataset)} samples")
|
| 90 |
|
| 91 |
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
|
| 92 |
|
| 93 |
+
# Encode dataset embeddings
|
| 94 |
s_list, l_list = [], []
|
| 95 |
for i, item in enumerate(dataset):
|
| 96 |
s_list.append(encode_fn(item["short_prompt"]))
|
|
|
|
| 125 |
chat = chat or []
|
| 126 |
short_emb = encode_fn(prompt)
|
| 127 |
mapped = model(short_emb.to(device)).cpu()
|
| 128 |
+
long_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
|
| 129 |
chat.append({"role": "user", "content": prompt})
|
| 130 |
chat.append({"role": "assistant", "content": long_prompt})
|
| 131 |
return chat
|
|
|
|
| 135 |
# ===========================
|
| 136 |
# Lazy Load / Get Model
|
| 137 |
# ===========================
|
|
|
|
| 138 |
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 139 |
local_model_path = "model.flashpack"
|
| 140 |
|
|
|
|
| 153 |
print(f"⚠️ Error accessing HF: {e}")
|
| 154 |
return None, None, None, None
|
| 155 |
|
| 156 |
+
model = GemmaTrainer().from_flashpack(local_model_path)
|
|
|
|
| 157 |
model.eval()
|
| 158 |
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
|
| 159 |
|