100rabhsah commited on
Commit
6feb7b4
·
1 Parent(s): 3269b9d

app.py analyse function fix 4

Browse files
Files changed (1) hide show
  1. src/app.py +12 -12
src/app.py CHANGED
@@ -42,17 +42,17 @@ st.set_page_config(
42
  layout="wide"
43
  )
44
 
45
- # Initialize session state
46
  if 'model' not in st.session_state:
47
- st.session_state.model = None
48
  if 'tokenizer' not in st.session_state:
49
- st.session_state.tokenizer = None
50
  if 'preprocessor' not in st.session_state:
51
- st.session_state.preprocessor = None
52
 
53
  def load_model():
54
  """Load the trained model and tokenizer."""
55
- if st.session_state.model is None:
56
  # Initialize model
57
  model = HybridFakeNewsDetector(
58
  bert_model_name=BERT_MODEL_NAME,
@@ -71,24 +71,24 @@ def load_model():
71
  # Load the filtered state dict
72
  model.load_state_dict(filtered_state_dict, strict=False)
73
  model.eval()
74
- st.session_state.model = model
75
 
76
  # Initialize tokenizer
77
- st.session_state.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
78
 
79
  # Initialize preprocessor
80
- st.session_state.preprocessor = TextPreprocessor()
81
 
82
  def predict_news(text):
83
  """Predict if the given news is fake or real."""
84
- if st.session_state.model is None:
85
  load_model()
86
 
87
  # Preprocess text
88
- processed_text = st.session_state.preprocessor.preprocess_text(text)
89
 
90
  # Tokenize
91
- encoding = st.session_state.tokenizer.encode_plus(
92
  processed_text,
93
  add_special_tokens=True,
94
  max_length=MAX_SEQUENCE_LENGTH,
@@ -100,7 +100,7 @@ def predict_news(text):
100
 
101
  # Get prediction
102
  with torch.no_grad():
103
- outputs = st.session_state.model(
104
  encoding['input_ids'],
105
  encoding['attention_mask']
106
  )
 
42
  layout="wide"
43
  )
44
 
45
+ # Initialize session state variables
46
  if 'model' not in st.session_state:
47
+ st.session_state['model'] = None
48
  if 'tokenizer' not in st.session_state:
49
+ st.session_state['tokenizer'] = None
50
  if 'preprocessor' not in st.session_state:
51
+ st.session_state['preprocessor'] = None
52
 
53
  def load_model():
54
  """Load the trained model and tokenizer."""
55
+ if st.session_state['model'] is None:
56
  # Initialize model
57
  model = HybridFakeNewsDetector(
58
  bert_model_name=BERT_MODEL_NAME,
 
71
  # Load the filtered state dict
72
  model.load_state_dict(filtered_state_dict, strict=False)
73
  model.eval()
74
+ st.session_state['model'] = model
75
 
76
  # Initialize tokenizer
77
+ st.session_state['tokenizer'] = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
78
 
79
  # Initialize preprocessor
80
+ st.session_state['preprocessor'] = TextPreprocessor()
81
 
82
  def predict_news(text):
83
  """Predict if the given news is fake or real."""
84
+ if st.session_state['model'] is None:
85
  load_model()
86
 
87
  # Preprocess text
88
+ processed_text = st.session_state['preprocessor'].preprocess_text(text)
89
 
90
  # Tokenize
91
+ encoding = st.session_state['tokenizer'].encode_plus(
92
  processed_text,
93
  add_special_tokens=True,
94
  max_length=MAX_SEQUENCE_LENGTH,
 
100
 
101
  # Get prediction
102
  with torch.no_grad():
103
+ outputs = st.session_state['model'](
104
  encoding['input_ids'],
105
  encoding['attention_mask']
106
  )