rahul7star commited on
Commit
ee55050
·
verified ·
1 Parent(s): c4c7a5a

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +12 -4
app_flash.py CHANGED
@@ -63,12 +63,20 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
63
  # 3️⃣ Load pretrained FlashPack model (skip training)
64
  # ============================================================
65
  def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
66
- print(f"🔁 Loading FlashPack model from: {hf_repo}")
67
- model = GemmaTrainer.load_flashpack(hf_repo)
68
- model.eval()
69
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
70
  return model, tokenizer, embed_model
71
 
 
 
 
 
 
 
 
 
 
72
 
73
  # ============================================================
74
  # 4️⃣ Load Gemma text model for prompt enhancement
 
63
  # 3️⃣ Load pretrained FlashPack model (skip training)
64
  # ============================================================
65
  def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
66
+ model = GemmaTrainer.from_flashpack(hf_repo)
67
+ tokenizer = model.tokenizer if hasattr(model, "tokenizer") else None
68
+ embed_model = model.embed_model if hasattr(model, "embed_model") else None
 
69
  return model, tokenizer, embed_model
70
 
71
+ # def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
72
+ # print(f"🔁 Loading FlashPack model from: {hf_repo}")
73
+
74
+ # model = GemmaTrainer.from_flashpack(hf_repo)
75
+
76
+ # model.eval()
77
+ # tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
78
+ # return model, tokenizer, embed_model
79
+
80
 
81
  # ============================================================
82
  # 4️⃣ Load Gemma text model for prompt enhancement