rahul7star commited on
Commit
44b2be8
·
verified ·
1 Parent(s): 588725c

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +15 -14
app_flash1.py CHANGED
@@ -18,8 +18,11 @@ print(f"🔧 Using device: {device} (CPU-only mode)")
18
  # Model Definition
19
  # ===========================
20
  class GemmaTrainer(nn.Module, FlashPackMixin):
21
- def __init__(self, input_dim: int = 1536, hidden_dim: int = 1024, output_dim: int = 1536):
22
  super().__init__()
 
 
 
23
  self.fc1 = nn.Linear(input_dim, hidden_dim)
24
  self.relu = nn.ReLU()
25
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
@@ -33,14 +36,14 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
33
  x = self.fc3(x)
34
  return x
35
 
36
-
37
  # ===========================
38
  # Encoder
39
  # ===========================
40
- def build_encoder(model_name="gpt2", max_length=128):
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  if tokenizer.pad_token is None:
43
  tokenizer.pad_token = tokenizer.eos_token
 
44
  embed_model = AutoModel.from_pretrained(model_name).to(device)
45
  embed_model.eval()
46
 
@@ -48,11 +51,11 @@ def build_encoder(model_name="gpt2", max_length=128):
48
  def encode(prompt: str) -> torch.Tensor:
49
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
50
  padding="max_length", max_length=max_length).to(device)
51
- hidden = embed_model(**inputs).last_hidden_state
52
- mean_pool = hidden.mean(dim=1)
53
- max_pool, _ = hidden.max(dim=1)
54
- return torch.cat([mean_pool, max_pool], dim=1).cpu()
55
-
56
  return tokenizer, embed_model, encode
57
 
58
  # ===========================
@@ -80,14 +83,14 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
80
  def log_fn(msg):
81
  logs.append(msg)
82
  print(msg)
83
-
84
  log_fn("📦 Loading dataset...")
85
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
86
  log_fn(f"✅ Loaded {len(dataset)} samples")
87
 
88
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
89
 
90
- # Only encode short+long embeddings
91
  s_list, l_list = [], []
92
  for i, item in enumerate(dataset):
93
  s_list.append(encode_fn(item["short_prompt"]))
@@ -122,7 +125,7 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
122
  chat = chat or []
123
  short_emb = encode_fn(prompt)
124
  mapped = model(short_emb.to(device)).cpu()
125
- long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
126
  chat.append({"role": "user", "content": prompt})
127
  chat.append({"role": "assistant", "content": long_prompt})
128
  return chat
@@ -132,7 +135,6 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
132
  # ===========================
133
  # Lazy Load / Get Model
134
  # ===========================
135
- # ===========================
136
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
137
  local_model_path = "model.flashpack"
138
 
@@ -151,8 +153,7 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
151
  print(f"⚠️ Error accessing HF: {e}")
152
  return None, None, None, None
153
 
154
- # ⚡ Use input_dim=1536 (default)
155
- model = GemmaTrainer(input_dim=1536).from_flashpack(local_model_path)
156
  model.eval()
157
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
158
 
 
18
  # Model Definition
19
  # ===========================
20
  class GemmaTrainer(nn.Module, FlashPackMixin):
21
+ def __init__(self):
22
  super().__init__()
23
+ input_dim = 1536 # GPT-2 mean+max pooled embeddings
24
+ hidden_dim = 1024
25
+ output_dim = 1536
26
  self.fc1 = nn.Linear(input_dim, hidden_dim)
27
  self.relu = nn.ReLU()
28
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
 
36
  x = self.fc3(x)
37
  return x
38
 
 
39
  # ===========================
40
  # Encoder
41
  # ===========================
42
+ def build_encoder(model_name="gpt2", max_length: int = 128):
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
  if tokenizer.pad_token is None:
45
  tokenizer.pad_token = tokenizer.eos_token
46
+
47
  embed_model = AutoModel.from_pretrained(model_name).to(device)
48
  embed_model.eval()
49
 
 
51
  def encode(prompt: str) -> torch.Tensor:
52
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
53
  padding="max_length", max_length=max_length).to(device)
54
+ last_hidden = embed_model(**inputs).last_hidden_state
55
+ mean_pool = last_hidden.mean(dim=1)
56
+ max_pool, _ = last_hidden.max(dim=1)
57
+ return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
58
+
59
  return tokenizer, embed_model, encode
60
 
61
  # ===========================
 
83
  def log_fn(msg):
84
  logs.append(msg)
85
  print(msg)
86
+
87
  log_fn("📦 Loading dataset...")
88
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
89
  log_fn(f"✅ Loaded {len(dataset)} samples")
90
 
91
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
92
 
93
+ # Encode dataset embeddings
94
  s_list, l_list = [], []
95
  for i, item in enumerate(dataset):
96
  s_list.append(encode_fn(item["short_prompt"]))
 
125
  chat = chat or []
126
  short_emb = encode_fn(prompt)
127
  mapped = model(short_emb.to(device)).cpu()
128
+ long_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
129
  chat.append({"role": "user", "content": prompt})
130
  chat.append({"role": "assistant", "content": long_prompt})
131
  return chat
 
135
  # ===========================
136
  # Lazy Load / Get Model
137
  # ===========================
 
138
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
139
  local_model_path = "model.flashpack"
140
 
 
153
  print(f"⚠️ Error accessing HF: {e}")
154
  return None, None, None, None
155
 
156
+ model = GemmaTrainer().from_flashpack(local_model_path)
 
157
  model.eval()
158
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
159