| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import joblib | |
| import gradio as gr | |
| from dateutil.relativedelta import relativedelta | |
| import calendar | |
| def load_model(): | |
| try: | |
| model = joblib.load('arima_sales_model.pkl') | |
| return model, None | |
| except Exception as e: | |
| return None, f"Failed to load model: {str(e)}" | |
| def parse_date(date_str): | |
| """Parse the custom date format 'Month-Year'.""" | |
| try: | |
| date = pd.to_datetime(date_str, format="%B-%Y") | |
| _, last_day = calendar.monthrange(date.year, date.month) | |
| start_date = date.replace(day=1) | |
| end_date = date.replace(day=last_day) | |
| return start_date, end_date, None | |
| except ValueError: | |
| return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'." | |
| def forecast_sales(uploaded_file, start_date_str, end_date_str): | |
| if uploaded_file is None: | |
| return "No file uploaded.", None, "Please upload a file." | |
| try: | |
| df = pd.read_csv(uploaded_file) | |
| if 'Date' not in df.columns or 'Sale' not in df.columns: | |
| return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns." | |
| except Exception as e: | |
| return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file." | |
| start_date, _, error = parse_date(start_date_str) | |
| _, end_date, error_end = parse_date(end_date_str) | |
| if error or error_end: | |
| return None, error or error_end, "Invalid date format." | |
| df['Date'] = pd.to_datetime(df['Date']) | |
| df = df.rename(columns={'Date': 'ds', 'Sale': 'y'}) | |
| df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)] | |
| arima_model, error = load_model() | |
| if arima_model is None: | |
| return None, error, "Failed to load ARIMA model." | |
| try: | |
| forecast = arima_model.get_forecast(steps=60) | |
| forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:] | |
| forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean}) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue') | |
| ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--') | |
| ax.set_xlabel('Date') | |
| ax.set_ylabel('Sales') | |
| ax.set_title('Sales Forecasting with ARIMA') | |
| ax.legend() | |
| return fig, "File loaded and processed successfully." | |
| except Exception as e: | |
| return None, f"Failed to generate plot: {str(e)}", "Plotting failed." | |
| def setup_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Upload your store data") | |
| start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024") | |
| end_date_input = gr.Textbox(label="End Date", placeholder="December-2024") | |
| forecast_button = gr.Button("Forecast Sales") | |
| with gr.Column(scale=2): | |
| output_plot = gr.Plot() | |
| output_message = gr.Textbox(label="Notifications", visible=True, lines=2) | |
| forecast_button.click( | |
| forecast_sales, | |
| inputs=[file_input, start_date_input, end_date_input], | |
| outputs=[output_plot, output_message] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = setup_interface() | |
| interface.launch() | |