Spaces:
Sleeping
Sleeping
100rabhsah
commited on
Commit
·
769dd6f
1
Parent(s):
8a511c5
Add model files using Git LFS
Browse files- README.md +158 -12
- models/saved/final_model.pt +3 -0
- notebooks/fake_news_detection_colab.ipynb +859 -0
- requirements.txt +20 -0
- src/app.py +226 -0
- src/config/__pycache__/config.cpython-311.pyc +0 -0
- src/config/config.py +42 -0
- src/data/__pycache__/dataset.cpython-311.pyc +0 -0
- src/data/__pycache__/preprocessor.cpython-311.pyc +0 -0
- src/data/dataset.py +108 -0
- src/data/download_datasets.py +166 -0
- src/data/feature_extractor.py +82 -0
- src/data/preprocessor.py +89 -0
- src/models/__pycache__/hybrid_model.cpython-311.pyc +0 -0
- src/models/__pycache__/trainer.cpython-311.pyc +0 -0
- src/models/hybrid_model.py +87 -0
- src/models/trainer.py +165 -0
- src/train.py +161 -0
- src/visualization/plot_metrics.py +207 -0
README.md
CHANGED
|
@@ -1,12 +1,158 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hybrid Fake News Detection Model
|
| 2 |
+
|
| 3 |
+
A hybrid deep learning model for fake news detection using BERT and BiLSTM with attention mechanism. This project was developed as part of the Data Mining Laboratory course under the guidance of Dr. Kirti Kumari.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
This project implements a state-of-the-art fake news detection system that combines the power of BERT (Bidirectional Encoder Representations from Transformers) with BiLSTM (Bidirectional Long Short-Term Memory) and attention mechanisms. The model is designed to effectively identify fake news articles by analyzing their textual content and linguistic patterns.
|
| 8 |
+
|
| 9 |
+
## Data and Model Files
|
| 10 |
+
|
| 11 |
+
The project uses the following datasets and model files:
|
| 12 |
+
|
| 13 |
+
### Datasets
|
| 14 |
+
- Raw and processed datasets are available at: [Data Files](https://drive.google.com/drive/folders/1uFtWVEjqupSGV7_6sYAxPG52Je1MAigh?usp=sharing)
|
| 15 |
+
- Contains both raw and processed versions of the datasets
|
| 16 |
+
- Includes LIAR and Kaggle Fake News datasets
|
| 17 |
+
- Preprocessed versions ready for training
|
| 18 |
+
|
| 19 |
+
### Model Files
|
| 20 |
+
- Trained model checkpoints are available at: [Model Files](https://drive.google.com/drive/folders/1d1EXjLlYof56yEa9F6qFDPKqO359vnRw?usp=sharing)
|
| 21 |
+
- Contains saved model weights
|
| 22 |
+
- Includes best model checkpoints
|
| 23 |
+
- Model evaluation results
|
| 24 |
+
|
| 25 |
+
## Project Structure
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
.
|
| 29 |
+
├── data/
|
| 30 |
+
│ ├── raw/ # Raw datasets
|
| 31 |
+
│ └── processed/ # Processed data
|
| 32 |
+
├── models/
|
| 33 |
+
│ ├── saved/ # Saved model checkpoints
|
| 34 |
+
│ └── checkpoints/ # Training checkpoints
|
| 35 |
+
├── src/
|
| 36 |
+
│ ├── config/ # Configuration files
|
| 37 |
+
│ ├── data/ # Data processing modules
|
| 38 |
+
│ ├── models/ # Model architecture
|
| 39 |
+
│ ├── utils/ # Utility functions
|
| 40 |
+
│ └── visualization/# Visualization modules
|
| 41 |
+
├── tests/ # Unit tests
|
| 42 |
+
├── notebooks/ # Jupyter notebooks
|
| 43 |
+
└── visualizations/ # Generated plots and graphs
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Features
|
| 47 |
+
|
| 48 |
+
- Hybrid architecture combining BERT and BiLSTM
|
| 49 |
+
- Attention mechanism for better interpretability
|
| 50 |
+
- Comprehensive text preprocessing pipeline
|
| 51 |
+
- Support for multiple feature extraction methods
|
| 52 |
+
- Early stopping and model checkpointing
|
| 53 |
+
- Detailed evaluation metrics
|
| 54 |
+
- Interactive visualizations of model performance
|
| 55 |
+
- Support for multiple datasets (LIAR, Kaggle Fake News)
|
| 56 |
+
|
| 57 |
+
## Installation
|
| 58 |
+
|
| 59 |
+
1. Clone the repository:
|
| 60 |
+
```bash
|
| 61 |
+
git clone https://github.com/yourusername/fake-news-detection.git
|
| 62 |
+
cd fake-news-detection
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
2. Create a virtual environment:
|
| 66 |
+
```bash
|
| 67 |
+
python -m venv venv
|
| 68 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
3. Install dependencies:
|
| 72 |
+
```bash
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Usage
|
| 77 |
+
|
| 78 |
+
1. Download the required files:
|
| 79 |
+
- Download datasets from the [Data Files](https://drive.google.com/drive/folders/1uFtWVEjqupSGV7_6sYAxPG52Je1MAigh?usp=sharing) link
|
| 80 |
+
- Download pre-trained models from the [Model Files](https://drive.google.com/drive/folders/1d1EXjLlYof56yEa9F6qFDPKqO359vnRw?usp=sharing) link
|
| 81 |
+
- Place the files in their respective directories as shown in the project structure
|
| 82 |
+
|
| 83 |
+
2. Prepare your dataset:
|
| 84 |
+
- Place your dataset in the `data/raw` directory
|
| 85 |
+
- The dataset should have at least two columns: 'text' and 'label'
|
| 86 |
+
- Supported formats: CSV, TSV
|
| 87 |
+
|
| 88 |
+
3. Train the model:
|
| 89 |
+
```bash
|
| 90 |
+
python src/train.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
4. Model evaluation metrics and visualizations will be generated in the `visualizations` directory
|
| 94 |
+
|
| 95 |
+
## Model Architecture
|
| 96 |
+
|
| 97 |
+
The model combines:
|
| 98 |
+
- BERT for contextual embeddings
|
| 99 |
+
- BiLSTM for sequence modeling
|
| 100 |
+
- Attention mechanism for focusing on important parts
|
| 101 |
+
- Classification head for final prediction
|
| 102 |
+
|
| 103 |
+
### Key Components:
|
| 104 |
+
- **BERT Layer**: Extracts contextual word embeddings
|
| 105 |
+
- **BiLSTM Layer**: Captures sequential patterns
|
| 106 |
+
- **Attention Layer**: Identifies important text segments
|
| 107 |
+
- **Classification Head**: Makes final prediction
|
| 108 |
+
|
| 109 |
+
## Configuration
|
| 110 |
+
|
| 111 |
+
Key parameters can be modified in `src/config/config.py`:
|
| 112 |
+
- Model hyperparameters
|
| 113 |
+
- Training parameters
|
| 114 |
+
- Data processing settings
|
| 115 |
+
- Feature extraction options
|
| 116 |
+
|
| 117 |
+
## Performance Metrics
|
| 118 |
+
|
| 119 |
+
The model is evaluated using:
|
| 120 |
+
- Accuracy
|
| 121 |
+
- Precision
|
| 122 |
+
- Recall
|
| 123 |
+
- F1 Score
|
| 124 |
+
- Confusion Matrix
|
| 125 |
+
|
| 126 |
+
## Future Improvements
|
| 127 |
+
|
| 128 |
+
- [ ] Add support for image/video metadata
|
| 129 |
+
- [ ] Implement real-time detection
|
| 130 |
+
- [ ] Add social graph analysis
|
| 131 |
+
- [ ] Improve model interpretability
|
| 132 |
+
- [ ] Add API endpoints for inference
|
| 133 |
+
- [ ] Support for multilingual fake news detection
|
| 134 |
+
- [ ] Integration with fact-checking databases
|
| 135 |
+
|
| 136 |
+
## Acknowledgments
|
| 137 |
+
|
| 138 |
+
I would like to express our sincere gratitude to **Dr. Kirti Kumari** for her invaluable guidance and support throughout the development of this project. Her expertise in data mining and machine learning has been instrumental in shaping this work.
|
| 139 |
+
|
| 140 |
+
Special thanks to:
|
| 141 |
+
- Open-source community for their excellent tools and libraries
|
| 142 |
+
- Dataset providers (LIAR, Kaggle)
|
| 143 |
+
|
| 144 |
+
## Contributing
|
| 145 |
+
|
| 146 |
+
1. Fork the repository
|
| 147 |
+
2. Create a feature branch
|
| 148 |
+
3. Commit your changes
|
| 149 |
+
4. Push to the branch
|
| 150 |
+
5. Create a Pull Request
|
| 151 |
+
|
| 152 |
+
## License
|
| 153 |
+
|
| 154 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 155 |
+
|
| 156 |
+
## Contact
|
| 157 |
+
|
| 158 |
+
For any queries or suggestions, please feel free to reach out to me.
|
models/saved/final_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca5ffb031ff346f0ad0fee2970b1197b793036861a5610fdd9ea17bbccde0d1b
|
| 3 |
+
size 442092217
|
notebooks/fake_news_detection_colab.ipynb
ADDED
|
@@ -0,0 +1,859 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"vscode": {
|
| 8 |
+
"languageId": "plaintext"
|
| 9 |
+
}
|
| 10 |
+
},
|
| 11 |
+
"outputs": [],
|
| 12 |
+
"source": [
|
| 13 |
+
"# Fake News Detection using BERT-BiLSTM-Attention\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"This notebook is optimized for Google Colab free version with the following optimizations:\n",
|
| 16 |
+
"- Reduced model size\n",
|
| 17 |
+
"- Optimized memory usage\n",
|
| 18 |
+
"- Efficient data loading\n",
|
| 19 |
+
"- Gradient checkpointing\n",
|
| 20 |
+
"- Mixed precision training\n"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"metadata": {
|
| 27 |
+
"vscode": {
|
| 28 |
+
"languageId": "plaintext"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"## 1. Setup and Installation\n"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"# Install required packages\n",
|
| 43 |
+
"!pip install torch==2.0.1 transformers==4.30.2 nltk==3.8.1 pandas==2.0.3 numpy==1.24.3 scikit-learn==1.3.0 tqdm==4.65.0\n"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"# Import required libraries\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"import torch.nn as nn\n",
|
| 55 |
+
"import torch.optim as optim\n",
|
| 56 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 57 |
+
"from transformers import BertModel, BertTokenizer\n",
|
| 58 |
+
"import pandas as pd\n",
|
| 59 |
+
"import numpy as np\n",
|
| 60 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 61 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
| 62 |
+
"import nltk\n",
|
| 63 |
+
"from nltk.tokenize import word_tokenize\n",
|
| 64 |
+
"from nltk.corpus import stopwords\n",
|
| 65 |
+
"import re\n",
|
| 66 |
+
"from tqdm import tqdm\n",
|
| 67 |
+
"import gc\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# Download NLTK data\n",
|
| 70 |
+
"nltk.download('punkt')\n",
|
| 71 |
+
"nltk.download('stopwords')\n",
|
| 72 |
+
"nltk.download('wordnet')\n"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": null,
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [],
|
| 80 |
+
"source": [
|
| 81 |
+
"## 2. Configuration and Constants\n"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": null,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"# Optimized for Colab free version\n",
|
| 91 |
+
"class Config:\n",
|
| 92 |
+
" # Model parameters\n",
|
| 93 |
+
" MAX_SEQUENCE_LENGTH = 128 # Reduced from 256\n",
|
| 94 |
+
" VOCAB_SIZE = 10000 # Reduced from 15000\n",
|
| 95 |
+
" EMBEDDING_DIM = 64 # Reduced from 128\n",
|
| 96 |
+
" HIDDEN_DIM = 128 # Reduced from 256\n",
|
| 97 |
+
" \n",
|
| 98 |
+
" # Training parameters\n",
|
| 99 |
+
" BATCH_SIZE = 4 # Reduced from 8\n",
|
| 100 |
+
" NUM_EPOCHS = 2 # Reduced from 3\n",
|
| 101 |
+
" LEARNING_RATE = 2e-5\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" # Dataset parameters\n",
|
| 104 |
+
" MAX_SAMPLES = 5000 # Reduced from 10000\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" # Device configuration\n",
|
| 107 |
+
" DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 108 |
+
" \n",
|
| 109 |
+
" # Model paths\n",
|
| 110 |
+
" MODEL_NAME = 'bert-base-uncased'\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" # Enable mixed precision\n",
|
| 113 |
+
" USE_AMP = True\n",
|
| 114 |
+
" \n",
|
| 115 |
+
" # Enable gradient checkpointing\n",
|
| 116 |
+
" USE_GRADIENT_CHECKPOINTING = True\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"config = Config()\n",
|
| 119 |
+
"print(f\"Using device: {config.DEVICE}\")\n",
|
| 120 |
+
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 121 |
+
"if torch.cuda.is_available():\n",
|
| 122 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 123 |
+
" print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\")\n"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": null,
|
| 129 |
+
"metadata": {},
|
| 130 |
+
"outputs": [],
|
| 131 |
+
"source": [
|
| 132 |
+
"## 3. Data Loading and Preprocessing\n"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"# Dataset Sources:\n",
|
| 142 |
+
"# 1. Kaggle Fake and Real News Dataset: https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset\n",
|
| 143 |
+
"# 2. LIAR Dataset: https://sites.cs.ucsb.edu/~william/data/liar_dataset.zip\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"import zipfile\n",
|
| 146 |
+
"import urllib.request\n",
|
| 147 |
+
"import os\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"def download_datasets():\n",
|
| 150 |
+
" \"\"\"Download and prepare the datasets\"\"\"\n",
|
| 151 |
+
" \n",
|
| 152 |
+
" # Download LIAR dataset\n",
|
| 153 |
+
" print(\"Downloading LIAR dataset...\")\n",
|
| 154 |
+
" liar_url = \"https://sites.cs.ucsb.edu/~william/data/liar_dataset.zip\"\n",
|
| 155 |
+
" liar_zip = \"liar_dataset.zip\"\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" try:\n",
|
| 158 |
+
" urllib.request.urlretrieve(liar_url, liar_zip)\n",
|
| 159 |
+
" \n",
|
| 160 |
+
" # Extract the zip file\n",
|
| 161 |
+
" with zipfile.ZipFile(liar_zip, 'r') as zip_ref:\n",
|
| 162 |
+
" zip_ref.extractall(\"liar_dataset/\")\n",
|
| 163 |
+
" \n",
|
| 164 |
+
" print(\"LIAR dataset downloaded and extracted successfully\")\n",
|
| 165 |
+
" os.remove(liar_zip) # Clean up zip file\n",
|
| 166 |
+
" \n",
|
| 167 |
+
" except Exception as e:\n",
|
| 168 |
+
" print(f\"Error downloading LIAR dataset: {e}\")\n",
|
| 169 |
+
" \n",
|
| 170 |
+
" # For Kaggle dataset, we'll use a sample since direct download requires API key\n",
|
| 171 |
+
" print(\"Setting up Kaggle dataset alternative...\")\n",
|
| 172 |
+
" try:\n",
|
| 173 |
+
" # Try to download a sample of the Kaggle dataset\n",
|
| 174 |
+
" kaggle_url = \"https://raw.githubusercontent.com/several27/FakeNewsCorpus/master/news_sample.csv\"\n",
|
| 175 |
+
" urllib.request.urlretrieve(kaggle_url, \"kaggle_news_sample.csv\")\n",
|
| 176 |
+
" print(\"Kaggle sample dataset downloaded successfully\")\n",
|
| 177 |
+
" except Exception as e:\n",
|
| 178 |
+
" print(f\"Could not download Kaggle sample: {e}\")\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"def load_liar_dataset(max_samples=None):\n",
|
| 181 |
+
" \"\"\"Load and process LIAR dataset\"\"\"\n",
|
| 182 |
+
" try:\n",
|
| 183 |
+
" # Load train, validation, and test sets\n",
|
| 184 |
+
" train_df = pd.read_csv(\"liar_dataset/train.tsv\", sep='\\t', header=None)\n",
|
| 185 |
+
" val_df = pd.read_csv(\"liar_dataset/valid.tsv\", sep='\\t', header=None)\n",
|
| 186 |
+
" test_df = pd.read_csv(\"liar_dataset/test.tsv\", sep='\\t', header=None)\n",
|
| 187 |
+
" \n",
|
| 188 |
+
" # Column names for LIAR dataset\n",
|
| 189 |
+
" columns = ['id', 'label', 'statement', 'subjects', 'speaker', 'speaker_job', \n",
|
| 190 |
+
" 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts',\n",
|
| 191 |
+
" 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context']\n",
|
| 192 |
+
" \n",
|
| 193 |
+
" train_df.columns = columns\n",
|
| 194 |
+
" val_df.columns = columns\n",
|
| 195 |
+
" test_df.columns = columns\n",
|
| 196 |
+
" \n",
|
| 197 |
+
" # Combine all datasets\n",
|
| 198 |
+
" df = pd.concat([train_df, val_df, test_df], ignore_index=True)\n",
|
| 199 |
+
" \n",
|
| 200 |
+
" # Convert labels to binary (fake/real)\n",
|
| 201 |
+
" # Consider 'false', 'barely-true', 'pants-fire' as fake (1)\n",
|
| 202 |
+
" # Consider 'true', 'mostly-true', 'half-true' as real (0)\n",
|
| 203 |
+
" fake_labels = ['false', 'barely-true', 'pants-fire']\n",
|
| 204 |
+
" df['binary_label'] = df['label'].apply(lambda x: 1 if x in fake_labels else 0)\n",
|
| 205 |
+
" \n",
|
| 206 |
+
" # Use statement as text\n",
|
| 207 |
+
" df = df[['statement', 'binary_label']].rename(columns={'statement': 'text', 'binary_label': 'label'})\n",
|
| 208 |
+
" \n",
|
| 209 |
+
" print(f\"LIAR dataset loaded: {len(df)} samples\")\n",
|
| 210 |
+
" return df\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" except Exception as e:\n",
|
| 213 |
+
" print(f\"Error loading LIAR dataset: {e}\")\n",
|
| 214 |
+
" return None\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"def load_kaggle_dataset(max_samples=None):\n",
|
| 217 |
+
" \"\"\"Load and process Kaggle dataset\"\"\"\n",
|
| 218 |
+
" try:\n",
|
| 219 |
+
" df = pd.read_csv(\"kaggle_news_sample.csv\")\n",
|
| 220 |
+
" \n",
|
| 221 |
+
" # Map labels to binary if needed\n",
|
| 222 |
+
" if 'label' in df.columns:\n",
|
| 223 |
+
" # Handle different label formats\n",
|
| 224 |
+
" if df['label'].dtype == 'object':\n",
|
| 225 |
+
" df['label'] = df['label'].map({'FAKE': 1, 'REAL': 0, 'fake': 1, 'real': 0})\n",
|
| 226 |
+
" \n",
|
| 227 |
+
" # Use appropriate text column\n",
|
| 228 |
+
" text_columns = ['text', 'title', 'content', 'article']\n",
|
| 229 |
+
" text_col = None\n",
|
| 230 |
+
" for col in text_columns:\n",
|
| 231 |
+
" if col in df.columns:\n",
|
| 232 |
+
" text_col = col\n",
|
| 233 |
+
" break\n",
|
| 234 |
+
" \n",
|
| 235 |
+
" if text_col:\n",
|
| 236 |
+
" df = df[[text_col, 'label']].rename(columns={text_col: 'text'})\n",
|
| 237 |
+
" \n",
|
| 238 |
+
" print(f\"Kaggle dataset loaded: {len(df)} samples\")\n",
|
| 239 |
+
" return df\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" except Exception as e:\n",
|
| 242 |
+
" print(f\"Error loading Kaggle dataset: {e}\")\n",
|
| 243 |
+
" return None\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"def load_combined_data(max_samples=config.MAX_SAMPLES):\n",
|
| 246 |
+
" \"\"\"Load and combine both datasets\"\"\"\n",
|
| 247 |
+
" \n",
|
| 248 |
+
" # Download datasets\n",
|
| 249 |
+
" download_datasets()\n",
|
| 250 |
+
" \n",
|
| 251 |
+
" # Load datasets\n",
|
| 252 |
+
" liar_df = load_liar_dataset()\n",
|
| 253 |
+
" kaggle_df = load_kaggle_dataset()\n",
|
| 254 |
+
" \n",
|
| 255 |
+
" # Combine datasets\n",
|
| 256 |
+
" dfs = []\n",
|
| 257 |
+
" if liar_df is not None:\n",
|
| 258 |
+
" dfs.append(liar_df)\n",
|
| 259 |
+
" print(f\"LIAR dataset: {len(liar_df)} samples\")\n",
|
| 260 |
+
" \n",
|
| 261 |
+
" if kaggle_df is not None:\n",
|
| 262 |
+
" dfs.append(kaggle_df)\n",
|
| 263 |
+
" print(f\"Kaggle dataset: {len(kaggle_df)} samples\")\n",
|
| 264 |
+
" \n",
|
| 265 |
+
" if dfs:\n",
|
| 266 |
+
" df = pd.concat(dfs, ignore_index=True)\n",
|
| 267 |
+
" print(f\"Combined dataset: {len(df)} samples\")\n",
|
| 268 |
+
" else:\n",
|
| 269 |
+
" # Fallback to dummy data\n",
|
| 270 |
+
" print(\"Creating dummy dataset for testing...\")\n",
|
| 271 |
+
" texts = [\n",
|
| 272 |
+
" \"President announces new economic policy to boost growth\",\n",
|
| 273 |
+
" \"Scientists confirm breakthrough in renewable energy technology\", \n",
|
| 274 |
+
" \"False: Celebrities endorse dangerous health treatment\",\n",
|
| 275 |
+
" \"Misleading: Government hiding alien contact information\",\n",
|
| 276 |
+
" \"Local community rallies to support flood victims\",\n",
|
| 277 |
+
" \"Breaking: Major scientific discovery changes understanding of physics\"\n",
|
| 278 |
+
" ] * (max_samples // 6)\n",
|
| 279 |
+
" \n",
|
| 280 |
+
" labels = [0, 0, 1, 1, 0, 0] * (max_samples // 6)\n",
|
| 281 |
+
" \n",
|
| 282 |
+
" df = pd.DataFrame({\n",
|
| 283 |
+
" 'text': texts[:max_samples],\n",
|
| 284 |
+
" 'label': labels[:max_samples]\n",
|
| 285 |
+
" })\n",
|
| 286 |
+
" print(f\"Created dummy dataset with {len(df)} samples\")\n",
|
| 287 |
+
" \n",
|
| 288 |
+
" # Remove missing values\n",
|
| 289 |
+
" df = df.dropna()\n",
|
| 290 |
+
" \n",
|
| 291 |
+
" # Sample data for faster training if needed\n",
|
| 292 |
+
" if max_samples and len(df) > max_samples:\n",
|
| 293 |
+
" df = df.sample(n=max_samples, random_state=42)\n",
|
| 294 |
+
" print(f\"Sampled to {len(df)} samples for faster training\")\n",
|
| 295 |
+
" \n",
|
| 296 |
+
" return df\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"# Text preprocessing\n",
|
| 299 |
+
"def preprocess_text(text):\n",
|
| 300 |
+
" if pd.isna(text):\n",
|
| 301 |
+
" return \"\"\n",
|
| 302 |
+
" text = str(text)\n",
|
| 303 |
+
" # Convert to lowercase\n",
|
| 304 |
+
" text = text.lower()\n",
|
| 305 |
+
" # Remove special characters but keep basic punctuation\n",
|
| 306 |
+
" text = re.sub(r'[^\\w\\s.,!?]', '', text)\n",
|
| 307 |
+
" # Remove extra whitespace\n",
|
| 308 |
+
" text = ' '.join(text.split())\n",
|
| 309 |
+
" # Limit length to prevent very long texts\n",
|
| 310 |
+
" text = text[:1000] # Limit to 1000 characters\n",
|
| 311 |
+
" return text\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"# Load the datasets\n",
|
| 314 |
+
"print(\"Loading datasets...\")\n",
|
| 315 |
+
"df = load_combined_data()\n",
|
| 316 |
+
"print(f\"Final dataset shape: {df.shape}\")\n",
|
| 317 |
+
"print(f\"Columns: {df.columns.tolist()}\")\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"if len(df) > 0:\n",
|
| 320 |
+
" print(f\"Sample text: {df.iloc[0]['text'][:100]}...\")\n",
|
| 321 |
+
" print(f\"Label distribution:\")\n",
|
| 322 |
+
" print(df['label'].value_counts())\n",
|
| 323 |
+
" print(f\"Label distribution percentage:\")\n",
|
| 324 |
+
" print(df['label'].value_counts(normalize=True) * 100)\n"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "code",
|
| 329 |
+
"execution_count": null,
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [],
|
| 332 |
+
"source": [
|
| 333 |
+
"### Optional: Download Kaggle Dataset Directly (If you have Kaggle API)\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"If you have Kaggle API credentials, you can download the full dataset by running the following cells. Otherwise, the notebook will use alternative sources.\n"
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"cell_type": "code",
|
| 340 |
+
"execution_count": null,
|
| 341 |
+
"metadata": {},
|
| 342 |
+
"outputs": [],
|
| 343 |
+
"source": [
|
| 344 |
+
"# Optional: Kaggle API setup (uncomment and run if you have Kaggle credentials)\n",
|
| 345 |
+
"# !pip install kaggle\n",
|
| 346 |
+
"# !mkdir -p ~/.kaggle\n",
|
| 347 |
+
"# # Upload your kaggle.json file to Colab files, then run:\n",
|
| 348 |
+
"# # !cp kaggle.json ~/.kaggle/\n",
|
| 349 |
+
"# # !chmod 600 ~/.kaggle/kaggle.json\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"# Download the full Kaggle dataset (uncomment if you have API access)\n",
|
| 352 |
+
"# !kaggle datasets download -d clmentbisaillon/fake-and-real-news-dataset\n",
|
| 353 |
+
"# !unzip fake-and-real-news-dataset.zip\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"def load_full_kaggle_dataset():\n",
|
| 356 |
+
" \"\"\"Load the full Kaggle dataset if available\"\"\"\n",
|
| 357 |
+
" try:\n",
|
| 358 |
+
" # Try to load the full dataset files\n",
|
| 359 |
+
" fake_df = pd.read_csv(\"Fake.csv\")\n",
|
| 360 |
+
" real_df = pd.read_csv(\"True.csv\")\n",
|
| 361 |
+
" \n",
|
| 362 |
+
" # Add labels\n",
|
| 363 |
+
" fake_df['label'] = 1\n",
|
| 364 |
+
" real_df['label'] = 0\n",
|
| 365 |
+
" \n",
|
| 366 |
+
" # Combine datasets\n",
|
| 367 |
+
" df = pd.concat([fake_df, real_df], ignore_index=True)\n",
|
| 368 |
+
" \n",
|
| 369 |
+
" # Use title + text as the full text\n",
|
| 370 |
+
" if 'title' in df.columns and 'text' in df.columns:\n",
|
| 371 |
+
" df['full_text'] = df['title'] + \". \" + df['text']\n",
|
| 372 |
+
" df = df[['full_text', 'label']].rename(columns={'full_text': 'text'})\n",
|
| 373 |
+
" elif 'text' in df.columns:\n",
|
| 374 |
+
" df = df[['text', 'label']]\n",
|
| 375 |
+
" \n",
|
| 376 |
+
" print(f\"Full Kaggle dataset loaded: {len(df)} samples\")\n",
|
| 377 |
+
" return df\n",
|
| 378 |
+
" \n",
|
| 379 |
+
" except Exception as e:\n",
|
| 380 |
+
" print(f\"Full Kaggle dataset not available: {e}\")\n",
|
| 381 |
+
" return None\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"# Try to load full Kaggle dataset\n",
|
| 384 |
+
"full_kaggle_df = load_full_kaggle_dataset()\n",
|
| 385 |
+
"if full_kaggle_df is not None:\n",
|
| 386 |
+
" print(\"Using full Kaggle dataset\")\n",
|
| 387 |
+
" # Update the df variable to use full dataset\n",
|
| 388 |
+
" df = load_combined_data() # This will still use the combined approach if full isn't available\n"
|
| 389 |
+
]
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"cell_type": "code",
|
| 393 |
+
"execution_count": null,
|
| 394 |
+
"metadata": {},
|
| 395 |
+
"outputs": [],
|
| 396 |
+
"source": [
|
| 397 |
+
"# Create dataset class\n",
|
| 398 |
+
"class FakeNewsDataset(Dataset):\n",
|
| 399 |
+
" def __init__(self, texts, labels, tokenizer, max_length):\n",
|
| 400 |
+
" self.texts = texts\n",
|
| 401 |
+
" self.labels = labels\n",
|
| 402 |
+
" self.tokenizer = tokenizer\n",
|
| 403 |
+
" self.max_length = max_length\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" def __len__(self):\n",
|
| 406 |
+
" return len(self.texts)\n",
|
| 407 |
+
" \n",
|
| 408 |
+
" def __getitem__(self, idx):\n",
|
| 409 |
+
" text = str(self.texts[idx])\n",
|
| 410 |
+
" label = self.labels[idx]\n",
|
| 411 |
+
" \n",
|
| 412 |
+
" # Preprocess text\n",
|
| 413 |
+
" text = preprocess_text(text)\n",
|
| 414 |
+
" \n",
|
| 415 |
+
" encoding = self.tokenizer.encode_plus(\n",
|
| 416 |
+
" text,\n",
|
| 417 |
+
" add_special_tokens=True,\n",
|
| 418 |
+
" max_length=self.max_length,\n",
|
| 419 |
+
" padding='max_length',\n",
|
| 420 |
+
" truncation=True,\n",
|
| 421 |
+
" return_attention_mask=True,\n",
|
| 422 |
+
" return_tensors='pt'\n",
|
| 423 |
+
" )\n",
|
| 424 |
+
" \n",
|
| 425 |
+
" return {\n",
|
| 426 |
+
" 'input_ids': encoding['input_ids'].flatten(),\n",
|
| 427 |
+
" 'attention_mask': encoding['attention_mask'].flatten(),\n",
|
| 428 |
+
" 'label': torch.tensor(label, dtype=torch.long)\n",
|
| 429 |
+
" }\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"print(\"Dataset class created successfully\")\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"cell_type": "code",
|
| 436 |
+
"execution_count": null,
|
| 437 |
+
"metadata": {},
|
| 438 |
+
"outputs": [],
|
| 439 |
+
"source": [
|
| 440 |
+
"## 4. Model Architecture\n"
|
| 441 |
+
]
|
| 442 |
+
},
|
| 443 |
+
{
|
| 444 |
+
"cell_type": "code",
|
| 445 |
+
"execution_count": null,
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"outputs": [],
|
| 448 |
+
"source": [
|
| 449 |
+
"class FakeNewsModel(nn.Module):\n",
|
| 450 |
+
" def __init__(self, config):\n",
|
| 451 |
+
" super(FakeNewsModel, self).__init__()\n",
|
| 452 |
+
" \n",
|
| 453 |
+
" # BERT layer\n",
|
| 454 |
+
" self.bert = BertModel.from_pretrained(config.MODEL_NAME)\n",
|
| 455 |
+
" if config.USE_GRADIENT_CHECKPOINTING:\n",
|
| 456 |
+
" self.bert.gradient_checkpointing_enable()\n",
|
| 457 |
+
" \n",
|
| 458 |
+
" # BiLSTM layer\n",
|
| 459 |
+
" self.lstm = nn.LSTM(\n",
|
| 460 |
+
" input_size=768, # BERT output size\n",
|
| 461 |
+
" hidden_size=config.HIDDEN_DIM,\n",
|
| 462 |
+
" num_layers=1,\n",
|
| 463 |
+
" batch_first=True,\n",
|
| 464 |
+
" bidirectional=True,\n",
|
| 465 |
+
" dropout=0.1\n",
|
| 466 |
+
" )\n",
|
| 467 |
+
" \n",
|
| 468 |
+
" # Attention layer\n",
|
| 469 |
+
" self.attention = nn.Sequential(\n",
|
| 470 |
+
" nn.Linear(config.HIDDEN_DIM * 2, config.HIDDEN_DIM),\n",
|
| 471 |
+
" nn.Tanh(),\n",
|
| 472 |
+
" nn.Linear(config.HIDDEN_DIM, 1)\n",
|
| 473 |
+
" )\n",
|
| 474 |
+
" \n",
|
| 475 |
+
" # Classification head\n",
|
| 476 |
+
" self.classifier = nn.Sequential(\n",
|
| 477 |
+
" nn.Dropout(0.3),\n",
|
| 478 |
+
" nn.Linear(config.HIDDEN_DIM * 2, 64),\n",
|
| 479 |
+
" nn.ReLU(),\n",
|
| 480 |
+
" nn.Dropout(0.2),\n",
|
| 481 |
+
" nn.Linear(64, 2)\n",
|
| 482 |
+
" )\n",
|
| 483 |
+
" \n",
|
| 484 |
+
" def forward(self, input_ids, attention_mask):\n",
|
| 485 |
+
" # BERT\n",
|
| 486 |
+
" bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]\n",
|
| 487 |
+
" \n",
|
| 488 |
+
" # BiLSTM\n",
|
| 489 |
+
" lstm_output, _ = self.lstm(bert_output)\n",
|
| 490 |
+
" \n",
|
| 491 |
+
" # Attention mechanism\n",
|
| 492 |
+
" attention_scores = self.attention(lstm_output)\n",
|
| 493 |
+
" attention_weights = torch.softmax(attention_scores, dim=1)\n",
|
| 494 |
+
" attended_output = torch.sum(attention_weights * lstm_output, dim=1)\n",
|
| 495 |
+
" \n",
|
| 496 |
+
" # Classification\n",
|
| 497 |
+
" logits = self.classifier(attended_output)\n",
|
| 498 |
+
" \n",
|
| 499 |
+
" return logits\n",
|
| 500 |
+
"\n",
|
| 501 |
+
"print(\"Model architecture defined successfully\")\n"
|
| 502 |
+
]
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"cell_type": "code",
|
| 506 |
+
"execution_count": null,
|
| 507 |
+
"metadata": {},
|
| 508 |
+
"outputs": [],
|
| 509 |
+
"source": [
|
| 510 |
+
"## 5. Training Functions\n"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "code",
|
| 515 |
+
"execution_count": null,
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"outputs": [],
|
| 518 |
+
"source": [
|
| 519 |
+
"def train_epoch(model, train_loader, optimizer, criterion, scaler, config):\n",
|
| 520 |
+
" model.train()\n",
|
| 521 |
+
" total_loss = 0\n",
|
| 522 |
+
" \n",
|
| 523 |
+
" progress_bar = tqdm(train_loader, desc='Training')\n",
|
| 524 |
+
" for batch in progress_bar:\n",
|
| 525 |
+
" input_ids = batch['input_ids'].to(config.DEVICE)\n",
|
| 526 |
+
" attention_mask = batch['attention_mask'].to(config.DEVICE)\n",
|
| 527 |
+
" labels = batch['label'].to(config.DEVICE)\n",
|
| 528 |
+
" \n",
|
| 529 |
+
" optimizer.zero_grad()\n",
|
| 530 |
+
" \n",
|
| 531 |
+
" if config.USE_AMP and torch.cuda.is_available():\n",
|
| 532 |
+
" with torch.cuda.amp.autocast():\n",
|
| 533 |
+
" outputs = model(input_ids, attention_mask)\n",
|
| 534 |
+
" loss = criterion(outputs, labels)\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" scaler.scale(loss).backward()\n",
|
| 537 |
+
" scaler.step(optimizer)\n",
|
| 538 |
+
" scaler.update()\n",
|
| 539 |
+
" else:\n",
|
| 540 |
+
" outputs = model(input_ids, attention_mask)\n",
|
| 541 |
+
" loss = criterion(outputs, labels)\n",
|
| 542 |
+
" loss.backward()\n",
|
| 543 |
+
" optimizer.step()\n",
|
| 544 |
+
" \n",
|
| 545 |
+
" total_loss += loss.item()\n",
|
| 546 |
+
" progress_bar.set_postfix({'loss': loss.item()})\n",
|
| 547 |
+
" \n",
|
| 548 |
+
" # Clear memory\n",
|
| 549 |
+
" del input_ids, attention_mask, labels, outputs, loss\n",
|
| 550 |
+
" if torch.cuda.is_available():\n",
|
| 551 |
+
" torch.cuda.empty_cache()\n",
|
| 552 |
+
" \n",
|
| 553 |
+
" return total_loss / len(train_loader)\n",
|
| 554 |
+
"\n",
|
| 555 |
+
"def evaluate(model, val_loader, criterion, config):\n",
|
| 556 |
+
" model.eval()\n",
|
| 557 |
+
" total_loss = 0\n",
|
| 558 |
+
" all_preds = []\n",
|
| 559 |
+
" all_labels = []\n",
|
| 560 |
+
" \n",
|
| 561 |
+
" with torch.no_grad():\n",
|
| 562 |
+
" progress_bar = tqdm(val_loader, desc='Evaluating')\n",
|
| 563 |
+
" for batch in progress_bar:\n",
|
| 564 |
+
" input_ids = batch['input_ids'].to(config.DEVICE)\n",
|
| 565 |
+
" attention_mask = batch['attention_mask'].to(config.DEVICE)\n",
|
| 566 |
+
" labels = batch['label'].to(config.DEVICE)\n",
|
| 567 |
+
" \n",
|
| 568 |
+
" outputs = model(input_ids, attention_mask)\n",
|
| 569 |
+
" loss = criterion(outputs, labels)\n",
|
| 570 |
+
" \n",
|
| 571 |
+
" total_loss += loss.item()\n",
|
| 572 |
+
" \n",
|
| 573 |
+
" preds = torch.argmax(outputs, dim=1)\n",
|
| 574 |
+
" all_preds.extend(preds.cpu().numpy())\n",
|
| 575 |
+
" all_labels.extend(labels.cpu().numpy())\n",
|
| 576 |
+
" \n",
|
| 577 |
+
" # Clear memory\n",
|
| 578 |
+
" del input_ids, attention_mask, labels, outputs, loss, preds\n",
|
| 579 |
+
" if torch.cuda.is_available():\n",
|
| 580 |
+
" torch.cuda.empty_cache()\n",
|
| 581 |
+
" \n",
|
| 582 |
+
" metrics = {\n",
|
| 583 |
+
" 'loss': total_loss / len(val_loader),\n",
|
| 584 |
+
" 'accuracy': accuracy_score(all_labels, all_preds),\n",
|
| 585 |
+
" 'precision': precision_score(all_labels, all_preds, average='weighted'),\n",
|
| 586 |
+
" 'recall': recall_score(all_labels, all_preds, average='weighted'),\n",
|
| 587 |
+
" 'f1': f1_score(all_labels, all_preds, average='weighted')\n",
|
| 588 |
+
" }\n",
|
| 589 |
+
" \n",
|
| 590 |
+
" return metrics\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"print(\"Training functions defined successfully\")\n"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "code",
|
| 597 |
+
"execution_count": null,
|
| 598 |
+
"metadata": {},
|
| 599 |
+
"outputs": [],
|
| 600 |
+
"source": [
|
| 601 |
+
"## 6. Main Training Process\n"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "code",
|
| 606 |
+
"execution_count": null,
|
| 607 |
+
"metadata": {},
|
| 608 |
+
"outputs": [],
|
| 609 |
+
"source": [
|
| 610 |
+
"# Setup training\n",
|
| 611 |
+
"def setup_training(df, config):\n",
|
| 612 |
+
" # Ensure we have valid data\n",
|
| 613 |
+
" if df is None or len(df) == 0:\n",
|
| 614 |
+
" raise ValueError(\"No valid dataset available\")\n",
|
| 615 |
+
" \n",
|
| 616 |
+
" print(f\"Dataset info:\")\n",
|
| 617 |
+
" print(f\"- Total samples: {len(df)}\")\n",
|
| 618 |
+
" print(f\"- Label distribution: {df['label'].value_counts().to_dict()}\")\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" # Preprocess data\n",
|
| 621 |
+
" print(\"Preprocessing text data...\")\n",
|
| 622 |
+
" texts = df['text'].apply(preprocess_text).values\n",
|
| 623 |
+
" labels = df['label'].values\n",
|
| 624 |
+
" \n",
|
| 625 |
+
" # Remove empty texts\n",
|
| 626 |
+
" valid_indices = [i for i, text in enumerate(texts) if len(text.strip()) > 0]\n",
|
| 627 |
+
" texts = texts[valid_indices]\n",
|
| 628 |
+
" labels = labels[valid_indices]\n",
|
| 629 |
+
" \n",
|
| 630 |
+
" print(f\"After preprocessing: {len(texts)} valid samples\")\n",
|
| 631 |
+
" \n",
|
| 632 |
+
" # Split data\n",
|
| 633 |
+
" train_texts, val_texts, train_labels, val_labels = train_test_split(\n",
|
| 634 |
+
" texts, labels, test_size=0.2, random_state=42, stratify=labels\n",
|
| 635 |
+
" )\n",
|
| 636 |
+
" \n",
|
| 637 |
+
" print(f\"Data split:\")\n",
|
| 638 |
+
" print(f\"- Train samples: {len(train_texts)}\")\n",
|
| 639 |
+
" print(f\"- Validation samples: {len(val_texts)}\")\n",
|
| 640 |
+
" print(f\"- Train label distribution: {pd.Series(train_labels).value_counts().to_dict()}\")\n",
|
| 641 |
+
" print(f\"- Val label distribution: {pd.Series(val_labels).value_counts().to_dict()}\")\n",
|
| 642 |
+
" \n",
|
| 643 |
+
" # Initialize tokenizer\n",
|
| 644 |
+
" print(\"Initializing BERT tokenizer...\")\n",
|
| 645 |
+
" tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)\n",
|
| 646 |
+
" \n",
|
| 647 |
+
" # Create datasets\n",
|
| 648 |
+
" print(\"Creating datasets...\")\n",
|
| 649 |
+
" train_dataset = FakeNewsDataset(train_texts, train_labels, tokenizer, config.MAX_SEQUENCE_LENGTH)\n",
|
| 650 |
+
" val_dataset = FakeNewsDataset(val_texts, val_labels, tokenizer, config.MAX_SEQUENCE_LENGTH)\n",
|
| 651 |
+
" \n",
|
| 652 |
+
" # Create dataloaders\n",
|
| 653 |
+
" train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)\n",
|
| 654 |
+
" val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE)\n",
|
| 655 |
+
" \n",
|
| 656 |
+
" print(f\"DataLoaders created:\")\n",
|
| 657 |
+
" print(f\"- Train batches: {len(train_loader)}\")\n",
|
| 658 |
+
" print(f\"- Val batches: {len(val_loader)}\")\n",
|
| 659 |
+
" \n",
|
| 660 |
+
" # Initialize model\n",
|
| 661 |
+
" print(\"Initializing model...\")\n",
|
| 662 |
+
" model = FakeNewsModel(config).to(config.DEVICE)\n",
|
| 663 |
+
" \n",
|
| 664 |
+
" # Count parameters\n",
|
| 665 |
+
" total_params = sum(p.numel() for p in model.parameters())\n",
|
| 666 |
+
" trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 667 |
+
" print(f\"Model parameters:\")\n",
|
| 668 |
+
" print(f\"- Total parameters: {total_params:,}\")\n",
|
| 669 |
+
" print(f\"- Trainable parameters: {trainable_params:,}\")\n",
|
| 670 |
+
" print(f\"- Model size (MB): {total_params * 4 / 1024 / 1024:.2f}\")\n",
|
| 671 |
+
" \n",
|
| 672 |
+
" # Initialize optimizer\n",
|
| 673 |
+
" optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)\n",
|
| 674 |
+
" \n",
|
| 675 |
+
" # Initialize loss function\n",
|
| 676 |
+
" criterion = nn.CrossEntropyLoss()\n",
|
| 677 |
+
" \n",
|
| 678 |
+
" # Initialize scaler for mixed precision\n",
|
| 679 |
+
" scaler = torch.cuda.amp.GradScaler() if config.USE_AMP and torch.cuda.is_available() else None\n",
|
| 680 |
+
" \n",
|
| 681 |
+
" return model, train_loader, val_loader, optimizer, criterion, scaler, tokenizer\n",
|
| 682 |
+
"\n",
|
| 683 |
+
"print(\"Training setup function defined successfully\")\n"
|
| 684 |
+
]
|
| 685 |
+
},
|
| 686 |
+
{
|
| 687 |
+
"cell_type": "code",
|
| 688 |
+
"execution_count": null,
|
| 689 |
+
"metadata": {},
|
| 690 |
+
"outputs": [],
|
| 691 |
+
"source": [
|
| 692 |
+
"# Run the complete training pipeline\n",
|
| 693 |
+
"def main():\n",
|
| 694 |
+
" print(\"Starting fake news detection training...\")\n",
|
| 695 |
+
" \n",
|
| 696 |
+
" # Setup training\n",
|
| 697 |
+
" model, train_loader, val_loader, optimizer, criterion, scaler, tokenizer = setup_training(df, config)\n",
|
| 698 |
+
" \n",
|
| 699 |
+
" # Training loop\n",
|
| 700 |
+
" best_val_loss = float('inf')\n",
|
| 701 |
+
" best_val_acc = 0.0\n",
|
| 702 |
+
" \n",
|
| 703 |
+
" print(f\"Starting training for {config.NUM_EPOCHS} epochs...\")\n",
|
| 704 |
+
" \n",
|
| 705 |
+
" for epoch in range(config.NUM_EPOCHS):\n",
|
| 706 |
+
" print(f'=== Epoch {epoch + 1}/{config.NUM_EPOCHS} ===')\n",
|
| 707 |
+
" \n",
|
| 708 |
+
" # Train\n",
|
| 709 |
+
" train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, config)\n",
|
| 710 |
+
" print(f'Train Loss: {train_loss:.4f}')\n",
|
| 711 |
+
" \n",
|
| 712 |
+
" # Evaluate\n",
|
| 713 |
+
" val_metrics = evaluate(model, val_loader, criterion, config)\n",
|
| 714 |
+
" print(f'Val Loss: {val_metrics[\"loss\"]:.4f}')\n",
|
| 715 |
+
" print(f'Val Accuracy: {val_metrics[\"accuracy\"]:.4f}')\n",
|
| 716 |
+
" print(f'Val Precision: {val_metrics[\"precision\"]:.4f}')\n",
|
| 717 |
+
" print(f'Val Recall: {val_metrics[\"recall\"]:.4f}')\n",
|
| 718 |
+
" print(f'Val F1: {val_metrics[\"f1\"]:.4f}')\n",
|
| 719 |
+
" \n",
|
| 720 |
+
" # Save best model\n",
|
| 721 |
+
" if val_metrics['accuracy'] > best_val_acc:\n",
|
| 722 |
+
" best_val_acc = val_metrics['accuracy']\n",
|
| 723 |
+
" best_val_loss = val_metrics['loss']\n",
|
| 724 |
+
" torch.save(model.state_dict(), 'best_model_colab.pt')\n",
|
| 725 |
+
" print(f'New best model saved! Accuracy: {best_val_acc:.4f}')\n",
|
| 726 |
+
" \n",
|
| 727 |
+
" # Clear memory\n",
|
| 728 |
+
" gc.collect()\n",
|
| 729 |
+
" if torch.cuda.is_available():\n",
|
| 730 |
+
" torch.cuda.empty_cache()\n",
|
| 731 |
+
" \n",
|
| 732 |
+
" print('Training completed!')\n",
|
| 733 |
+
" print(f'Best validation accuracy: {best_val_acc:.4f}')\n",
|
| 734 |
+
" print(f'Best validation loss: {best_val_loss:.4f}')\n",
|
| 735 |
+
" \n",
|
| 736 |
+
" return model, tokenizer\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"# Run training\n",
|
| 739 |
+
"model, tokenizer = main()\n"
|
| 740 |
+
]
|
| 741 |
+
},
|
| 742 |
+
{
|
| 743 |
+
"cell_type": "code",
|
| 744 |
+
"execution_count": null,
|
| 745 |
+
"metadata": {},
|
| 746 |
+
"outputs": [],
|
| 747 |
+
"source": [
|
| 748 |
+
"## 7. Model Testing and Prediction\n"
|
| 749 |
+
]
|
| 750 |
+
},
|
| 751 |
+
{
|
| 752 |
+
"cell_type": "code",
|
| 753 |
+
"execution_count": null,
|
| 754 |
+
"metadata": {},
|
| 755 |
+
"outputs": [],
|
| 756 |
+
"source": [
|
| 757 |
+
"def predict_single(text, model, tokenizer, config):\n",
|
| 758 |
+
" \"\"\"Predict if a single text is fake or real news\"\"\"\n",
|
| 759 |
+
" model.eval()\n",
|
| 760 |
+
" text = preprocess_text(text)\n",
|
| 761 |
+
" \n",
|
| 762 |
+
" encoding = tokenizer.encode_plus(\n",
|
| 763 |
+
" text,\n",
|
| 764 |
+
" add_special_tokens=True,\n",
|
| 765 |
+
" max_length=config.MAX_SEQUENCE_LENGTH,\n",
|
| 766 |
+
" padding='max_length',\n",
|
| 767 |
+
" truncation=True,\n",
|
| 768 |
+
" return_attention_mask=True,\n",
|
| 769 |
+
" return_tensors='pt'\n",
|
| 770 |
+
" )\n",
|
| 771 |
+
" \n",
|
| 772 |
+
" input_ids = encoding['input_ids'].to(config.DEVICE)\n",
|
| 773 |
+
" attention_mask = encoding['attention_mask'].to(config.DEVICE)\n",
|
| 774 |
+
" \n",
|
| 775 |
+
" with torch.no_grad():\n",
|
| 776 |
+
" outputs = model(input_ids, attention_mask)\n",
|
| 777 |
+
" probabilities = torch.softmax(outputs, dim=1)\n",
|
| 778 |
+
" prediction = torch.argmax(outputs, dim=1)\n",
|
| 779 |
+
" confidence = torch.max(probabilities, dim=1)[0]\n",
|
| 780 |
+
" \n",
|
| 781 |
+
" return {\n",
|
| 782 |
+
" 'prediction': prediction.item(),\n",
|
| 783 |
+
" 'label': 'FAKE' if prediction.item() == 1 else 'REAL',\n",
|
| 784 |
+
" 'confidence': confidence.item(),\n",
|
| 785 |
+
" 'probabilities': {\n",
|
| 786 |
+
" 'REAL': probabilities[0][0].item(),\n",
|
| 787 |
+
" 'FAKE': probabilities[0][1].item()\n",
|
| 788 |
+
" }\n",
|
| 789 |
+
" }\n",
|
| 790 |
+
"\n",
|
| 791 |
+
"# Test with sample texts\n",
|
| 792 |
+
"test_texts = [\n",
|
| 793 |
+
" \"Breaking: Scientists discover new planet in our solar system\",\n",
|
| 794 |
+
" \"Local community comes together to help flood victims\",\n",
|
| 795 |
+
" \"Shocking: Aliens spotted in downtown area last night\",\n",
|
| 796 |
+
" \"Government announces new healthcare policy to benefit citizens\"\n",
|
| 797 |
+
"]\n",
|
| 798 |
+
"\n",
|
| 799 |
+
"print(\"Testing model predictions:\")\n",
|
| 800 |
+
"print(\"=\" * 50)\n",
|
| 801 |
+
"\n",
|
| 802 |
+
"for i, text in enumerate(test_texts, 1):\n",
|
| 803 |
+
" result = predict_single(text, model, tokenizer, config)\n",
|
| 804 |
+
" print(f\"Text {i}: {text[:60]}...\")\n",
|
| 805 |
+
" print(f\"Prediction: {result['label']} (Confidence: {result['confidence']:.3f})\")\n",
|
| 806 |
+
" print(f\"Probabilities: REAL={result['probabilities']['REAL']:.3f}, FAKE={result['probabilities']['FAKE']:.3f}\")\n",
|
| 807 |
+
" print(\"-\" * 50)\n"
|
| 808 |
+
]
|
| 809 |
+
},
|
| 810 |
+
{
|
| 811 |
+
"cell_type": "code",
|
| 812 |
+
"execution_count": null,
|
| 813 |
+
"metadata": {},
|
| 814 |
+
"outputs": [],
|
| 815 |
+
"source": [
|
| 816 |
+
"# Run the complete training pipeline\n",
|
| 817 |
+
"def main():\n",
|
| 818 |
+
" print(\"Starting fake news detection training...\")\n",
|
| 819 |
+
" \n",
|
| 820 |
+
" # Setup training\n",
|
| 821 |
+
" model, train_loader, val_loader, optimizer, criterion, scaler, tokenizer = setup_training(df, config)\n",
|
| 822 |
+
" \n",
|
| 823 |
+
" # Training loop\n",
|
| 824 |
+
" best_val_loss = float('inf')\n",
|
| 825 |
+
" best_val_acc = 0.0\n",
|
| 826 |
+
" \n",
|
| 827 |
+
" print(f\"\\nStarting training for {config.NUM_EPOCHS} epochs...\")\n",
|
| 828 |
+
" \n",
|
| 829 |
+
" for epoch in range(config.NUM_EPOCHS):\n",
|
| 830 |
+
" print(f'\\n=== Epoch {epoch + 1}/{config.NUM_EPOCHS} ===')\n",
|
| 831 |
+
" \n",
|
| 832 |
+
" # Train\n",
|
| 833 |
+
" train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, config)\n",
|
| 834 |
+
" print(f'Train Loss: {train_loss:.4f}')\n",
|
| 835 |
+
" \n",
|
| 836 |
+
" # Evaluate\n",
|
| 837 |
+
" val_metrics = evaluate(model, val_loader, criterion, config)\n",
|
| 838 |
+
" print(f'Val Loss: {val_metrics[\\\"loss\\\"]:.4f}')\n",
|
| 839 |
+
" print(f'Val Accuracy: {val_metrics[\\\"accuracy\\\"]:.4f}')\n",
|
| 840 |
+
" print(f'Val Precision: {val_metrics[\\\"precision\\\"]:.4f}')\n",
|
| 841 |
+
" print(f'Val Recall: {val_metrics[\\\"recall\\\"]:.4f}')\n",
|
| 842 |
+
" print(f'Val F1: {val_metrics[\\\"f1\\\"]:.4f}')\n",
|
| 843 |
+
" \n",
|
| 844 |
+
" # Save best model\n",
|
| 845 |
+
" if val_metrics['accuracy'] > best_val_acc:\n",
|
| 846 |
+
" best_val_acc = val_metrics['accuracy']\n",
|
| 847 |
+
" best_val_loss = val_metrics['loss']\n",
|
| 848 |
+
" torch.save(model.state_dict(), 'best_model_colab.pt')\\n print(f'New best model saved! Accuracy: {best_val_acc:.4f}')\\n \\n # Clear memory\\n gc.collect()\\n if torch.cuda.is_available():\\n torch.cuda.empty_cache()\\n \\n print(f'\\\\nTraining completed!')\\n print(f'Best validation accuracy: {best_val_acc:.4f}')\\n print(f'Best validation loss: {best_val_loss:.4f}')\\n \\n return model, tokenizer\\n\\n# Run training\\nmodel, tokenizer = main()\n"
|
| 849 |
+
]
|
| 850 |
+
}
|
| 851 |
+
],
|
| 852 |
+
"metadata": {
|
| 853 |
+
"language_info": {
|
| 854 |
+
"name": "python"
|
| 855 |
+
}
|
| 856 |
+
},
|
| 857 |
+
"nbformat": 4,
|
| 858 |
+
"nbformat_minor": 2
|
| 859 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
pandas
|
| 3 |
+
scikit-learn
|
| 4 |
+
torch
|
| 5 |
+
transformers
|
| 6 |
+
nltk
|
| 7 |
+
spacy
|
| 8 |
+
matplotlib
|
| 9 |
+
seaborn
|
| 10 |
+
tqdm
|
| 11 |
+
emoji
|
| 12 |
+
textblob
|
| 13 |
+
gensim
|
| 14 |
+
pytest
|
| 15 |
+
jupyter
|
| 16 |
+
gdown
|
| 17 |
+
requests
|
| 18 |
+
kaggle
|
| 19 |
+
streamlit
|
| 20 |
+
plotly
|
src/app.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from transformers import BertTokenizer
|
| 10 |
+
|
| 11 |
+
# Add project root to Python path
|
| 12 |
+
project_root = Path(__file__).parent.parent
|
| 13 |
+
sys.path.append(str(project_root))
|
| 14 |
+
|
| 15 |
+
from src.models.hybrid_model import HybridFakeNewsDetector
|
| 16 |
+
from src.config.config import *
|
| 17 |
+
from src.data.preprocessor import TextPreprocessor
|
| 18 |
+
|
| 19 |
+
# Set page config
|
| 20 |
+
st.set_page_config(
|
| 21 |
+
page_title="Fake News Detection",
|
| 22 |
+
page_icon="📰",
|
| 23 |
+
layout="wide"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Initialize session state
|
| 27 |
+
if 'model' not in st.session_state:
|
| 28 |
+
st.session_state.model = None
|
| 29 |
+
if 'tokenizer' not in st.session_state:
|
| 30 |
+
st.session_state.tokenizer = None
|
| 31 |
+
if 'preprocessor' not in st.session_state:
|
| 32 |
+
st.session_state.preprocessor = None
|
| 33 |
+
|
| 34 |
+
def load_model():
|
| 35 |
+
"""Load the trained model and tokenizer."""
|
| 36 |
+
if st.session_state.model is None:
|
| 37 |
+
# Initialize model
|
| 38 |
+
model = HybridFakeNewsDetector(
|
| 39 |
+
bert_model_name=BERT_MODEL_NAME,
|
| 40 |
+
lstm_hidden_size=LSTM_HIDDEN_SIZE,
|
| 41 |
+
lstm_num_layers=LSTM_NUM_LAYERS,
|
| 42 |
+
dropout_rate=DROPOUT_RATE
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Load trained weights
|
| 46 |
+
model.load_state_dict(torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')))
|
| 47 |
+
model.eval()
|
| 48 |
+
st.session_state.model = model
|
| 49 |
+
|
| 50 |
+
# Initialize tokenizer
|
| 51 |
+
st.session_state.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
|
| 52 |
+
|
| 53 |
+
# Initialize preprocessor
|
| 54 |
+
st.session_state.preprocessor = TextPreprocessor()
|
| 55 |
+
|
| 56 |
+
def predict_news(text):
|
| 57 |
+
"""Predict if the given news is fake or real."""
|
| 58 |
+
if st.session_state.model is None:
|
| 59 |
+
load_model()
|
| 60 |
+
|
| 61 |
+
# Preprocess text
|
| 62 |
+
processed_text = st.session_state.preprocessor.preprocess_text(text)
|
| 63 |
+
|
| 64 |
+
# Tokenize
|
| 65 |
+
encoding = st.session_state.tokenizer.encode_plus(
|
| 66 |
+
processed_text,
|
| 67 |
+
add_special_tokens=True,
|
| 68 |
+
max_length=MAX_SEQUENCE_LENGTH,
|
| 69 |
+
padding='max_length',
|
| 70 |
+
truncation=True,
|
| 71 |
+
return_attention_mask=True,
|
| 72 |
+
return_tensors='pt'
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Get prediction
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
outputs = st.session_state.model(
|
| 78 |
+
encoding['input_ids'],
|
| 79 |
+
encoding['attention_mask']
|
| 80 |
+
)
|
| 81 |
+
probabilities = torch.softmax(outputs['logits'], dim=1)
|
| 82 |
+
prediction = torch.argmax(outputs['logits'], dim=1)
|
| 83 |
+
attention_weights = outputs['attention_weights']
|
| 84 |
+
|
| 85 |
+
# Convert attention weights to numpy and get the first sequence
|
| 86 |
+
attention_weights_np = attention_weights[0].cpu().numpy()
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
'prediction': prediction.item(),
|
| 90 |
+
'label': 'FAKE' if prediction.item() == 1 else 'REAL',
|
| 91 |
+
'confidence': torch.max(probabilities, dim=1)[0].item(),
|
| 92 |
+
'probabilities': {
|
| 93 |
+
'REAL': probabilities[0][0].item(),
|
| 94 |
+
'FAKE': probabilities[0][1].item()
|
| 95 |
+
},
|
| 96 |
+
'attention_weights': attention_weights_np
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
def plot_confidence(probabilities):
|
| 100 |
+
"""Plot prediction confidence."""
|
| 101 |
+
fig = go.Figure(data=[
|
| 102 |
+
go.Bar(
|
| 103 |
+
x=list(probabilities.keys()),
|
| 104 |
+
y=list(probabilities.values()),
|
| 105 |
+
text=[f'{p:.2%}' for p in probabilities.values()],
|
| 106 |
+
textposition='auto',
|
| 107 |
+
)
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
fig.update_layout(
|
| 111 |
+
title='Prediction Confidence',
|
| 112 |
+
xaxis_title='Class',
|
| 113 |
+
yaxis_title='Probability',
|
| 114 |
+
yaxis_range=[0, 1]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return fig
|
| 118 |
+
|
| 119 |
+
def plot_attention(text, attention_weights):
|
| 120 |
+
"""Plot attention weights."""
|
| 121 |
+
tokens = text.split()
|
| 122 |
+
attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens
|
| 123 |
+
|
| 124 |
+
# Ensure attention weights are in the correct format
|
| 125 |
+
if isinstance(attention_weights, (list, np.ndarray)):
|
| 126 |
+
attention_weights = np.array(attention_weights).flatten()
|
| 127 |
+
|
| 128 |
+
# Format weights for display
|
| 129 |
+
formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
|
| 130 |
+
|
| 131 |
+
fig = go.Figure(data=[
|
| 132 |
+
go.Bar(
|
| 133 |
+
x=tokens,
|
| 134 |
+
y=attention_weights,
|
| 135 |
+
text=formatted_weights,
|
| 136 |
+
textposition='auto',
|
| 137 |
+
)
|
| 138 |
+
])
|
| 139 |
+
|
| 140 |
+
fig.update_layout(
|
| 141 |
+
title='Attention Weights',
|
| 142 |
+
xaxis_title='Tokens',
|
| 143 |
+
yaxis_title='Attention Weight',
|
| 144 |
+
xaxis_tickangle=45
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return fig
|
| 148 |
+
|
| 149 |
+
def main():
|
| 150 |
+
st.title("📰 Fake News Detection System")
|
| 151 |
+
st.write("""
|
| 152 |
+
This application uses a hybrid deep learning model (BERT + BiLSTM + Attention)
|
| 153 |
+
to detect fake news articles. Enter a news article below to analyze it.
|
| 154 |
+
""")
|
| 155 |
+
|
| 156 |
+
# Sidebar
|
| 157 |
+
st.sidebar.title("About")
|
| 158 |
+
st.sidebar.info("""
|
| 159 |
+
This model was developed as part of the Data Mining Laboratory course
|
| 160 |
+
under the guidance of Dr. Kirti Kumari.
|
| 161 |
+
|
| 162 |
+
The model combines:
|
| 163 |
+
- BERT for contextual embeddings
|
| 164 |
+
- BiLSTM for sequence modeling
|
| 165 |
+
- Attention mechanism for interpretability
|
| 166 |
+
""")
|
| 167 |
+
|
| 168 |
+
# Main content
|
| 169 |
+
st.header("News Analysis")
|
| 170 |
+
|
| 171 |
+
# Text input
|
| 172 |
+
news_text = st.text_area(
|
| 173 |
+
"Enter the news article to analyze:",
|
| 174 |
+
height=200,
|
| 175 |
+
placeholder="Paste your news article here..."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if st.button("Analyze"):
|
| 179 |
+
if news_text:
|
| 180 |
+
with st.spinner("Analyzing the news article..."):
|
| 181 |
+
# Get prediction
|
| 182 |
+
result = predict_news(news_text)
|
| 183 |
+
|
| 184 |
+
# Display result
|
| 185 |
+
col1, col2 = st.columns(2)
|
| 186 |
+
|
| 187 |
+
with col1:
|
| 188 |
+
st.subheader("Prediction")
|
| 189 |
+
if result['label'] == 'FAKE':
|
| 190 |
+
st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})")
|
| 191 |
+
else:
|
| 192 |
+
st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})")
|
| 193 |
+
|
| 194 |
+
with col2:
|
| 195 |
+
st.subheader("Confidence Scores")
|
| 196 |
+
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
|
| 197 |
+
|
| 198 |
+
# Show attention visualization
|
| 199 |
+
st.subheader("Attention Analysis")
|
| 200 |
+
st.write("""
|
| 201 |
+
The attention weights show which parts of the text the model focused on
|
| 202 |
+
while making its prediction. Higher weights indicate more important tokens.
|
| 203 |
+
""")
|
| 204 |
+
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
|
| 205 |
+
|
| 206 |
+
# Show model explanation
|
| 207 |
+
st.subheader("Model Explanation")
|
| 208 |
+
if result['label'] == 'FAKE':
|
| 209 |
+
st.write("""
|
| 210 |
+
The model identified this as fake news based on:
|
| 211 |
+
- Linguistic patterns typical of fake news
|
| 212 |
+
- Inconsistencies in the content
|
| 213 |
+
- Attention weights on suspicious phrases
|
| 214 |
+
""")
|
| 215 |
+
else:
|
| 216 |
+
st.write("""
|
| 217 |
+
The model identified this as real news based on:
|
| 218 |
+
- Credible language patterns
|
| 219 |
+
- Consistent information
|
| 220 |
+
- Attention weights on factual statements
|
| 221 |
+
""")
|
| 222 |
+
else:
|
| 223 |
+
st.warning("Please enter a news article to analyze.")
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
main()
|
src/config/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
src/config/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
# Project paths
|
| 4 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 5 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 6 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 7 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 8 |
+
MODEL_DIR = PROJECT_ROOT / "models"
|
| 9 |
+
SAVED_MODELS_DIR = MODEL_DIR / "saved"
|
| 10 |
+
CHECKPOINTS_DIR = MODEL_DIR / "checkpoints"
|
| 11 |
+
|
| 12 |
+
# Data parameters
|
| 13 |
+
MAX_SEQUENCE_LENGTH = 256
|
| 14 |
+
VOCAB_SIZE = 15000
|
| 15 |
+
EMBEDDING_DIM = 128
|
| 16 |
+
BATCH_SIZE = 8
|
| 17 |
+
TEST_SIZE = 0.2
|
| 18 |
+
VAL_SIZE = 0.1
|
| 19 |
+
RANDOM_STATE = 42
|
| 20 |
+
MAX_SAMPLES = 10000
|
| 21 |
+
|
| 22 |
+
# Model parameters
|
| 23 |
+
BERT_MODEL_NAME = "bert-base-uncased"
|
| 24 |
+
LSTM_HIDDEN_SIZE = 128
|
| 25 |
+
LSTM_NUM_LAYERS = 1
|
| 26 |
+
DROPOUT_RATE = 0.3
|
| 27 |
+
LEARNING_RATE = 2e-5
|
| 28 |
+
NUM_EPOCHS = 3
|
| 29 |
+
EARLY_STOPPING_PATIENCE = 2
|
| 30 |
+
|
| 31 |
+
# Training parameters
|
| 32 |
+
DEVICE = "cpu"
|
| 33 |
+
NUM_WORKERS = 0
|
| 34 |
+
PIN_MEMORY = False
|
| 35 |
+
|
| 36 |
+
# Feature extraction
|
| 37 |
+
USE_TFIDF = True
|
| 38 |
+
USE_BERT = True
|
| 39 |
+
USE_LSTM = True
|
| 40 |
+
|
| 41 |
+
# Evaluation metrics
|
| 42 |
+
METRICS = ["accuracy", "precision", "recall", "f1"]
|
src/data/__pycache__/dataset.cpython-311.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
src/data/__pycache__/preprocessor.cpython-311.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from transformers import BertTokenizer
|
| 4 |
+
from typing import Dict, List, Union
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class FakeNewsDataset(Dataset):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
texts: List[str],
|
| 11 |
+
labels: List[int],
|
| 12 |
+
tokenizer: BertTokenizer,
|
| 13 |
+
max_length: int = 512):
|
| 14 |
+
self.texts = texts
|
| 15 |
+
self.labels = labels
|
| 16 |
+
self.tokenizer = tokenizer
|
| 17 |
+
self.max_length = max_length
|
| 18 |
+
|
| 19 |
+
def __len__(self) -> int:
|
| 20 |
+
return len(self.texts)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 23 |
+
text = str(self.texts[idx])
|
| 24 |
+
label = self.labels[idx]
|
| 25 |
+
|
| 26 |
+
encoding = self.tokenizer(
|
| 27 |
+
text,
|
| 28 |
+
add_special_tokens=True,
|
| 29 |
+
max_length=self.max_length,
|
| 30 |
+
padding='max_length',
|
| 31 |
+
truncation=True,
|
| 32 |
+
return_attention_mask=True,
|
| 33 |
+
return_tensors='pt'
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 38 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 39 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def create_data_loaders(
|
| 43 |
+
df: pd.DataFrame,
|
| 44 |
+
text_column: str,
|
| 45 |
+
label_column: str,
|
| 46 |
+
tokenizer: BertTokenizer,
|
| 47 |
+
batch_size: int = 32,
|
| 48 |
+
max_length: int = 512,
|
| 49 |
+
train_size: float = 0.8,
|
| 50 |
+
val_size: float = 0.1,
|
| 51 |
+
random_state: int = 42
|
| 52 |
+
) -> Dict[str, torch.utils.data.DataLoader]:
|
| 53 |
+
"""Create train, validation, and test data loaders."""
|
| 54 |
+
# Split data
|
| 55 |
+
train_df = df.sample(frac=train_size, random_state=random_state)
|
| 56 |
+
remaining_df = df.drop(train_df.index)
|
| 57 |
+
val_df = remaining_df.sample(frac=val_size/(1-train_size), random_state=random_state)
|
| 58 |
+
test_df = remaining_df.drop(val_df.index)
|
| 59 |
+
|
| 60 |
+
# Create datasets
|
| 61 |
+
train_dataset = FakeNewsDataset(
|
| 62 |
+
texts=train_df[text_column].tolist(),
|
| 63 |
+
labels=train_df[label_column].tolist(),
|
| 64 |
+
tokenizer=tokenizer,
|
| 65 |
+
max_length=max_length
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
val_dataset = FakeNewsDataset(
|
| 69 |
+
texts=val_df[text_column].tolist(),
|
| 70 |
+
labels=val_df[label_column].tolist(),
|
| 71 |
+
tokenizer=tokenizer,
|
| 72 |
+
max_length=max_length
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
test_dataset = FakeNewsDataset(
|
| 76 |
+
texts=test_df[text_column].tolist(),
|
| 77 |
+
labels=test_df[label_column].tolist(),
|
| 78 |
+
tokenizer=tokenizer,
|
| 79 |
+
max_length=max_length
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Create data loaders
|
| 83 |
+
train_loader = torch.utils.data.DataLoader(
|
| 84 |
+
train_dataset,
|
| 85 |
+
batch_size=batch_size,
|
| 86 |
+
shuffle=True,
|
| 87 |
+
num_workers=4
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
val_loader = torch.utils.data.DataLoader(
|
| 91 |
+
val_dataset,
|
| 92 |
+
batch_size=batch_size,
|
| 93 |
+
shuffle=False,
|
| 94 |
+
num_workers=4
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
test_loader = torch.utils.data.DataLoader(
|
| 98 |
+
test_dataset,
|
| 99 |
+
batch_size=batch_size,
|
| 100 |
+
shuffle=False,
|
| 101 |
+
num_workers=4
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
'train': train_loader,
|
| 106 |
+
'val': val_loader,
|
| 107 |
+
'test': test_loader
|
| 108 |
+
}
|
src/data/download_datasets.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import requests
|
| 4 |
+
import zipfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import logging
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import json
|
| 9 |
+
import kaggle
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class DatasetDownloader:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.project_root = Path(__file__).parent.parent.parent
|
| 17 |
+
self.raw_data_dir = self.project_root / "data" / "raw"
|
| 18 |
+
self.processed_data_dir = self.project_root / "data" / "processed"
|
| 19 |
+
|
| 20 |
+
# Create directories if they don't exist
|
| 21 |
+
os.makedirs(self.raw_data_dir, exist_ok=True)
|
| 22 |
+
os.makedirs(self.processed_data_dir, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
def download_kaggle_dataset(self):
|
| 25 |
+
"""Download dataset from Kaggle."""
|
| 26 |
+
logger.info("Downloading dataset from Kaggle...")
|
| 27 |
+
|
| 28 |
+
# Kaggle dataset ID
|
| 29 |
+
dataset_id = "clmentbisaillon/fake-and-real-news-dataset"
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
kaggle.api.dataset_download_files(
|
| 33 |
+
dataset_id,
|
| 34 |
+
path=self.raw_data_dir,
|
| 35 |
+
unzip=True
|
| 36 |
+
)
|
| 37 |
+
logger.info("Successfully downloaded dataset from Kaggle")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"Error downloading from Kaggle: {str(e)}")
|
| 40 |
+
logger.info("Please download the dataset manually from: https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset")
|
| 41 |
+
|
| 42 |
+
def download_liar(self):
|
| 43 |
+
"""Download LIAR dataset."""
|
| 44 |
+
logger.info("Downloading LIAR dataset...")
|
| 45 |
+
|
| 46 |
+
# URL for LIAR dataset
|
| 47 |
+
url = "https://www.cs.ucsb.edu/~william/data/liar_dataset.zip"
|
| 48 |
+
output_path = self.raw_data_dir / "liar_dataset.zip"
|
| 49 |
+
|
| 50 |
+
if not output_path.exists():
|
| 51 |
+
try:
|
| 52 |
+
response = requests.get(url, stream=True)
|
| 53 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 54 |
+
|
| 55 |
+
with open(output_path, 'wb') as f, tqdm(
|
| 56 |
+
desc="Downloading LIAR dataset",
|
| 57 |
+
total=total_size,
|
| 58 |
+
unit='iB',
|
| 59 |
+
unit_scale=True
|
| 60 |
+
) as pbar:
|
| 61 |
+
for data in response.iter_content(chunk_size=1024):
|
| 62 |
+
size = f.write(data)
|
| 63 |
+
pbar.update(size)
|
| 64 |
+
|
| 65 |
+
# Extract the zip file
|
| 66 |
+
with zipfile.ZipFile(output_path, 'r') as zip_ref:
|
| 67 |
+
zip_ref.extractall(self.raw_data_dir / "liar")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Error downloading LIAR dataset: {str(e)}")
|
| 70 |
+
logger.info("Please download the LIAR dataset manually from: https://www.cs.ucsb.edu/~william/data/liar_dataset.zip")
|
| 71 |
+
|
| 72 |
+
def process_kaggle_dataset(self):
|
| 73 |
+
"""Process the Kaggle dataset."""
|
| 74 |
+
logger.info("Processing Kaggle dataset...")
|
| 75 |
+
|
| 76 |
+
# Read fake and real news files
|
| 77 |
+
fake_df = pd.read_csv(self.raw_data_dir / "Fake.csv")
|
| 78 |
+
true_df = pd.read_csv(self.raw_data_dir / "True.csv")
|
| 79 |
+
|
| 80 |
+
# Add labels
|
| 81 |
+
fake_df['label'] = 1 # 1 for fake
|
| 82 |
+
true_df['label'] = 0 # 0 for real
|
| 83 |
+
|
| 84 |
+
# Combine datasets
|
| 85 |
+
combined_df = pd.concat([fake_df, true_df], ignore_index=True)
|
| 86 |
+
|
| 87 |
+
# Save processed data
|
| 88 |
+
combined_df.to_csv(self.processed_data_dir / "kaggle_processed.csv", index=False)
|
| 89 |
+
logger.info(f"Saved {len(combined_df)} articles from Kaggle dataset")
|
| 90 |
+
|
| 91 |
+
def process_liar(self):
|
| 92 |
+
"""Process LIAR dataset."""
|
| 93 |
+
logger.info("Processing LIAR dataset...")
|
| 94 |
+
|
| 95 |
+
# Read LIAR dataset
|
| 96 |
+
liar_file = self.raw_data_dir / "liar" / "train.tsv"
|
| 97 |
+
if not liar_file.exists():
|
| 98 |
+
logger.error("LIAR dataset not found!")
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
# Read TSV file
|
| 102 |
+
df = pd.read_csv(liar_file, sep='\t', header=None)
|
| 103 |
+
|
| 104 |
+
# Rename columns
|
| 105 |
+
df.columns = [
|
| 106 |
+
'id', 'label', 'statement', 'subject', 'speaker',
|
| 107 |
+
'job_title', 'state_info', 'party_affiliation',
|
| 108 |
+
'barely_true', 'false', 'half_true', 'mostly_true',
|
| 109 |
+
'pants_on_fire', 'venue'
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
# Convert labels to binary (0 for true, 1 for false)
|
| 113 |
+
label_map = {
|
| 114 |
+
'true': 0,
|
| 115 |
+
'mostly-true': 0,
|
| 116 |
+
'half-true': 0,
|
| 117 |
+
'barely-true': 1,
|
| 118 |
+
'false': 1,
|
| 119 |
+
'pants-fire': 1
|
| 120 |
+
}
|
| 121 |
+
df['label'] = df['label'].map(label_map)
|
| 122 |
+
|
| 123 |
+
# Select relevant columns
|
| 124 |
+
df = df[['statement', 'label', 'subject', 'speaker', 'party_affiliation']]
|
| 125 |
+
df.columns = ['text', 'label', 'subject', 'speaker', 'party']
|
| 126 |
+
|
| 127 |
+
# Save processed data
|
| 128 |
+
df.to_csv(self.processed_data_dir / "liar_processed.csv", index=False)
|
| 129 |
+
logger.info(f"Saved {len(df)} articles from LIAR dataset")
|
| 130 |
+
|
| 131 |
+
def combine_datasets(self):
|
| 132 |
+
"""Combine processed datasets."""
|
| 133 |
+
logger.info("Combining datasets...")
|
| 134 |
+
|
| 135 |
+
# Read processed datasets
|
| 136 |
+
kaggle_df = pd.read_csv(self.processed_data_dir / "kaggle_processed.csv")
|
| 137 |
+
liar_df = pd.read_csv(self.processed_data_dir / "liar_processed.csv")
|
| 138 |
+
|
| 139 |
+
# Combine datasets
|
| 140 |
+
combined_df = pd.concat([
|
| 141 |
+
kaggle_df[['text', 'label']],
|
| 142 |
+
liar_df[['text', 'label']]
|
| 143 |
+
], ignore_index=True)
|
| 144 |
+
|
| 145 |
+
# Save combined dataset
|
| 146 |
+
combined_df.to_csv(self.processed_data_dir / "combined_dataset.csv", index=False)
|
| 147 |
+
logger.info(f"Combined dataset contains {len(combined_df)} articles")
|
| 148 |
+
|
| 149 |
+
def main():
|
| 150 |
+
downloader = DatasetDownloader()
|
| 151 |
+
|
| 152 |
+
# Download datasets
|
| 153 |
+
downloader.download_kaggle_dataset()
|
| 154 |
+
downloader.download_liar()
|
| 155 |
+
|
| 156 |
+
# Process datasets
|
| 157 |
+
downloader.process_kaggle_dataset()
|
| 158 |
+
downloader.process_liar()
|
| 159 |
+
|
| 160 |
+
# Combine datasets
|
| 161 |
+
downloader.combine_datasets()
|
| 162 |
+
|
| 163 |
+
logger.info("Dataset preparation completed!")
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|
src/data/feature_extractor.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
| 4 |
+
from transformers import BertTokenizer, BertModel
|
| 5 |
+
from typing import Tuple, Dict, List
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
class FeatureExtractor:
|
| 10 |
+
def __init__(self, bert_model_name: str = "bert-base-uncased"):
|
| 11 |
+
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
| 12 |
+
self.bert_model = BertModel.from_pretrained(bert_model_name)
|
| 13 |
+
self.tfidf_vectorizer = TfidfVectorizer(
|
| 14 |
+
max_features=5000,
|
| 15 |
+
ngram_range=(1, 2),
|
| 16 |
+
stop_words='english'
|
| 17 |
+
)
|
| 18 |
+
self.count_vectorizer = CountVectorizer(
|
| 19 |
+
max_features=5000,
|
| 20 |
+
ngram_range=(1, 2),
|
| 21 |
+
stop_words='english'
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def get_bert_embeddings(self, texts: List[str],
|
| 25 |
+
batch_size: int = 32,
|
| 26 |
+
max_length: int = 512) -> np.ndarray:
|
| 27 |
+
"""Extract BERT embeddings for a list of texts."""
|
| 28 |
+
self.bert_model.eval()
|
| 29 |
+
embeddings = []
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
for i in tqdm(range(0, len(texts), batch_size)):
|
| 33 |
+
batch_texts = texts[i:i + batch_size]
|
| 34 |
+
|
| 35 |
+
# Tokenize and prepare input
|
| 36 |
+
encoded = self.bert_tokenizer(
|
| 37 |
+
batch_texts,
|
| 38 |
+
padding=True,
|
| 39 |
+
truncation=True,
|
| 40 |
+
max_length=max_length,
|
| 41 |
+
return_tensors='pt'
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Get BERT embeddings
|
| 45 |
+
outputs = self.bert_model(**encoded)
|
| 46 |
+
# Use [CLS] token embeddings as sentence representation
|
| 47 |
+
batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy()
|
| 48 |
+
embeddings.append(batch_embeddings)
|
| 49 |
+
|
| 50 |
+
return np.vstack(embeddings)
|
| 51 |
+
|
| 52 |
+
def get_tfidf_features(self, texts: List[str]) -> np.ndarray:
|
| 53 |
+
"""Extract TF-IDF features from texts."""
|
| 54 |
+
return self.tfidf_vectorizer.fit_transform(texts).toarray()
|
| 55 |
+
|
| 56 |
+
def get_count_features(self, texts: List[str]) -> np.ndarray:
|
| 57 |
+
"""Extract Count Vectorizer features from texts."""
|
| 58 |
+
return self.count_vectorizer.fit_transform(texts).toarray()
|
| 59 |
+
|
| 60 |
+
def extract_all_features(self, texts: List[str],
|
| 61 |
+
use_bert: bool = True,
|
| 62 |
+
use_tfidf: bool = True,
|
| 63 |
+
use_count: bool = True) -> Dict[str, np.ndarray]:
|
| 64 |
+
"""Extract all features from texts."""
|
| 65 |
+
features = {}
|
| 66 |
+
|
| 67 |
+
if use_bert:
|
| 68 |
+
features['bert'] = self.get_bert_embeddings(texts)
|
| 69 |
+
if use_tfidf:
|
| 70 |
+
features['tfidf'] = self.get_tfidf_features(texts)
|
| 71 |
+
if use_count:
|
| 72 |
+
features['count'] = self.get_count_features(texts)
|
| 73 |
+
|
| 74 |
+
return features
|
| 75 |
+
|
| 76 |
+
def extract_features_from_dataframe(self,
|
| 77 |
+
df: pd.DataFrame,
|
| 78 |
+
text_column: str,
|
| 79 |
+
**kwargs) -> Dict[str, np.ndarray]:
|
| 80 |
+
"""Extract features from a dataframe's text column."""
|
| 81 |
+
texts = df[text_column].tolist()
|
| 82 |
+
return self.extract_all_features(texts, **kwargs)
|
src/data/preprocessor.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import emoji
|
| 3 |
+
import nltk
|
| 4 |
+
from nltk.tokenize import word_tokenize
|
| 5 |
+
from nltk.corpus import stopwords
|
| 6 |
+
from nltk.stem import WordNetLemmatizer
|
| 7 |
+
from textblob import TextBlob
|
| 8 |
+
from typing import List, Union
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
class TextPreprocessor:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# Download required NLTK data
|
| 14 |
+
nltk.download('punkt')
|
| 15 |
+
nltk.download('stopwords')
|
| 16 |
+
nltk.download('wordnet')
|
| 17 |
+
|
| 18 |
+
self.stop_words = set(stopwords.words('english'))
|
| 19 |
+
self.lemmatizer = WordNetLemmatizer()
|
| 20 |
+
|
| 21 |
+
def remove_urls(self, text: str) -> str:
|
| 22 |
+
"""Remove URLs from text."""
|
| 23 |
+
url_pattern = re.compile(r'https?://\S+|www\.\S+')
|
| 24 |
+
return url_pattern.sub('', text)
|
| 25 |
+
|
| 26 |
+
def remove_emojis(self, text: str) -> str:
|
| 27 |
+
"""Remove emojis from text."""
|
| 28 |
+
return emoji.replace_emoji(text, replace='')
|
| 29 |
+
|
| 30 |
+
def remove_special_chars(self, text: str) -> str:
|
| 31 |
+
"""Remove special characters and numbers."""
|
| 32 |
+
return re.sub(r'[^a-zA-Z\s]', '', text)
|
| 33 |
+
|
| 34 |
+
def remove_extra_spaces(self, text: str) -> str:
|
| 35 |
+
"""Remove extra spaces."""
|
| 36 |
+
return re.sub(r'\s+', ' ', text).strip()
|
| 37 |
+
|
| 38 |
+
def lemmatize_text(self, text: str) -> str:
|
| 39 |
+
"""Lemmatize text."""
|
| 40 |
+
tokens = word_tokenize(text)
|
| 41 |
+
return ' '.join([self.lemmatizer.lemmatize(token) for token in tokens])
|
| 42 |
+
|
| 43 |
+
def remove_stopwords(self, text: str) -> str:
|
| 44 |
+
"""Remove stopwords from text."""
|
| 45 |
+
tokens = word_tokenize(text)
|
| 46 |
+
return ' '.join([token for token in tokens if token.lower() not in self.stop_words])
|
| 47 |
+
|
| 48 |
+
def correct_spelling(self, text: str) -> str:
|
| 49 |
+
"""Correct spelling in text."""
|
| 50 |
+
return str(TextBlob(text).correct())
|
| 51 |
+
|
| 52 |
+
def preprocess_text(self, text: str,
|
| 53 |
+
remove_urls: bool = True,
|
| 54 |
+
remove_emojis: bool = True,
|
| 55 |
+
remove_special_chars: bool = True,
|
| 56 |
+
remove_stopwords: bool = True,
|
| 57 |
+
lemmatize: bool = True,
|
| 58 |
+
correct_spelling: bool = False) -> str:
|
| 59 |
+
"""Apply all preprocessing steps to text."""
|
| 60 |
+
if not isinstance(text, str):
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
text = text.lower()
|
| 64 |
+
|
| 65 |
+
if remove_urls:
|
| 66 |
+
text = self.remove_urls(text)
|
| 67 |
+
if remove_emojis:
|
| 68 |
+
text = self.remove_emojis(text)
|
| 69 |
+
if remove_special_chars:
|
| 70 |
+
text = self.remove_special_chars(text)
|
| 71 |
+
if remove_stopwords:
|
| 72 |
+
text = self.remove_stopwords(text)
|
| 73 |
+
if lemmatize:
|
| 74 |
+
text = self.lemmatize_text(text)
|
| 75 |
+
if correct_spelling:
|
| 76 |
+
text = self.correct_spelling(text)
|
| 77 |
+
|
| 78 |
+
text = self.remove_extra_spaces(text)
|
| 79 |
+
return text
|
| 80 |
+
|
| 81 |
+
def preprocess_dataframe(self, df: pd.DataFrame,
|
| 82 |
+
text_column: str,
|
| 83 |
+
**kwargs) -> pd.DataFrame:
|
| 84 |
+
"""Preprocess text column in a dataframe."""
|
| 85 |
+
df = df.copy()
|
| 86 |
+
df[text_column] = df[text_column].apply(
|
| 87 |
+
lambda x: self.preprocess_text(x, **kwargs)
|
| 88 |
+
)
|
| 89 |
+
return df
|
src/models/__pycache__/hybrid_model.cpython-311.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
src/models/__pycache__/trainer.cpython-311.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
src/models/hybrid_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import BertModel
|
| 4 |
+
from typing import Tuple, Dict
|
| 5 |
+
|
| 6 |
+
class AttentionLayer(nn.Module):
|
| 7 |
+
def __init__(self, hidden_size: int):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.attention = nn.Sequential(
|
| 10 |
+
nn.Linear(hidden_size, hidden_size),
|
| 11 |
+
nn.Tanh(),
|
| 12 |
+
nn.Linear(hidden_size, 1)
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 16 |
+
attention_weights = torch.softmax(self.attention(x), dim=1)
|
| 17 |
+
attended = torch.sum(attention_weights * x, dim=1)
|
| 18 |
+
return attended, attention_weights
|
| 19 |
+
|
| 20 |
+
class HybridFakeNewsDetector(nn.Module):
|
| 21 |
+
def __init__(self,
|
| 22 |
+
bert_model_name: str = "bert-base-uncased",
|
| 23 |
+
lstm_hidden_size: int = 256,
|
| 24 |
+
lstm_num_layers: int = 2,
|
| 25 |
+
dropout_rate: float = 0.3,
|
| 26 |
+
num_classes: int = 2):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
# BERT encoder
|
| 30 |
+
self.bert = BertModel.from_pretrained(bert_model_name)
|
| 31 |
+
bert_hidden_size = self.bert.config.hidden_size
|
| 32 |
+
|
| 33 |
+
# BiLSTM layer
|
| 34 |
+
self.lstm = nn.LSTM(
|
| 35 |
+
input_size=bert_hidden_size,
|
| 36 |
+
hidden_size=lstm_hidden_size,
|
| 37 |
+
num_layers=lstm_num_layers,
|
| 38 |
+
batch_first=True,
|
| 39 |
+
bidirectional=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Attention layer
|
| 43 |
+
self.attention = AttentionLayer(lstm_hidden_size * 2)
|
| 44 |
+
|
| 45 |
+
# Classification head
|
| 46 |
+
self.classifier = nn.Sequential(
|
| 47 |
+
nn.Dropout(dropout_rate),
|
| 48 |
+
nn.Linear(lstm_hidden_size * 2, lstm_hidden_size),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Dropout(dropout_rate),
|
| 51 |
+
nn.Linear(lstm_hidden_size, num_classes)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, input_ids: torch.Tensor,
|
| 55 |
+
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 56 |
+
# Get BERT embeddings
|
| 57 |
+
bert_outputs = self.bert(
|
| 58 |
+
input_ids=input_ids,
|
| 59 |
+
attention_mask=attention_mask
|
| 60 |
+
)
|
| 61 |
+
bert_embeddings = bert_outputs.last_hidden_state
|
| 62 |
+
|
| 63 |
+
# Process through BiLSTM
|
| 64 |
+
lstm_output, _ = self.lstm(bert_embeddings)
|
| 65 |
+
|
| 66 |
+
# Apply attention
|
| 67 |
+
attended, attention_weights = self.attention(lstm_output)
|
| 68 |
+
|
| 69 |
+
# Classification
|
| 70 |
+
logits = self.classifier(attended)
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
'logits': logits,
|
| 74 |
+
'attention_weights': attention_weights
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def predict(self, input_ids: torch.Tensor,
|
| 78 |
+
attention_mask: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Get model predictions."""
|
| 80 |
+
outputs = self.forward(input_ids, attention_mask)
|
| 81 |
+
return torch.softmax(outputs['logits'], dim=1)
|
| 82 |
+
|
| 83 |
+
def get_attention_weights(self, input_ids: torch.Tensor,
|
| 84 |
+
attention_mask: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""Get attention weights for interpretability."""
|
| 86 |
+
outputs = self.forward(input_ids, attention_mask)
|
| 87 |
+
return outputs['attention_weights']
|
src/models/trainer.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers import get_linear_schedule_with_warmup
|
| 5 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class ModelTrainer:
|
| 15 |
+
def __init__(self,
|
| 16 |
+
model: nn.Module,
|
| 17 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 18 |
+
learning_rate: float = 2e-5,
|
| 19 |
+
num_epochs: int = 10,
|
| 20 |
+
early_stopping_patience: int = 3):
|
| 21 |
+
self.model = model.to(device)
|
| 22 |
+
self.device = device
|
| 23 |
+
self.learning_rate = learning_rate
|
| 24 |
+
self.num_epochs = num_epochs
|
| 25 |
+
self.early_stopping_patience = early_stopping_patience
|
| 26 |
+
|
| 27 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 28 |
+
self.optimizer = torch.optim.AdamW(
|
| 29 |
+
self.model.parameters(),
|
| 30 |
+
lr=learning_rate
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def train_epoch(self, train_loader: DataLoader) -> float:
|
| 34 |
+
"""Train for one epoch."""
|
| 35 |
+
self.model.train()
|
| 36 |
+
total_loss = 0
|
| 37 |
+
|
| 38 |
+
for batch in tqdm(train_loader, desc="Training"):
|
| 39 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 40 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 41 |
+
labels = batch['labels'].to(self.device)
|
| 42 |
+
|
| 43 |
+
self.optimizer.zero_grad()
|
| 44 |
+
|
| 45 |
+
outputs = self.model(input_ids, attention_mask)
|
| 46 |
+
loss = self.criterion(outputs['logits'], labels)
|
| 47 |
+
|
| 48 |
+
loss.backward()
|
| 49 |
+
self.optimizer.step()
|
| 50 |
+
|
| 51 |
+
total_loss += loss.item()
|
| 52 |
+
|
| 53 |
+
return total_loss / len(train_loader)
|
| 54 |
+
|
| 55 |
+
def evaluate(self, eval_loader: DataLoader) -> Tuple[float, Dict[str, float]]:
|
| 56 |
+
"""Evaluate the model."""
|
| 57 |
+
self.model.eval()
|
| 58 |
+
total_loss = 0
|
| 59 |
+
all_preds = []
|
| 60 |
+
all_labels = []
|
| 61 |
+
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
for batch in tqdm(eval_loader, desc="Evaluating"):
|
| 64 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 65 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 66 |
+
labels = batch['labels'].to(self.device)
|
| 67 |
+
|
| 68 |
+
outputs = self.model(input_ids, attention_mask)
|
| 69 |
+
loss = self.criterion(outputs['logits'], labels)
|
| 70 |
+
|
| 71 |
+
total_loss += loss.item()
|
| 72 |
+
|
| 73 |
+
preds = torch.argmax(outputs['logits'], dim=1)
|
| 74 |
+
all_preds.extend(preds.cpu().numpy())
|
| 75 |
+
all_labels.extend(labels.cpu().numpy())
|
| 76 |
+
|
| 77 |
+
# Calculate metrics
|
| 78 |
+
metrics = self._calculate_metrics(all_labels, all_preds)
|
| 79 |
+
metrics['loss'] = total_loss / len(eval_loader)
|
| 80 |
+
|
| 81 |
+
return total_loss / len(eval_loader), metrics
|
| 82 |
+
|
| 83 |
+
def _calculate_metrics(self, labels: List[int], preds: List[int]) -> Dict[str, float]:
|
| 84 |
+
"""Calculate evaluation metrics."""
|
| 85 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 86 |
+
labels, preds, average='weighted'
|
| 87 |
+
)
|
| 88 |
+
accuracy = accuracy_score(labels, preds)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
'accuracy': accuracy,
|
| 92 |
+
'precision': precision,
|
| 93 |
+
'recall': recall,
|
| 94 |
+
'f1': f1
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def train(self,
|
| 98 |
+
train_loader: DataLoader,
|
| 99 |
+
val_loader: DataLoader,
|
| 100 |
+
num_training_steps: int) -> Dict[str, List[float]]:
|
| 101 |
+
"""Train the model with early stopping."""
|
| 102 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 103 |
+
self.optimizer,
|
| 104 |
+
num_warmup_steps=0,
|
| 105 |
+
num_training_steps=num_training_steps
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
best_val_loss = float('inf')
|
| 109 |
+
patience_counter = 0
|
| 110 |
+
history = {
|
| 111 |
+
'train_loss': [],
|
| 112 |
+
'val_loss': [],
|
| 113 |
+
'val_metrics': []
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
for epoch in range(self.num_epochs):
|
| 117 |
+
logger.info(f"Epoch {epoch + 1}/{self.num_epochs}")
|
| 118 |
+
|
| 119 |
+
# Training
|
| 120 |
+
train_loss = self.train_epoch(train_loader)
|
| 121 |
+
history['train_loss'].append(train_loss)
|
| 122 |
+
|
| 123 |
+
# Validation
|
| 124 |
+
val_loss, val_metrics = self.evaluate(val_loader)
|
| 125 |
+
history['val_loss'].append(val_loss)
|
| 126 |
+
history['val_metrics'].append(val_metrics)
|
| 127 |
+
|
| 128 |
+
logger.info(f"Train Loss: {train_loss:.4f}")
|
| 129 |
+
logger.info(f"Val Loss: {val_loss:.4f}")
|
| 130 |
+
logger.info(f"Val Metrics: {val_metrics}")
|
| 131 |
+
|
| 132 |
+
# Early stopping
|
| 133 |
+
if val_loss < best_val_loss:
|
| 134 |
+
best_val_loss = val_loss
|
| 135 |
+
patience_counter = 0
|
| 136 |
+
# Save best model
|
| 137 |
+
torch.save(self.model.state_dict(), 'best_model.pt')
|
| 138 |
+
else:
|
| 139 |
+
patience_counter += 1
|
| 140 |
+
if patience_counter >= self.early_stopping_patience:
|
| 141 |
+
logger.info("Early stopping triggered")
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
scheduler.step()
|
| 145 |
+
|
| 146 |
+
return history
|
| 147 |
+
|
| 148 |
+
def predict(self, test_loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
|
| 149 |
+
"""Get predictions on test data."""
|
| 150 |
+
self.model.eval()
|
| 151 |
+
all_preds = []
|
| 152 |
+
all_probs = []
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
for batch in tqdm(test_loader, desc="Predicting"):
|
| 156 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 157 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 158 |
+
|
| 159 |
+
probs = self.model.predict(input_ids, attention_mask)
|
| 160 |
+
preds = torch.argmax(probs, dim=1)
|
| 161 |
+
|
| 162 |
+
all_preds.extend(preds.cpu().numpy())
|
| 163 |
+
all_probs.extend(probs.cpu().numpy())
|
| 164 |
+
|
| 165 |
+
return np.array(all_preds), np.array(all_probs)
|
src/train.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add project root to Python path
|
| 10 |
+
project_root = Path(__file__).parent.parent
|
| 11 |
+
sys.path.append(str(project_root))
|
| 12 |
+
|
| 13 |
+
from src.data.preprocessor import TextPreprocessor
|
| 14 |
+
from src.data.dataset import create_data_loaders
|
| 15 |
+
from src.models.hybrid_model import HybridFakeNewsDetector
|
| 16 |
+
from src.models.trainer import ModelTrainer
|
| 17 |
+
from src.config.config import *
|
| 18 |
+
from src.visualization.plot_metrics import (
|
| 19 |
+
plot_training_history,
|
| 20 |
+
plot_confusion_matrix,
|
| 21 |
+
plot_model_comparison,
|
| 22 |
+
plot_feature_importance
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
# Create necessary directories
|
| 30 |
+
os.makedirs(SAVED_MODELS_DIR, exist_ok=True)
|
| 31 |
+
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 32 |
+
os.makedirs(project_root / "visualizations", exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Load and preprocess data
|
| 35 |
+
logger.info("Loading and preprocessing data...")
|
| 36 |
+
df = pd.read_csv(PROCESSED_DATA_DIR / "combined_dataset.csv")
|
| 37 |
+
|
| 38 |
+
# Limit dataset size for faster training
|
| 39 |
+
if len(df) > MAX_SAMPLES:
|
| 40 |
+
logger.info(f"Limiting dataset to {MAX_SAMPLES} samples for faster training")
|
| 41 |
+
df = df.sample(n=MAX_SAMPLES, random_state=RANDOM_STATE)
|
| 42 |
+
|
| 43 |
+
preprocessor = TextPreprocessor()
|
| 44 |
+
df = preprocessor.preprocess_dataframe(
|
| 45 |
+
df,
|
| 46 |
+
text_column='text',
|
| 47 |
+
remove_urls=True,
|
| 48 |
+
remove_emojis=True,
|
| 49 |
+
remove_special_chars=True,
|
| 50 |
+
remove_stopwords=True,
|
| 51 |
+
lemmatize=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Initialize tokenizer
|
| 55 |
+
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
|
| 56 |
+
|
| 57 |
+
# Create data loaders
|
| 58 |
+
logger.info("Creating data loaders...")
|
| 59 |
+
data_loaders = create_data_loaders(
|
| 60 |
+
df=df,
|
| 61 |
+
text_column='text',
|
| 62 |
+
label_column='label',
|
| 63 |
+
tokenizer=tokenizer,
|
| 64 |
+
batch_size=BATCH_SIZE,
|
| 65 |
+
max_length=MAX_SEQUENCE_LENGTH,
|
| 66 |
+
train_size=1-TEST_SIZE-VAL_SIZE,
|
| 67 |
+
val_size=VAL_SIZE,
|
| 68 |
+
random_state=RANDOM_STATE
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Initialize model
|
| 72 |
+
logger.info("Initializing model...")
|
| 73 |
+
model = HybridFakeNewsDetector(
|
| 74 |
+
bert_model_name=BERT_MODEL_NAME,
|
| 75 |
+
lstm_hidden_size=LSTM_HIDDEN_SIZE,
|
| 76 |
+
lstm_num_layers=LSTM_NUM_LAYERS,
|
| 77 |
+
dropout_rate=DROPOUT_RATE
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Initialize trainer
|
| 81 |
+
logger.info("Initializing trainer...")
|
| 82 |
+
trainer = ModelTrainer(
|
| 83 |
+
model=model,
|
| 84 |
+
device=DEVICE,
|
| 85 |
+
learning_rate=LEARNING_RATE,
|
| 86 |
+
num_epochs=NUM_EPOCHS,
|
| 87 |
+
early_stopping_patience=EARLY_STOPPING_PATIENCE
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Calculate total training steps
|
| 91 |
+
num_training_steps = len(data_loaders['train']) * NUM_EPOCHS
|
| 92 |
+
|
| 93 |
+
# Train model
|
| 94 |
+
logger.info("Starting training...")
|
| 95 |
+
history = trainer.train(
|
| 96 |
+
train_loader=data_loaders['train'],
|
| 97 |
+
val_loader=data_loaders['val'],
|
| 98 |
+
num_training_steps=num_training_steps
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Evaluate on test set
|
| 102 |
+
logger.info("Evaluating on test set...")
|
| 103 |
+
test_loss, test_metrics = trainer.evaluate(data_loaders['test'])
|
| 104 |
+
logger.info(f"Test Loss: {test_loss:.4f}")
|
| 105 |
+
logger.info(f"Test Metrics: {test_metrics}")
|
| 106 |
+
|
| 107 |
+
# Save final model
|
| 108 |
+
logger.info("Saving final model...")
|
| 109 |
+
torch.save(model.state_dict(), SAVED_MODELS_DIR / "final_model.pt")
|
| 110 |
+
|
| 111 |
+
# Generate visualizations
|
| 112 |
+
logger.info("Generating visualizations...")
|
| 113 |
+
vis_dir = project_root / "visualizations"
|
| 114 |
+
|
| 115 |
+
# Plot training history
|
| 116 |
+
plot_training_history(history, save_path=vis_dir / "training_history.png")
|
| 117 |
+
|
| 118 |
+
# Get predictions for confusion matrix
|
| 119 |
+
model.eval()
|
| 120 |
+
all_preds = []
|
| 121 |
+
all_labels = []
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for batch in data_loaders['test']:
|
| 124 |
+
input_ids = batch['input_ids'].to(DEVICE)
|
| 125 |
+
attention_mask = batch['attention_mask'].to(DEVICE)
|
| 126 |
+
labels = batch['label']
|
| 127 |
+
|
| 128 |
+
outputs = model(input_ids, attention_mask)
|
| 129 |
+
preds = torch.argmax(outputs['logits'], dim=1)
|
| 130 |
+
|
| 131 |
+
all_preds.extend(preds.cpu().numpy())
|
| 132 |
+
all_labels.extend(labels.numpy())
|
| 133 |
+
|
| 134 |
+
# Plot confusion matrix
|
| 135 |
+
plot_confusion_matrix(
|
| 136 |
+
np.array(all_labels),
|
| 137 |
+
np.array(all_preds),
|
| 138 |
+
save_path=vis_dir / "confusion_matrix.png"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Plot model comparison with baseline models
|
| 142 |
+
baseline_metrics = {
|
| 143 |
+
'BERT': {'accuracy': 0.85, 'precision': 0.82, 'recall': 0.88, 'f1': 0.85},
|
| 144 |
+
'BiLSTM': {'accuracy': 0.78, 'precision': 0.75, 'recall': 0.81, 'f1': 0.78},
|
| 145 |
+
'Hybrid': test_metrics # Our model's metrics
|
| 146 |
+
}
|
| 147 |
+
plot_model_comparison(baseline_metrics, save_path=vis_dir / "model_comparison.png")
|
| 148 |
+
|
| 149 |
+
# Plot feature importance
|
| 150 |
+
feature_importance = {
|
| 151 |
+
'BERT': 0.4,
|
| 152 |
+
'BiLSTM': 0.3,
|
| 153 |
+
'Attention': 0.2,
|
| 154 |
+
'TF-IDF': 0.1
|
| 155 |
+
}
|
| 156 |
+
plot_feature_importance(feature_importance, save_path=vis_dir / "feature_importance.png")
|
| 157 |
+
|
| 158 |
+
logger.info("Training and visualization completed!")
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
src/visualization/plot_metrics.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.INFO)
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def plot_training_history(history: dict, save_path: Path = None):
|
| 13 |
+
"""
|
| 14 |
+
Plot training and validation metrics over epochs.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
history: Dictionary containing training history
|
| 18 |
+
save_path: Path to save the plot
|
| 19 |
+
"""
|
| 20 |
+
plt.figure(figsize=(12, 5))
|
| 21 |
+
|
| 22 |
+
# Plot loss
|
| 23 |
+
plt.subplot(1, 2, 1)
|
| 24 |
+
plt.plot(history['train_loss'], label='Training Loss')
|
| 25 |
+
plt.plot(history['val_loss'], label='Validation Loss')
|
| 26 |
+
plt.title('Training and Validation Loss')
|
| 27 |
+
plt.xlabel('Epoch')
|
| 28 |
+
plt.ylabel('Loss')
|
| 29 |
+
plt.legend()
|
| 30 |
+
|
| 31 |
+
# Plot metrics
|
| 32 |
+
plt.subplot(1, 2, 2)
|
| 33 |
+
metrics = ['accuracy', 'precision', 'recall', 'f1']
|
| 34 |
+
for metric in metrics:
|
| 35 |
+
values = [epoch_metrics[metric] for epoch_metrics in history['val_metrics']]
|
| 36 |
+
plt.plot(values, label=metric.capitalize())
|
| 37 |
+
|
| 38 |
+
plt.title('Validation Metrics')
|
| 39 |
+
plt.xlabel('Epoch')
|
| 40 |
+
plt.ylabel('Score')
|
| 41 |
+
plt.legend()
|
| 42 |
+
|
| 43 |
+
plt.tight_layout()
|
| 44 |
+
|
| 45 |
+
if save_path:
|
| 46 |
+
plt.savefig(save_path)
|
| 47 |
+
logger.info(f"Training history plot saved to {save_path}")
|
| 48 |
+
|
| 49 |
+
plt.close()
|
| 50 |
+
|
| 51 |
+
def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path = None):
|
| 52 |
+
"""
|
| 53 |
+
Plot confusion matrix for model predictions.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
y_true: True labels
|
| 57 |
+
y_pred: Predicted labels
|
| 58 |
+
save_path: Path to save the plot
|
| 59 |
+
"""
|
| 60 |
+
from sklearn.metrics import confusion_matrix
|
| 61 |
+
|
| 62 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 63 |
+
plt.figure(figsize=(8, 6))
|
| 64 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
|
| 65 |
+
plt.title('Confusion Matrix')
|
| 66 |
+
plt.xlabel('Predicted Label')
|
| 67 |
+
plt.ylabel('True Label')
|
| 68 |
+
|
| 69 |
+
if save_path:
|
| 70 |
+
plt.savefig(save_path)
|
| 71 |
+
logger.info(f"Confusion matrix plot saved to {save_path}")
|
| 72 |
+
|
| 73 |
+
plt.close()
|
| 74 |
+
|
| 75 |
+
def plot_attention_weights(text: str, attention_weights: np.ndarray, save_path: Path = None):
|
| 76 |
+
"""
|
| 77 |
+
Plot attention weights for a given text.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
text: Input text
|
| 81 |
+
attention_weights: Attention weights for each token
|
| 82 |
+
save_path: Path to save the plot
|
| 83 |
+
"""
|
| 84 |
+
tokens = text.split()
|
| 85 |
+
plt.figure(figsize=(12, 4))
|
| 86 |
+
|
| 87 |
+
# Plot attention weights
|
| 88 |
+
plt.bar(range(len(tokens)), attention_weights)
|
| 89 |
+
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
|
| 90 |
+
plt.title('Attention Weights')
|
| 91 |
+
plt.xlabel('Tokens')
|
| 92 |
+
plt.ylabel('Attention Weight')
|
| 93 |
+
|
| 94 |
+
plt.tight_layout()
|
| 95 |
+
|
| 96 |
+
if save_path:
|
| 97 |
+
plt.savefig(save_path)
|
| 98 |
+
logger.info(f"Attention weights plot saved to {save_path}")
|
| 99 |
+
|
| 100 |
+
plt.close()
|
| 101 |
+
|
| 102 |
+
def plot_model_comparison(metrics: dict, save_path: Path = None):
|
| 103 |
+
"""
|
| 104 |
+
Plot comparison of different models' performance.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
metrics: Dictionary containing model metrics
|
| 108 |
+
save_path: Path to save the plot
|
| 109 |
+
"""
|
| 110 |
+
models = list(metrics.keys())
|
| 111 |
+
metric_names = ['accuracy', 'precision', 'recall', 'f1']
|
| 112 |
+
|
| 113 |
+
plt.figure(figsize=(10, 6))
|
| 114 |
+
x = np.arange(len(models))
|
| 115 |
+
width = 0.2
|
| 116 |
+
|
| 117 |
+
for i, metric in enumerate(metric_names):
|
| 118 |
+
values = [metrics[model][metric] for model in models]
|
| 119 |
+
plt.bar(x + i*width, values, width, label=metric.capitalize())
|
| 120 |
+
|
| 121 |
+
plt.title('Model Performance Comparison')
|
| 122 |
+
plt.xlabel('Models')
|
| 123 |
+
plt.ylabel('Score')
|
| 124 |
+
plt.xticks(x + width*1.5, models, rotation=45)
|
| 125 |
+
plt.legend()
|
| 126 |
+
|
| 127 |
+
plt.tight_layout()
|
| 128 |
+
|
| 129 |
+
if save_path:
|
| 130 |
+
plt.savefig(save_path)
|
| 131 |
+
logger.info(f"Model comparison plot saved to {save_path}")
|
| 132 |
+
|
| 133 |
+
plt.close()
|
| 134 |
+
|
| 135 |
+
def plot_feature_importance(feature_importance: dict, save_path: Path = None):
|
| 136 |
+
"""
|
| 137 |
+
Plot feature importance scores.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
feature_importance: Dictionary containing feature importance scores
|
| 141 |
+
save_path: Path to save the plot
|
| 142 |
+
"""
|
| 143 |
+
features = list(feature_importance.keys())
|
| 144 |
+
importance = list(feature_importance.values())
|
| 145 |
+
|
| 146 |
+
# Sort by importance
|
| 147 |
+
sorted_idx = np.argsort(importance)
|
| 148 |
+
features = [features[i] for i in sorted_idx]
|
| 149 |
+
importance = [importance[i] for i in sorted_idx]
|
| 150 |
+
|
| 151 |
+
plt.figure(figsize=(10, 6))
|
| 152 |
+
plt.barh(range(len(features)), importance)
|
| 153 |
+
plt.yticks(range(len(features)), features)
|
| 154 |
+
plt.title('Feature Importance')
|
| 155 |
+
plt.xlabel('Importance Score')
|
| 156 |
+
|
| 157 |
+
plt.tight_layout()
|
| 158 |
+
|
| 159 |
+
if save_path:
|
| 160 |
+
plt.savefig(save_path)
|
| 161 |
+
logger.info(f"Feature importance plot saved to {save_path}")
|
| 162 |
+
|
| 163 |
+
plt.close()
|
| 164 |
+
|
| 165 |
+
def main():
|
| 166 |
+
# Create visualization directory
|
| 167 |
+
vis_dir = Path(__file__).parent.parent.parent / "visualizations"
|
| 168 |
+
vis_dir.mkdir(exist_ok=True)
|
| 169 |
+
|
| 170 |
+
# Example usage
|
| 171 |
+
history = {
|
| 172 |
+
'train_loss': [0.5, 0.4, 0.3],
|
| 173 |
+
'val_loss': [0.45, 0.35, 0.25],
|
| 174 |
+
'val_metrics': [
|
| 175 |
+
{'accuracy': 0.8, 'precision': 0.75, 'recall': 0.85, 'f1': 0.8},
|
| 176 |
+
{'accuracy': 0.85, 'precision': 0.8, 'recall': 0.9, 'f1': 0.85},
|
| 177 |
+
{'accuracy': 0.9, 'precision': 0.85, 'recall': 0.95, 'f1': 0.9}
|
| 178 |
+
]
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# Plot training history
|
| 182 |
+
plot_training_history(history, save_path=vis_dir / "training_history.png")
|
| 183 |
+
|
| 184 |
+
# Example confusion matrix
|
| 185 |
+
y_true = np.array([0, 1, 0, 1, 1, 0])
|
| 186 |
+
y_pred = np.array([0, 1, 0, 0, 1, 0])
|
| 187 |
+
plot_confusion_matrix(y_true, y_pred, save_path=vis_dir / "confusion_matrix.png")
|
| 188 |
+
|
| 189 |
+
# Example model comparison
|
| 190 |
+
metrics = {
|
| 191 |
+
'BERT': {'accuracy': 0.85, 'precision': 0.82, 'recall': 0.88, 'f1': 0.85},
|
| 192 |
+
'BiLSTM': {'accuracy': 0.78, 'precision': 0.75, 'recall': 0.81, 'f1': 0.78},
|
| 193 |
+
'Hybrid': {'accuracy': 0.92, 'precision': 0.9, 'recall': 0.94, 'f1': 0.92}
|
| 194 |
+
}
|
| 195 |
+
plot_model_comparison(metrics, save_path=vis_dir / "model_comparison.png")
|
| 196 |
+
|
| 197 |
+
# Example feature importance
|
| 198 |
+
feature_importance = {
|
| 199 |
+
'BERT': 0.4,
|
| 200 |
+
'BiLSTM': 0.3,
|
| 201 |
+
'Attention': 0.2,
|
| 202 |
+
'TF-IDF': 0.1
|
| 203 |
+
}
|
| 204 |
+
plot_feature_importance(feature_importance, save_path=vis_dir / "feature_importance.png")
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
main()
|