{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fairness by Counterfactuals\n", "This notebook will demonstrate how to implement the concept of \"fairness\" in a case of multiclass classification. \n", "Usually, fairness and explainability focus on the descriptive side of the model, providing the variables that impact the most the final prediction/classification. \n", "We will be using a library named DiCE (https://github.com/interpretml/DiCE) which is able to provide active explanations through counterfactuals from the real data of a Machine Learning (ML) model. This way it would be possible to understand also how much a variable should change to reach the desired outcome\n", "\n", "## Case Study: NBA Players Salary Expectations\n", "We will be implementing DiCE with an official NBA database reporting players' stats and salary for 2022/2023 season. The aim is to suggest which (and, most importantly, how much) stats need to improve in order to expect a better salary. Let's jump right in." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install and necessary libraries\n", "DiCE can be easily installed (check package documentation). Other than that only *pandas* and *scikit-learn* must be loaded, so Standard libraries for ML classification are required, nothing really fancy. \n", "DiCE claims to be able to work with whatever ML model, therefore you could also just build your own." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import dice_ml\n", "from dice_ml import Dice\n", "\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and preprocess the NBA dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For multiclass classification, we will try to predict through stats the range of salary of NBA players. We start from the assumption that better stats mean better salary, disregarding intangible skills such as leadership, locker room chemistry building and charisma: it's a shame, but this is a toy example to play around with.\n", "In all honesty, it would have been better to average the stats of the past five seasons, which is when the current contract is usually earned, but it would have been complicated to account for younger players, so for the moment this should do it." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PlayerPosAgeGGSMPFGFGAFTFTAORBDRBTRBASTSTLBLKTOVPFPTSsalary
0Aaron GordonPF26757531.75.811.12.33.11.74.25.92.50.60.61.82.015.019690909.0
1Aaron HolidayPG25631516.22.45.40.91.10.41.61.92.40.70.11.11.56.31836090.0
2Aaron NesmithSF2252311.01.43.50.40.50.31.41.70.40.40.10.61.33.83804360.0
3Aaron WigginsSG23503524.23.16.71.21.71.02.53.61.40.60.21.11.98.31563518.0
4Admiral SchofieldSF2438112.31.43.40.30.40.41.92.30.70.10.10.61.53.8506508.0
\n", "
" ], "text/plain": [ " Player Pos Age G GS MP FG FGA FT FTA ORB DRB \\\n", "0 Aaron Gordon PF 26 75 75 31.7 5.8 11.1 2.3 3.1 1.7 4.2 \n", "1 Aaron Holiday PG 25 63 15 16.2 2.4 5.4 0.9 1.1 0.4 1.6 \n", "2 Aaron Nesmith SF 22 52 3 11.0 1.4 3.5 0.4 0.5 0.3 1.4 \n", "3 Aaron Wiggins SG 23 50 35 24.2 3.1 6.7 1.2 1.7 1.0 2.5 \n", "4 Admiral Schofield SF 24 38 1 12.3 1.4 3.4 0.3 0.4 0.4 1.9 \n", "\n", " TRB AST STL BLK TOV PF PTS salary \n", "0 5.9 2.5 0.6 0.6 1.8 2.0 15.0 19690909.0 \n", "1 1.9 2.4 0.7 0.1 1.1 1.5 6.3 1836090.0 \n", "2 1.7 0.4 0.4 0.1 0.6 1.3 3.8 3804360.0 \n", "3 3.6 1.4 0.6 0.2 1.1 1.9 8.3 1563518.0 \n", "4 2.3 0.7 0.1 0.1 0.6 1.5 3.8 506508.0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "url = 'https://raw.githubusercontent.com/antoniocollesei/nba-fairness-salary/main/stats_salary_NBA_2223.csv'\n", "\n", "df = pd.read_csv(url)\n", "df = df.dropna()\n", "# remove columns that have a '.' or an 'X' in the column name (so, percentages)\n", "df = df.loc[:,~df.columns.str.contains('\\.')]\n", "df = df.loc[:,~df.columns.str.contains('X')]\n", "# reset the index (it will be useful later, trust me)\n", "df = df.reset_index(drop=True)\n", "df.head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the target\n", "Although you could make this work with regression, we want to work with classification to make it easily understandable. Therefore, we subdivide the continuous salary into four classes. We tried to make the classes as balanced as possible without undermining the purpose, but it is clear that most of the players are not superstars, earning way less than their colleagues." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5m- 226\n", "5-15m 132\n", "25m+ 50\n", "15-25m 41\n", "Name: salary, dtype: int64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outcome_name = \"salary\"\n", "continuous_features = df.drop(outcome_name, axis=1).select_dtypes(include=['float64', 'int64']).columns\n", "target = df[outcome_name]\n", "# factorize target into 4 classes\n", "target_cat = pd.cut(target, bins=[0, 5e6, 1.5e7, 2.5e7, target.max()], labels=[\"5m-\", \"5-15m\", \"15-25m\", \"25m+\"])\n", "# substitute target with factorized version\n", "df[outcome_name] = target_cat\n", "\n", "target_cat.value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-Class Modeling\n", "Here we build the ML model (a Random Forest Classifier). Note that we introduce also the blocks to scale continous variables and one-hot encode the categorical ones." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# save vector of players and drop from df\n", "players = df['Player']\n", "df = df.drop(['Player'], axis=1)\n", "\n", "# Split data into train and test\n", "datasetX = df.drop(outcome_name, axis=1)\n", "x_train, x_test, y_train, y_test = train_test_split(datasetX,\n", " target_cat,\n", " test_size=0.3,\n", " random_state=42,\n", " stratify=target_cat)\n", "\n", "# Create the same dataset but with the player names (it will be used later to select the player)\n", "datasetX['Player'] = players\n", "x_player_train, x_player_test, y_player_train, y_player_test = train_test_split(datasetX,\n", " target_cat,\n", " test_size=0.3,\n", " random_state=42,\n", " stratify=target_cat)\n", "\n", "# Implementing transformers for variables\n", "categorical_features = x_train.columns.difference(continuous_features)\n", "\n", "numeric_transformer = Pipeline(steps=[\n", " ('scaler', StandardScaler())])\n", "\n", "categorical_transformer = Pipeline(steps=[\n", " ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n", "\n", "transformations = ColumnTransformer(\n", " transformers=[\n", " ('num', numeric_transformer, continuous_features),\n", " ('cat', categorical_transformer, categorical_features)])\n", "\n", "# Create a pipeline with the transformations and the classifier\n", "clf = Pipeline(steps=[('preprocessor', transformations),\n", " ('classifier', RandomForestClassifier())])\n", "model = clf.fit(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build DiCE model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "d = dice_ml.Data(dataframe=df,\n", " continuous_features=list(continuous_features),\n", " outcome_name=outcome_name)\n", "\n", "# We provide the type of model as a parameter (model_type)\n", "m = dice_ml.Model(model=model, backend=\"sklearn\", model_type='classifier')\n", "\n", "exp_genetic = Dice(d, m, method=\"genetic\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we list the possible players to analyze through DiCE: for simplicity we will get a pool of the first 20 out of the test dataset." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Saben Lee' 'Isaac Okoro' 'Dewayne Dedmon' 'Sam Merrill'\n", " 'Kevin Porter Jr.' 'Devin Cannady' 'Doug McDermott' 'PJ Dozier'\n", " 'Richaun Holmes' 'Naz Reid' 'Kai Jones' 'Max Strus' 'Marvin Bagley III'\n", " 'Mike Muscala' 'Aaron Nesmith' 'Jrue Holiday' 'Marcus Smart'\n", " 'Derrick Favors' 'Cam Thomas']\n" ] } ], "source": [ "players = x_player_test['Player'].values\n", "print(players[1:20])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are ready to select an NBA player that needs desperately his contract to be upgraded. What should he do in terms of stats to improve his salary?\n", "Let's take Naz Reid, who has just been named Sixth Man of the Year in season 2023/24 with the Minnesota Timberwolves.\n", "We want to see, for example, in terms of points scored ('PTS') what his contribution should be to upgrade his salary." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/1 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PosAgeGGSMPFGFGAFTFTAORBDRBTRBASTSTLBLKTOVPFPTSsalary
0C2277615.83.06.21.51.91.32.63.90.90.50.91.12.28.35m-
\n", "" ], "text/plain": [ " Pos Age G GS MP FG FGA FT FTA ORB DRB TRB AST STL BLK \\\n", "0 C 22 77 6 15.8 3.0 6.2 1.5 1.9 1.3 2.6 3.9 0.9 0.5 0.9 \n", "\n", " TOV PF PTS salary \n", "0 1.1 2.2 8.3 5m- " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Diverse Counterfactual set (new outcome: 25m+)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PosAgeGGSMPFGFGAFTFTAORBDRBTRBASTSTLBLKTOVPFPTSsalary
0-32---10.2---2.87.9---1.6--29.025m+
0-35---9.2---1.210.8--0.70.3--25.025m+
0-29---11.0---1.710.8---2.3--29.025m+
0-32---8.4---2.210.8--0.40.2--25.025m+
0-40---11.3---2.010.1--0.40.2--28.025m+
\n", "
" ], "text/plain": [ " Pos Age G GS MP FG FGA FT FTA ORB DRB TRB AST STL BLK TOV PF PTS \\\n", "0 - 32 - - - 10.2 - - - 2.8 7.9 - - - 1.6 - - 29.0 \n", "0 - 35 - - - 9.2 - - - 1.2 10.8 - - 0.7 0.3 - - 25.0 \n", "0 - 29 - - - 11.0 - - - 1.7 10.8 - - - 2.3 - - 29.0 \n", "0 - 32 - - - 8.4 - - - 2.2 10.8 - - 0.4 0.2 - - 25.0 \n", "0 - 40 - - - 11.3 - - - 2.0 10.1 - - 0.4 0.2 - - 28.0 \n", "\n", " salary \n", "0 25m+ \n", "0 25m+ \n", "0 25m+ \n", "0 25m+ \n", "0 25m+ " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "player_name = 'Naz Reid'\n", "\n", "x_player_test = x_player_test.reset_index(drop=True)\n", "x_test = x_test.reset_index(drop=True)\n", "player_row = x_player_test[x_player_test['Player'].str.contains(player_name)].iloc[0]\n", "player_row = player_row.drop(['Player'])\n", "player_index = x_test[x_test.eq(player_row).all(1)].index[0]\n", "\n", "# Generate counterfactuals for the player\n", "query_instances = x_test[player_index:player_index+1]\n", "genetic = exp_genetic.generate_counterfactuals(query_instances, \n", " total_CFs=3,\n", " #features_to_vary=['Age', 'PTS', 'ORB', 'DRB', 'STL', 'BLK'],\n", " desired_class=1\n", " )\n", "genetic.visualize_as_dataframe(show_only_changes=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And as a reference these are the league averages per single stat." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_14224/78113647.py:1: FutureWarning: The default value of numeric_only in DataFrame.mean is deprecated. In a future version, it will default to False. In addition, specifying 'numeric_only=None' is deprecated. Select only valid columns or specify the value of numeric_only to silence this warning.\n", " df.mean(axis=0).to_frame().T\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGGSMPFGFGAFTFTAORBDRBTRBASTSTLBLKTOVPFPTS
025.83741652.590226.58574621.7111363.5815147.80491.4757241.9066820.9260583.0325173.9603562.2113590.692650.4296211.1685971.7855239.747439
\n", "
" ], "text/plain": [ " Age G GS MP FG FGA FT \\\n", "0 25.837416 52.5902 26.585746 21.711136 3.581514 7.8049 1.475724 \n", "\n", " FTA ORB DRB TRB AST STL BLK \\\n", "0 1.906682 0.926058 3.032517 3.960356 2.211359 0.69265 0.429621 \n", "\n", " TOV PF PTS \n", "0 1.168597 1.785523 9.747439 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.mean(axis=0).to_frame().T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Considerations\n", "A few key points can be drawn from this simple analysis, both for the player, and the league itself. Let's see: it appears that Naz Reid, in order to secure himself a 25M+ contract, should clearly improve his scoring efficiency. His mere 8.3 points and 3.0 field goals per night are not enough to be considered an elite player, so he should focus on buckets! And it makes completely sense.\n", "But other interesting take-aways are actually coming from noticeable trends at league level. Let's see some of them:\n", "* the NBA seems to care more on defensive rebounds (DRB) than offensive (ORB), since our Naz Reid should grab 3x more DRB, while on the ORB side he would just need to double his stats;\n", "* steals (STL) and blocks (BLK) are not highly rewarded since his defensive attitude could work just like it is (and his stats are pretty close to the league average, maybe the blocks are a bit closer to a medium-high level player);\n", "* the league rewards mature players, and it actually makes sense: young players coming into the league have a capped salary until they reach 5 years of militance in NBA; only after that they can be granted a max-extension, if their performance is elite." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusions\n", "DiCE seems a pretty powerful tool to improve visibility in ML classification tasks. Not only it suggests personalized actions to fall into a different category, but also it can extract domain take-away points to analyze high-level trends.\n", "In conclusion, if some NBA player wants a better contract, he can send me his stats, and I can definitely help him improve his game!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 4 }