from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline
def dist_plotting(dataset):
numerical_set = []
categorical_set = []
last_col = dataset.columns[-1]
for col in dataset:
if np.issubdtype(dataset[col].dtype, np.number):
numerical_set.append(col)
else:
categorical_set.append(col)
count_categorical = len(categorical_set)
count_numerical = len(numerical_set)
for n in range(2, count_numerical):
if n * n >= count_numerical:
break
fig = plt.figure(figsize=(18, 18))
i = 1
for col in numerical_set:
axes = fig.add_subplot(n, n, i)
if col == last_col:
sns.histplot(x=dataset[col], kde=True, color='salmon')
else:
sns.histplot(x=dataset[col], kde=True)
i += 1
fig = plt.figure(figsize=(18, 18*n/2))
i = 1
for col in categorical_set:
axes = fig.add_subplot(count_categorical, 1, i)
sorted_counts = dataset[col].value_counts(
).sort_values(ascending=False).index
if col == last_col:
sns.countplot(data=dataset, y=col, palette="YlOrBr_r",
orient='h', order=sorted_counts)
else:
sns.countplot(data=dataset, y=col, palette="Blues_r",
orient='h', order=sorted_counts)
i += 1