100rabhsah commited on
Commit
2febec8
·
1 Parent(s): 6feb7b4

app.py analyse function fix 5

Browse files
Files changed (1) hide show
  1. src/app.py +36 -40
src/app.py CHANGED
@@ -42,53 +42,49 @@ st.set_page_config(
42
  layout="wide"
43
  )
44
 
45
- # Initialize session state variables
46
- if 'model' not in st.session_state:
47
- st.session_state['model'] = None
48
- if 'tokenizer' not in st.session_state:
49
- st.session_state['tokenizer'] = None
50
- if 'preprocessor' not in st.session_state:
51
- st.session_state['preprocessor'] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def load_model():
54
- """Load the trained model and tokenizer."""
55
- if st.session_state['model'] is None:
56
- # Initialize model
57
- model = HybridFakeNewsDetector(
58
- bert_model_name=BERT_MODEL_NAME,
59
- lstm_hidden_size=LSTM_HIDDEN_SIZE,
60
- lstm_num_layers=LSTM_NUM_LAYERS,
61
- dropout_rate=DROPOUT_RATE
62
- )
63
-
64
- # Load trained weights
65
- state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
66
-
67
- # Filter out unexpected keys
68
- model_state_dict = model.state_dict()
69
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
70
-
71
- # Load the filtered state dict
72
- model.load_state_dict(filtered_state_dict, strict=False)
73
- model.eval()
74
- st.session_state['model'] = model
75
-
76
- # Initialize tokenizer
77
- st.session_state['tokenizer'] = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
78
-
79
- # Initialize preprocessor
80
- st.session_state['preprocessor'] = TextPreprocessor()
81
 
82
  def predict_news(text):
83
  """Predict if the given news is fake or real."""
84
- if st.session_state['model'] is None:
85
- load_model()
 
86
 
87
  # Preprocess text
88
- processed_text = st.session_state['preprocessor'].preprocess_text(text)
89
 
90
  # Tokenize
91
- encoding = st.session_state['tokenizer'].encode_plus(
92
  processed_text,
93
  add_special_tokens=True,
94
  max_length=MAX_SEQUENCE_LENGTH,
@@ -100,7 +96,7 @@ def predict_news(text):
100
 
101
  # Get prediction
102
  with torch.no_grad():
103
- outputs = st.session_state['model'](
104
  encoding['input_ids'],
105
  encoding['attention_mask']
106
  )
 
42
  layout="wide"
43
  )
44
 
45
+ @st.cache_resource
46
+ def load_model_and_tokenizer():
47
+ """Load the model and tokenizer (cached)."""
48
+ # Initialize model
49
+ model = HybridFakeNewsDetector(
50
+ bert_model_name=BERT_MODEL_NAME,
51
+ lstm_hidden_size=LSTM_HIDDEN_SIZE,
52
+ lstm_num_layers=LSTM_NUM_LAYERS,
53
+ dropout_rate=DROPOUT_RATE
54
+ )
55
+
56
+ # Load trained weights
57
+ state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
58
+
59
+ # Filter out unexpected keys
60
+ model_state_dict = model.state_dict()
61
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
62
+
63
+ # Load the filtered state dict
64
+ model.load_state_dict(filtered_state_dict, strict=False)
65
+ model.eval()
66
+
67
+ # Initialize tokenizer
68
+ tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
69
+
70
+ return model, tokenizer
71
 
72
+ @st.cache_resource
73
+ def get_preprocessor():
74
+ """Get the text preprocessor (cached)."""
75
+ return TextPreprocessor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def predict_news(text):
78
  """Predict if the given news is fake or real."""
79
+ # Get model, tokenizer, and preprocessor from cache
80
+ model, tokenizer = load_model_and_tokenizer()
81
+ preprocessor = get_preprocessor()
82
 
83
  # Preprocess text
84
+ processed_text = preprocessor.preprocess_text(text)
85
 
86
  # Tokenize
87
+ encoding = tokenizer.encode_plus(
88
  processed_text,
89
  add_special_tokens=True,
90
  max_length=MAX_SEQUENCE_LENGTH,
 
96
 
97
  # Get prediction
98
  with torch.no_grad():
99
+ outputs = model(
100
  encoding['input_ids'],
101
  encoding['attention_mask']
102
  )