emsesc's picture
add formatting and fix bug
77e9502
raw
history blame
10.4 kB
# Import packages
from dash import Dash, html, dcc, Input, Output
import pandas as pd
import plotly.express as px
from graphs.model_market_share import create_plotly_stacked_area_chart, create_plotly_world_map, create_plotly_range_slider, create_leaderboard
from graphs.model_characteristics import create_plotly_language_concentration_chart, create_plotly_publication_curves_with_legend
# Initialize the app
app = Dash()
server = app.server
# Load pre-processed data frames
model_topk_df = pd.read_pickle("data_frames/model_topk_df.pkl")
model_gini_df = pd.read_pickle("data_frames/model_gini_df.pkl")
model_hhi_df = pd.read_pickle("data_frames/model_hhi_df.pkl")
language_concentration_df = pd.read_pickle("data_frames/language_concentration_df.pkl")
license_concentration_df = pd.read_pickle("data_frames/download_license_cumsum_df.pkl")
download_method_cumsum_df = pd.read_pickle("data_frames/download_method_cumsum_df.pkl")
download_arch_cumsum_df = pd.read_pickle("data_frames/download_arch_cumsum_df.pkl")
nat_topk_df = pd.read_pickle("data_frames/nat_topk_df.pkl")
country_concentration_df = pd.read_pickle("data_frames/country_concentration_df.pkl")
author_concentration_df = pd.read_pickle("data_frames/author_concentration_df.pkl")
model_concentration_df = pd.read_pickle("data_frames/model_concentration_df.pkl")
TEMP_MODEL_EVENTS = {
# "Yolo World Mirror": "2024-03-01",
"Llama 3": "2024-04-17",
"Stable Cascade": "2024-02-02",
"Stable Diffusion 3": "2024-05-30",
# "embed/upscale": "2023-03-24",
"DeepSeek-R1": "2025-01-20",
"Gemma-3 12B QAT": "2025-04-15", # gemma-3-12b-it-qat-4bit
# "Qwen": "2025-03-05",
# "Flux RedFlux": "2025-04-12",
# "DeepSeek-V3": "2025-03-24",
# "bloom": "2022-05-19",
"DALLE2-PyTorch": "2022-06-25",
"Stable Diffusion": "2022-08-10",
"CLIP ViT": "2021-01-05",
"YOLOv8": "2023-04-26",
"Sentence Transformer MiniLM v2": "2021-08-30",
}
PALETTE_0 = [
"#335C67",
"#FFF3B0",
"#E09F3E",
"#9E2A2B",
"#540B0E"
]
fig = create_plotly_stacked_area_chart(
model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0
)
LANG_SEGMENT_ORDER = [
'Monolingual: EN', 'Monolingual: HR', 'Monolingual: M/LR',
'Multilingual: HR', 'Multilingual', 'Unknown',
]
fig2 = create_plotly_language_concentration_chart(
language_concentration_df, 'time', 'metric', 'value', LANG_SEGMENT_ORDER, PALETTE_0
)
LICENSE_SEGMENT_ORDER = [
"Open Use", "Open Use (Acceptable Use Policy)", "Open Use (Non-Commercial Only)", "Attribution",
"Acceptable Use Policy", "Non-Commercial Only", "Undocumented", "Undocumented (Acceptable Use Policy)",
]
fig3 = create_plotly_language_concentration_chart(
license_concentration_df, 'period', 'status', 'percent', LICENSE_SEGMENT_ORDER, PALETTE_0
)
METHOD_PLOT_CHOICES = {
"cumulative": "none", # none, mean, sum
"y_col": "percent", # percent count
"y_log": False, # True, False
"period": "W",
}
fig4 = create_plotly_publication_curves_with_legend(
download_method_cumsum_df, METHOD_PLOT_CHOICES, PALETTE_0
)
ARCHITECTURE_PLOT_CHOICES = {
"cumulative": "none", # none, mean, sum
"y_col": "percent", # percent count
"y_log": False, # True, False
"period": "W",
}
fig5 = create_plotly_publication_curves_with_legend(
download_arch_cumsum_df, ARCHITECTURE_PLOT_CHOICES, PALETTE_0
)
fig6 = create_plotly_world_map(
country_concentration_df, "time", "metric", "value"
)
fig7 = create_leaderboard(
country_concentration_df, author_concentration_df, model_concentration_df
)
slider = create_plotly_range_slider(
model_topk_df
)
slider2 = create_plotly_range_slider(
country_concentration_df
)
# Make global font family
fig.update_layout(font_family="Inter")
fig2.update_layout(font_family="Inter")
fig3.update_layout(font_family="Inter")
fig4.update_layout(font_family="Inter")
fig5.update_layout(font_family="Inter")
fig6.update_layout(font_family="Inter")
slider.update_layout(font_family="Inter")
slider2.update_layout(font_family="Inter")
# App layout
app.layout = html.Div(
[
html.Div(
[
html.Div(children='Visualizing the Open Model Ecosystem', style={'fontSize': 28, 'fontWeight': 'bold', 'marginBottom': 6}),
html.Div(children='An interactive dashboard to explore trends in open models on Hugging Face', style={'fontSize': 16, 'marginBottom': 12}),
html.Hr(style={'marginTop': 8, 'marginBottom': 8}),
],
style={'textAlign': 'center'}
),
html.Div(
[
dcc.Tabs([
dcc.Tab(label='Model Market Share', children=[
html.Div([
html.Div(children='Select time range to update all graphs below:', style={'fontSize': 16, 'marginBottom': 6, 'marginTop': 10}),
dcc.Graph(figure=slider2, id='time-slider', style={'height': '100px'}),
html.Div(
id='output-container-range-slider',
style={
'textAlign': 'center',
'fontSize': 20,
'marginBottom': 15,
'marginTop': 30,
'backgroundColor': 'white',
'borderRadius': '12px',
'boxShadow': '0 2px 12px rgba(0,0,0,0.10)',
'padding': '18px',
'display': 'inline-block',
}
),
], style={'marginBottom': 12, 'justifyContent': 'center', 'textAlign': 'center'}),
html.Div([
dcc.Graph(id='stacked-area-chart'),
], style={'marginBottom': 12}),
html.Div([
html.Div(
dcc.Graph(id='world-map-with-slider'),
style={'display': 'flex', 'justifyContent': 'center'}
),
dcc.Graph(id='leaderboard'),
], style={'marginBottom': 12})
]),
dcc.Tab(label='Model Characteristics', children=[
dcc.Graph(id='language-concentration-chart'),
html.Div([
dcc.Dropdown(['Language Concentration', 'Architecture', 'License', 'Method'], 'Language Concentration', id='dropdown'),
], style={'marginTop': 6}),
]),
])
],
style={
'backgroundColor': 'white',
'borderRadius': '18px',
'boxShadow': '0 4px 24px rgba(0,0,0,0.10)',
'padding': '32px',
'margin': '32px auto',
'maxWidth': '1250px',
}
)
],
style={'fontFamily': 'Inter', 'backgroundColor': '#f7f7fa', 'minHeight': '100vh'}
)
@app.callback(
Output('output-container-range-slider', 'children'),
[Input('time-slider', 'relayoutData')]
)
def update_output(relayout_data):
if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
return f'Selected time range: {start_time} to {end_time}'
else:
return 'Selected time range: All data'
# On dropdown change, update graph
@app.callback(
Output('language-concentration-chart', 'figure'),
[Input('dropdown', 'value')]
)
def update_graph(selected_metric):
if selected_metric == 'Language Concentration':
return fig2
elif selected_metric == 'License':
return fig3
elif selected_metric == 'Method':
return fig4
elif selected_metric == 'Architecture':
return fig5
@app.callback(
Output('world-map-with-slider', 'figure'),
[Input('time-slider', 'relayoutData')]
)
def update_map(relayout_data):
if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
updated_fig = create_plotly_world_map(
country_concentration_df, "time", "metric", "value", start_time=start_time, end_time=end_time
)
updated_fig.update_layout(font_family="Inter")
return updated_fig
else:
return fig6
@app.callback(
Output('leaderboard', 'figure'),
[Input('time-slider', 'relayoutData')]
)
def update_leaderboard(relayout_data):
if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
updated_fig = create_leaderboard(
country_concentration_df, author_concentration_df, model_concentration_df, start_time=start_time, end_time=end_time
)
updated_fig.update_layout(font_family="Inter")
return updated_fig
else:
return fig7
@app.callback(
Output('stacked-area-chart', 'figure'),
[Input('time-slider', 'relayoutData')]
)
def update_stacked_area(relayout_data):
if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
updated_fig = create_plotly_stacked_area_chart(
model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0,
start_time=start_time, end_time=end_time
)
updated_fig.update_layout(font_family="Inter")
return updated_fig
else:
return fig
# Run the app
if __name__ == '__main__':
app.run(debug=True)