import os
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap
from matplotlib.colors import ListedColormap
from scipy.sparse import csr_matrix
from sklearn.manifold import TSNE
from sklearn.metrics import davies_bouldin_score, calinski_harabasz_score
from sklearn.metrics import silhouette_score
from STMiner.Algorithm.AlgUtils import get_exp_array
def _adjust_arr(arr, rotate, reverse_x, reverse_y):
if rotate:
arr = np.rot90(arr)
if reverse_y:
arr = np.flipud(arr)
if reverse_x:
arr = np.fliplr(arr)
return arr
def _get_figure(fig_count, num_cols):
num_rows = (fig_count + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 3 * num_rows))
# Disable the axis for each subplot
for ax in axes.flat:
ax.axis("off")
return axes, fig
def is_path(image_path):
return isinstance(image_path, str) and (
os.path.isfile(image_path) or os.path.isdir(image_path)
)
[docs]
class Plot:
def __init__(self, sp):
self.sp = sp
[docs]
def plot_gene(
self,
gene,
cmap="Spectral_r",
reverse_y=False,
reverse_x=False,
rotate=False,
figsize=(8, 6),
s=5,
log1p=False,
save_path="",
format="eps",
dpi=400,
vmax=99,
):
"""
Plots the spatial expression of a given gene on a scatter plot.
Args:
gene (str): The name of the gene to plot.
cmap (str, optional): Colormap to use for gene expression. Default is 'Spectral_r'.
reverse_y (bool, optional): If True, reverse the y-axis. Default is False.
reverse_x (bool, optional): If True, reverse the x-axis. Default is False.
rotate (bool, optional): If True, rotate the plot by 90 degrees. Default is False.
figsize (tuple, optional): Figure size in inches (width, height). Default is (8, 6).
s (int or None, optional): Size of the scatter plot points. If None, uses default size. Default is 5.
log1p (bool, optional): If True, apply log1p transformation to expression values. Default is False.
save_path (str, optional): Directory path to save the plot. If empty, the plot is not saved. Default is ''.
format (str, optional): File format for saving the plot (e.g., 'eps', 'png'). Default is 'eps'.
dpi (int, optional): Dots per inch for saved figure. Default is 400.
vmax (float, optional): Percentile value to set the upper limit of the color scale. Default is 99.
Returns:
None
Side Effects:
- Displays the plot using matplotlib.
- Saves the plot to the specified path if `save_path` is provided.
"""
global_matrix = self.get_global_matrix(reverse_x, reverse_y, rotate)
plt.figure(figsize=figsize)
sns.scatterplot(
x=global_matrix.nonzero()[0],
y=global_matrix.nonzero()[1],
s=s,
color="#ced4da",
edgecolor="none",
)
arr = get_exp_array(self.sp.adata, gene)
arr = _adjust_arr(arr, rotate, reverse_x, reverse_y)
if log1p:
arr = np.log1p(arr)
sparse_matrix = csr_matrix(arr)
if s is not None:
ax = sns.scatterplot(
x=sparse_matrix.nonzero()[0],
y=sparse_matrix.nonzero()[1],
c=sparse_matrix.data,
s=s,
edgecolor="none",
vmax=np.percentile(sparse_matrix.data, vmax),
cmap=cmap,
)
else:
ax = sns.scatterplot(
x=sparse_matrix.nonzero()[0],
y=sparse_matrix.nonzero()[1],
c=sparse_matrix.data,
edgecolor="none",
vmax=np.percentile(sparse_matrix.data, vmax),
cmap=cmap,
)
ax.set_axis_off()
ax.set_title(gene)
if is_path(save_path):
fig = ax.get_figure()
if save_path[-1] != "/":
save_path += "/"
save_path += gene
save_path += "." + format
fig.savefig(fname=save_path, dpi=dpi, format=format, bbox_inches="tight")
plt.show()
def get_global_matrix(self, reverse_x, reverse_y, rotate):
adata = self.sp.adata
expression_data = np.array(adata.X.sum(axis=1)).flatten()
row_indices = np.array(adata.obs["x"].values).flatten()
column_indices = np.array(adata.obs["y"].values).flatten()
global_matrix = csr_matrix((expression_data, (row_indices, column_indices)))
global_matrix = _adjust_arr(
global_matrix.todense(), rotate, reverse_x, reverse_y
)
global_matrix = csr_matrix(global_matrix)
return global_matrix
def plot_genes(
self,
label=None,
gene_list=None,
n_gene=12,
cmap=None,
num_cols=4,
vmax=99,
vmin=0,
rotate=False,
reverse_y=False,
reverse_x=False,
plot_type="scatter",
s=1,
):
result = self.sp.genes_labels
adata = self.sp.adata
if gene_list is None:
if label is None or result is None:
raise "Error: Parameter [label] and [result] should not be None."
else:
gene_list = list(result[result["labels"] == label]["gene_id"])[:n_gene]
genes_count = len(gene_list)
axes, fig = _get_figure(genes_count, num_cols)
fig.subplots_adjust(hspace=0.5)
for i, gene in enumerate(gene_list):
row = i // num_cols
col = i % num_cols
if len(axes.shape) == 1:
ax = axes[i]
else:
ax = axes[row, col]
arr = get_exp_array(adata, gene)
arr = _adjust_arr(arr, rotate, reverse_x, reverse_y)
sns.set(style="white")
if cmap is None:
cmap = sns.color_palette("viridis", as_cmap=True)
if plot_type == "heatmap":
sns.heatmap(
arr,
cbar=False,
ax=ax,
cmap=cmap,
vmax=np.percentile(arr, vmax),
vmin=np.percentile(arr, vmin),
)
elif plot_type == "scatter":
sparse_matrix = csr_matrix(arr)
sns.scatterplot(
x=sparse_matrix.nonzero()[0],
y=sparse_matrix.nonzero()[1],
ax=ax,
c=sparse_matrix.data,
cmap=cmap,
s=s,
)
ax.set_axis_off()
ax.set_title(gene)
plt.tight_layout()
plt.show()
[docs]
def plot_pattern(
self,
cmap=None,
vmax=99,
num_cols=4,
rotate=False,
reverse_y=False,
reverse_x=False,
heatmap=False,
s=1,
image_path=None,
rotate_img=False,
k=1,
aspect=1,
output_path=None,
plot_bg=False,
):
"""
Plots spatial patterns for each label in the dataset as either heatmaps or scatter plots.
Args:
cmap (str or matplotlib colormap, optional): Colormap to use for plotting. Defaults to "viridis" if None.
vmax (float, optional): Percentile value to use as the maximum value for color scaling. Default is 99.
num_cols (int, optional): Number of columns in the subplot grid. Default is 4.
rotate (bool, optional): Whether to rotate the pattern matrices. Default is False.
reverse_y (bool, optional): Whether to reverse the y-axis of the pattern matrices. Default is False.
reverse_x (bool, optional): Whether to reverse the x-axis of the pattern matrices. Default is False.
heatmap (bool, optional): If True, plot patterns as heatmaps; otherwise, use scatter plots. Default is False.
s (int or float, optional): Size of scatter plot points. Default is 1.
image_path (str, optional): Path to a background image to display under the scatter plot. Default is None.
rotate_img (bool, optional): Whether to rotate the background image. Default is False.
k (int, optional): Number of 90-degree rotations to apply to the background image if rotate_img is True. Default is 1.
aspect (float, optional): Aspect ratio for the background image. Default is 1.
output_path (str, optional): If provided, saves the plot to this path in EPS format. Default is None.
plot_bg (bool, optional): Whether to plot the global background points in gray. Default is False.
Returns:
None. Displays the generated plots and optionally saves them to a file.
"""
result = self.sp.genes_labels
label_list = set(result["labels"])
plot_count = len(label_list)
axes, fig = _get_figure(plot_count, num_cols=num_cols)
fig.subplots_adjust(hspace=0.5)
global_matrix = self.get_global_matrix(reverse_x, reverse_y, rotate)
for i, label in enumerate(label_list):
row = i // num_cols
col = i % num_cols
if len(axes.shape) == 1:
ax = axes[i]
else:
ax = axes[row, col]
total_count = self.sp.patterns_matrix_dict[label]
total_count = _adjust_arr(total_count, rotate, reverse_x, reverse_y)
if heatmap:
sns.heatmap(
total_count,
ax=ax,
cbar=False,
cmap=cmap if cmap is not None else "viridis",
vmax=np.percentile(total_count, vmax),
)
else:
if is_path(image_path):
bg_img = mpimg.imread(image_path)
if rotate_img:
bg_img = np.rot90(bg_img, k=k)
ax.imshow(
bg_img,
extent=[0, total_count.shape[0], 0, total_count.shape[1]],
aspect=aspect,
)
if plot_bg:
# plt.figure(figsize=figsize)
sns.scatterplot(
x=global_matrix.nonzero()[0],
y=global_matrix.nonzero()[1],
s=s,
color="#ced4da",
)
sparse_matrix = csr_matrix(total_count)
sns.scatterplot(
x=sparse_matrix.nonzero()[0],
y=sparse_matrix.nonzero()[1],
c=sparse_matrix.data,
ax=ax,
cmap=cmap if cmap is not None else "viridis",
s=s,
vmax=np.percentile(sparse_matrix.data, vmax),
edgecolor="none",
)
ax.set_xlim(0, total_count.shape[0])
ax.set_ylim(0, total_count.shape[1])
ax.set_title("Pattern " + str(label))
if output_path is not None:
plt.savefig(output_path, dpi=400, format="eps")
plt.tight_layout()
plt.show()
[docs]
def plot_intersection(
self,
pattern_list,
cmap=None,
s=None,
rotate=False,
reverse_x=False,
reverse_y=False,
figsize=(12, 8),
image_path=None,
rotate_img=False,
plot_bg=True,
k=1,
bgs=5,
aspect=1,
):
"""
Plots the intersection of multiple patterns as a scatter plot, optionally overlaying a background image and global pattern matrix.
Args:
pattern_list (list): List of pattern names/keys to intersect and plot.
cmap (matplotlib.colors.Colormap, optional): Colormap to use for the intersection points. Defaults to a preset colormap.
s (float or array-like, optional): Size of the scatter plot points. Defaults to 10.
rotate (bool, optional): Whether to rotate the intersection matrix by 90 degrees. Defaults to False.
reverse_x (bool, optional): Whether to reverse the x-axis. Defaults to False.
reverse_y (bool, optional): Whether to reverse the y-axis. Defaults to False.
figsize (tuple, optional): Figure size for the plot. Defaults to (12, 8).
image_path (str, optional): Path to a background image to overlay. If None, no image is shown.
rotate_img (bool, optional): Whether to rotate the background image by 90 degrees. Defaults to False.
plot_bg (bool, optional): Whether to plot the global pattern matrix as a background. Defaults to True.
k (int, optional): Number of times to rotate the background image by 90 degrees. Defaults to 1.
bgs (float, optional): Size of the background scatter plot points. Defaults to 5.
aspect (float, optional): Aspect ratio for the background image. Defaults to 1.
Returns:
None: Displays the plot using matplotlib.
"""
sum_array = np.zeros(self.sp.patterns_binary_matrix_dict[pattern_list[0]].shape)
flag = 1
for i in pattern_list:
sum_array += np.where(self.sp.patterns_binary_matrix_dict[i] > 0, flag, 0)
flag += 1
sum_array = _adjust_arr(
sum_array, rotate=rotate, reverse_x=reverse_x, reverse_y=reverse_y
)
sparse_matrix = csr_matrix(sum_array)
plt.figure(figsize=figsize)
sns.set_style("white")
if is_path(image_path):
bg_img = mpimg.imread(image_path)
if rotate_img:
bg_img = np.rot90(bg_img, k=k)
plt.imshow(
bg_img,
extent=[0, sum_array.shape[0], 0, sum_array.shape[1]],
aspect=aspect,
)
if plot_bg:
global_matrix = self.get_global_matrix(reverse_x, reverse_y, rotate)
sns.scatterplot(
x=global_matrix.nonzero()[0],
y=global_matrix.nonzero()[1],
s=bgs,
color="#ced4da",
edgecolor="none",
)
default_cmap = ListedColormap(["#06d6a0", "#fb8500", "#ff006e"])
sns.scatterplot(
x=sparse_matrix.nonzero()[0],
y=sparse_matrix.nonzero()[1],
c=sparse_matrix.data,
cmap=cmap if cmap is not None else default_cmap,
s=s if s is not None else 10,
edgecolor="none",
)
plt.axis("off")
def plot_cluster_score(self, mds_comp, min_cluster, max_cluster):
db_dict = {}
ch_dict = {}
si_dict = {}
for cluster_number in range(min_cluster, max_cluster + 1):
self.sp.cluster_gene(
self, n_clusters=cluster_number, mds_components=mds_comp
)
db_dict[cluster_number] = 1 / davies_bouldin_score(
self.sp.genes_distance_array, self.sp.kmeans_fit_result.labels_
)
ch_dict[cluster_number] = calinski_harabasz_score(
self.sp.genes_distance_array, self.sp.kmeans_fit_result.labels_
)
si_dict[cluster_number] = silhouette_score(
self.sp.genes_distance_array, self.sp.kmeans_fit_result.labels_
)
score_df = pd.DataFrame(
[db_dict, si_dict, ch_dict],
index=["1/Davies-Bouldin", "Silhouette", "Calinski-Harabasz"],
).T
norm_score_df = (score_df - score_df.min()) / (score_df.max() - score_df.min())
sns.lineplot(norm_score_df, markers=True)
plt.xticks(list(range(min_cluster, max_cluster + 1, 1)))
plt.title("Evaluate Clustering Performance")
plt.xlabel("Number of Clusters")
plt.ylabel("Normalized Score")
plt.show()
def plot_tsne(self, method="tsne", s=10, show_bar=False):
n_clusters = len(set(self.sp.kmeans_fit_result.labels_))
if n_clusters <= 10:
cmap = "tab10"
elif 10 < n_clusters & n_clusters <= 20:
cmap = "tab20"
else:
cmap = "viridis"
if method == "tsne":
tsne = TSNE(n_components=2)
embedded_data = tsne.fit_transform(self.sp.mds_features)
else:
umap_model = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2)
embedded_data = umap_model.fit_transform(self.sp.mds_features)
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
embedded_data[:, 0],
embedded_data[:, 1],
c=self.sp.kmeans_fit_result.labels_,
cmap=cmap,
s=s,
)
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.grid(False)
plt.xticks([])
plt.yticks([])
if show_bar:
plt.colorbar(scatter, label="Labels")
plt.show()