Source code for smoder.models.deconvolution

import scanpy as sc
import numpy as np
import random  # set random seed
from sklearn.decomposition import PCA, TruncatedSVD  # LSI所需TruncatedSVD
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
import matplotlib.pyplot as plt
import pandas as pd
import anndata as ad
import os
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import sparse
import warnings
warnings.filterwarnings('ignore', message='DataFrame is highly fragmented')

# 导入外部模块
from smoder.preprocessing import rna as pr
from smoder.preprocessing import modality2 as pad  # ADT/Peak预处理模块
from smoder.models.gnn import *


[docs] def selectInfoGenes(Basis, sc_adata, commonGene, ct_select, ct_varname, log_FC=1.25, top_n_filter=False, top_n=200): """ 选择信息基因的函数(新增top_n筛选选项) 参数: Basis: 基础表达矩阵,行是基因,列是细胞类型 sc_adata: scanpy的AnnData对象 commonGene: 共同基因列表 ct_select: 需要筛选的细胞类型列表 ct_varname: 细胞类型在sc_adata.obs中的列名 log_FC: 原有的log倍变化阈值,默认1.25 top_n_filter: 是否启用top_n筛选模式,默认False(使用原有log_FC筛选) top_n: 每种细胞类型筛选的差异基因数量,默认2000 返回: gene2: 最终筛选出的信息基因列表 """ import numpy as np import pandas as pd from scipy import sparse # 第一步:筛选差异表达基因(新增top_n筛选逻辑) gene1_list = [] for ict in ct_select: rest_columns = [col for col in Basis.columns if col != ict] rest = Basis[rest_columns].mean(axis=1) FC = np.log(Basis[ict] + 1e-06) - np.log(rest + 1e-06) if top_n_filter: # 启用top_n筛选:取FC最大的前top_n个基因(且表达量>0) # 先筛选表达量>0的基因 valid_genes = Basis.index[Basis[ict] > 0] FC_valid = FC.loc[valid_genes] # 按FC降序排序,取前top_n个 selected_genes = FC_valid.sort_values(ascending=False).head(top_n).index else: # 原有逻辑:按log_FC阈值筛选 selected_genes = Basis.index[(FC > log_FC) & (Basis[ict] > 0)] gene1_list.extend(selected_genes) print(f"细胞类型 {ict} 筛选出 {len(selected_genes)} 个基因") # 去重并与commonGene取交集 gene1 = list(set(gene1_list)) gene1 = list(set(gene1) & set(commonGene)) # 第二步:获取counts数据(原有逻辑不变) counts = sc_adata.X if sparse.issparse(counts): counts = counts.toarray() counts = counts.T counts_filtered = counts[np.isin(sc_adata.var_names, gene1), :] # 获取细胞类型信息 cell_types = sc_adata.obs[ct_varname] ct_counts = cell_types.value_counts() ct_select_filtered = ct_counts[ct_counts > 1].index.tolist() # 第三步:计算within-group变异(原有逻辑不变) sd_within = pd.DataFrame(index=np.where(np.isin(sc_adata.var_names, gene1))[0]) for ict in ct_select_filtered: mask = (cell_types == ict) temp = counts_filtered[:, mask] variance = np.var(temp, axis=1) mean_val = np.mean(temp, axis=1) ratio = np.divide(variance, mean_val, out=np.zeros_like(variance), where=mean_val != 0) sd_within[ict] = ratio mean_ratio = sd_within.mean(axis=1, skipna=True) threshold = np.quantile(mean_ratio, 0.99) gene2_indices = mean_ratio[mean_ratio < threshold].index gene2 = [sc_adata.var_names[i] for i in gene2_indices] return gene2
[docs] class SpaMultiDecon_two_modals: def __init__(self, ref_adata_dict = {'modal1': None}, smo_adata_dict = {'modal1': None, 'modal2': None}, nn_para_dict={}, modal2_type='adt'): ### 初始化参数:核心调整(第二模态仅一个维度参数 modal2_target_dim) if nn_para_dict == {}: nn_para_dict = { # 原始基础参数 'epochs': 1000, 'learning_rate': 1e-3, 'device': 'cuda:0', 'hidden_dim': 64, 'weight_loss': [1, 0.1], 'seed': 1, 'weight_spatial': 1, 'weight_consistency': 0.5, # 降维相关参数:仅一个参数管控第二模态最终维度 'pca_n_components_rna': 256, # RNA模态PCA最终维度 'modal2_target_dim': 256, # 第二模态(ADT/Peak)最终目标维度(仅一个输入参数) 'lsi_random_state': 1 } self.smo_adata_dict = smo_adata_dict self.ref_adata_dict = ref_adata_dict self.nn_para_dict = nn_para_dict self.modal2_type = modal2_type.lower() assert self.modal2_type in ['adt', 'peak'], "modal2_type仅支持 'adt' 或 'peak'!" # 标记:是否已完成特征工程(含模型输入映射) self.feature_engineering_done = False ### 属性装饰器:核心调整(第二模态仅暴露一个 target_dim 属性) # 原始属性(保持不变) @property def weight_loss(self): return self.nn_para_dict.get('weight_loss', [1, 0.001]) @weight_loss.setter def weight_loss(self, value): self.nn_para_dict['weight_loss'] = value @property def hidden_dim(self): return self.nn_para_dict.get('hidden_dim', 64) @hidden_dim.setter def hidden_dim(self, value): self.nn_para_dict['hidden_dim'] = value @property def learning_rate(self): return self.nn_para_dict.get('learning_rate', 2e-3) @learning_rate.setter def learning_rate(self, value): self.nn_para_dict['learning_rate'] = value @property def device(self): return self.nn_para_dict.get('device', 'cuda:0') @device.setter def device(self, value): self.nn_para_dict['device'] = value @property def epochs(self): return self.nn_para_dict.get('epochs', 1000) @epochs.setter def epochs(self, value): self.nn_para_dict['epochs'] = value @property def seed(self): return self.nn_para_dict.get('seed', 1) @seed.setter def seed(self, value): self.nn_para_dict['seed'] = value @property def weight_spatial(self): return self.nn_para_dict.get('weight_spatial', 1e-5) @weight_spatial.setter def weight_spatial(self, value): self.nn_para_dict['weight_spatial'] = value @property def weight_consistency(self): return self.nn_para_dict.get('weight_consistency', 0.5) @weight_consistency.setter def weight_consistency(self, value): self.nn_para_dict['weight_consistency'] = value # 降维相关属性:核心简化(仅一个参数管控第二模态) @property def pca_n_components_rna(self): return self.nn_para_dict.get('pca_n_components_rna', 256) @pca_n_components_rna.setter def pca_n_components_rna(self, value): assert value > 0, "pca_n_components_rna必须大于0!" self.nn_para_dict['pca_n_components_rna'] = value @property def modal2_target_dim(self): """第二模态(ADT/Peak)最终目标维度(仅一个输入参数,统一管控)""" return self.nn_para_dict.get('modal2_target_dim', 256) @modal2_target_dim.setter def modal2_target_dim(self, value): assert value > 0, "modal2_target_dim必须大于0!" self.nn_para_dict['modal2_target_dim'] = value @property def lsi_random_state(self): return self.nn_para_dict.get('lsi_random_state', 1) @lsi_random_state.setter def lsi_random_state(self, value): self.nn_para_dict['lsi_random_state'] = value # 批量设置方法:同步简化,仅保留 modal2_target_dim
[docs] def set_nn_para(self, **kwargs): valid_keys = [ 'epochs', 'learning_rate', 'device', 'hidden_dim', 'seed', 'weight_loss', 'weight_consistency','weight_spatial', 'pca_n_components_rna', 'modal2_target_dim', 'lsi_random_state' ] for key, value in kwargs.items(): if key in valid_keys: setattr(self, key, value) else: print(f"警告: 忽略无效参数 '{key}'")
### 1. 预处理方法(仅完成数据清洗/归一化,无降维,保持不变)
[docs] def preprocess(self, ref_celltype_col, sample_id_col, ct_select, log_FC=1.25, top_n_filter=False, top_n=200, do_select_info_genes=True): """ 仅完成数据预处理(清洗、归一化、位点对齐),不执行特征工程(降维) 流程:RNA数据预处理 + 模态2(ADT/Peak)数据预处理(TF-IDF/CLR) + 位点对齐 """ print(f"="*60) print(f"开始数据预处理(modal2类型: {self.modal2_type.upper()})") print(f"="*60) # Step 1: RNA参考数据预处理(保持原逻辑) print(f"[Step 1/3] 处理参考RNA数据...") sc_shape_before_filter = self.ref_adata_dict['modal1'].shape adata_sc_filtered = pr.process_single_cell(self.ref_adata_dict['modal1'], ref_celltype_col) print(f"参考RNA数据形状:预处理前 {sc_shape_before_filter} → 预处理后 {adata_sc_filtered.shape}") # Step 2: 空间RNA数据预处理(保持原逻辑) print(f"[Step 2/3] 处理空间RNA数据...") adata_st_filtered = self.smo_adata_dict['modal1'].copy() adata_st_filtered.var_names_make_unique() common_genes = adata_sc_filtered.var_names.intersection(adata_st_filtered.var_names) print(f"空间与参考RNA共同基因数量: {len(common_genes)}") adata_sc_common = adata_sc_filtered[:, common_genes].copy() sc.pp.normalize_total(adata_sc_common, target_sum=1e4) # 信息基因筛选 marker_genes = common_genes if do_select_info_genes: print(f"筛选信息基因...") temp_Basis = pr.calculate_celltype_averages( adata_sc_norm=adata_sc_common, celltype_col=ref_celltype_col, sample_id_col=sample_id_col, target_sum=1, normalize=True ) marker_genes = selectInfoGenes(temp_Basis.T, adata_sc_common, common_genes, ct_select, ref_celltype_col, log_FC, top_n_filter, top_n) print(f"信息基因筛选完成,共 {len(marker_genes)} 个") # 构建基础矩阵 adata_sc_marker = adata_sc_common[:, marker_genes].copy() Basis = pr.calculate_celltype_averages( adata_sc_norm=adata_sc_marker, celltype_col=ref_celltype_col, sample_id_col=sample_id_col, target_sum=1, normalize=True ) self.ref_adata_dict['basis_matrix'] = Basis # 空间RNA数据最终处理 adata_gene_norm = pr.process_spatial_data( adata_st_filtered, common_genes=common_genes, hvgs=marker_genes ) print(f"空间RNA数据预处理完成,形状: {adata_gene_norm.shape}") # Step 3: 模态2(ADT/Peak)数据预处理(仅清洗/归一化,不降维) print(f"[Step 3/3] 处理模态2({self.modal2_type.upper()})数据(仅归一化,不执行降维)...") adata_modal2_raw = self.smo_adata_dict['modal2'].copy() if self.modal2_type == 'adt': # ADT:仅执行CLR标准化(预处理) adata_modal2_norm = pad.process_spatial_adt_data(adata_modal2_raw) print(f"ADT数据预处理完成(CLR标准化),形状: {adata_modal2_norm.shape}") elif self.modal2_type == 'peak': # Peak:仅执行TF-IDF归一化(预处理) adata_modal2_norm = pad.process_peak_data(adata_modal2_raw) print(f"Peak数据预处理完成(TF-IDF归一化),形状: {adata_modal2_norm.shape}") # Step 4: 空间位点对齐(保持原逻辑) common_spots = adata_gene_norm.obs_names.intersection(adata_modal2_norm.obs_names) adata_gene_norm = adata_gene_norm[common_spots, :].copy() adata_modal2_norm = adata_modal2_norm[common_spots, :].copy() num_spots = len(common_spots) print(f"空间位点对齐完成,共保留 {num_spots} 个共同位点") # 对齐基因(仅RNA) common_genes = adata_gene_norm.var_names.intersection(Basis.columns) adata_gene_norm = adata_gene_norm[:, common_genes].copy() print(f"RNA基因对齐完成,共保留 {len(common_genes)} 个共同基因") # 存储预处理后的数据(未执行特征工程) self.smo_adata_dict['modal1'] = adata_gene_norm self.smo_adata_dict['modal2'] = adata_modal2_norm self.feature_engineering_done = False # 未执行特征工程 print(f"="*60) print(f"数据预处理全部完成!") print(f"="*60)
### 2. 核心合并:特征工程 + 模型输入映射(单方法完成,第二模态单维度管控)
[docs] def run_feature_engineering_and_mapping(self, obsm_name_rna=None, obsm_name_modal2=None): # 默认键名(兼容原有逻辑) obsm_name_rna = obsm_name_rna or "X_feat_rna" obsm_name_modal2 = obsm_name_modal2 or "X_feat_modal2" # 1. 处理modal1:RNA模态(PCA降维) adata_rna = self.smo_adata_dict["modal1"] scaler = StandardScaler() mat_rna_norm = scaler.fit_transform(adata_rna.X if not sparse.issparse(adata_rna.X) else adata_rna.X.toarray()) pca_rna = PCA(n_components=min(self.pca_n_components_rna, mat_rna_norm.shape[1])) adata_rna.obsm["X_pca_rna"] = pca_rna.fit_transform(mat_rna_norm) adata_rna.obsm[obsm_name_rna] = adata_rna.obsm["X_pca_rna"] # 2. 处理modal2:Peak/ADT模态 —— 固定用 256 做判断 adata_modal2 = self.smo_adata_dict["modal2"] modal2_original_dim = adata_modal2.shape[1] target_dim = 256 # 固定用256判断 print(f"模态二原始维度: {modal2_original_dim}, 判断阈值: {target_dim}") if self.modal2_type == "peak": if not sparse.issparse(adata_modal2.X): adata_modal2.X = sparse.csr_matrix(adata_modal2.X) else: adata_modal2.X = adata_modal2.X.tocsr() if modal2_original_dim <= target_dim: # 不做LSI降维,直接用原始数据 feat_out = adata_modal2.X.toarray() print(f"维度≤256,直接使用原始数据") else: # LSI 降到 256 svd = TruncatedSVD(n_components=target_dim + 1, random_state=self.lsi_random_state, algorithm='arpack') emb_full = svd.fit_transform(adata_modal2.X) feat_out = emb_full[:, 1:] print(f"维度>256,LSI降维到目标维度") adata_modal2.obsm["X_lsi_peak"] = feat_out adata_modal2.obsm[obsm_name_modal2] = feat_out elif self.modal2_type == "adt": mat_adt = adata_modal2.X if not sparse.issparse(adata_modal2.X) else adata_modal2.X.toarray() scaler = StandardScaler() mat_adt_norm = scaler.fit_transform(mat_adt) if modal2_original_dim <= target_dim: # 不做PCA,直接用标准化后的数据 feat_out = mat_adt_norm print(f"维度≤256,直接使用标准化后数据") else: # PCA 降到 256 pca_adt = PCA(n_components=target_dim) feat_out = pca_adt.fit_transform(mat_adt_norm) print(f"维度>256,PCA降维到目标维度") adata_modal2.obsm["X_pca_adt"] = feat_out adata_modal2.obsm[obsm_name_modal2] = feat_out # 3. 保存并标记完成 self.smo_adata_dict["modal1"] = adata_rna self.smo_adata_dict["modal2"] = adata_modal2 self.feature_engineering_done = True
### 3. 后续方法(保持不变,适配合并后的流程)
[docs] def create_spatialgraph(self, obsm_spatial, K_neighbors = 8): if not self.feature_engineering_done: warnings.warn("请先完成特征工程+输入映射,再构建空间图!", UserWarning) coords = self.smo_adata_dict['modal1'].obsm[obsm_spatial] adj = kneighbors_graph(coords, K_neighbors, mode='connectivity') edge_index = np.array(np.nonzero(adj)) print(f"空间图构建完成,存储在 smo_adata_dict['modal1'].uns['spatial_graph']") self.smo_adata_dict['modal1'].uns['spatial_graph'] = edge_index
[docs] def get_laplacian(self, normalize=True): edge_index = self.smo_adata_dict['modal1'].uns.get('spatial_graph') if edge_index is None: raise ValueError("请先调用create_spatialgraph()生成空间图") edge_index = torch.tensor(edge_index, dtype=torch.long) if not isinstance(edge_index, torch.Tensor) else edge_index device = edge_index.device num_spots = self.smo_adata_dict['modal1'].shape[0] A = torch.zeros((num_spots, num_spots), device=device) A[edge_index[0], edge_index[1]] = 1.0 A[edge_index[1], edge_index[0]] = 1.0 D = torch.diag(A.sum(dim=1)).to(device) I = torch.eye(num_spots, device=device) if normalize: D_sqrt_inv = torch.inverse(torch.sqrt(D + 1e-8)) L = I - D_sqrt_inv @ A @ D_sqrt_inv else: L = D - A return L, A, D
[docs] def create_featuregraph(self, obsm_name='X_pca', K_neighbors = 8, name = 'namemod'): if not isinstance(self.smo_adata_dict, dict): raise ValueError("smo_adata_dict must be a dictionary of AnnData objects.") for key, adata in self.smo_adata_dict.items(): if not isinstance(adata, ad.AnnData): warnings.warn( f"Skipping key '{key}': Expected AnnData object, got {type(adata)}.", UserWarning ) continue if obsm_name not in adata.obsm: warnings.warn( f"Skipping key '{key}': obsm '{obsm_name}' not found in AnnData object.", UserWarning ) continue adj = kneighbors_graph(adata.obsm[obsm_name], K_neighbors, mode='connectivity') edge_index = np.array(np.nonzero(adj)) print(f"为 {key} 构建特征图完成,存储在 adata.uns['{name}']") adata.uns[name] = edge_index
[docs] def train(self, embd_dim=50, method=1, name1='name1', name2='name2', obsm_name_rna='X_pca', obsm_name_adt='X_pca', plot=True, model_save=False, model_save_dir="trained_models"): """ 核心特点:输入维度为 modal2_target_dim(第二模态单维度),兼容合并后的流程 """ if not self.feature_engineering_done: raise ValueError("请先调用 run_feature_engineering_and_mapping() 完成特征工程+输入映射!") ### 1. 固定随机种子(保持原逻辑) device = self.nn_para_dict['device'] seed = self.nn_para_dict['seed'] random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(seed) ### 2. 数据准备(保持原逻辑,输入维度已正确) gene_counts = self.smo_adata_dict['modal1'].layers['counts'] gene_counts = gene_counts.toarray() if hasattr(gene_counts, 'toarray') else gene_counts gene_spot_counts_sum = np.sum(gene_counts, axis=1, keepdims=True) gene_pca = self.smo_adata_dict['modal1'].obsm[obsm_name_rna] modal2_pca = self.smo_adata_dict['modal2'].obsm[obsm_name_adt] # 空间图检查 if 'spatial_graph' not in self.smo_adata_dict['modal1'].uns: raise ValueError("spatial_graph not found! Run create_spatialgraph() first.") g_spatial = torch.tensor(self.smo_adata_dict['modal1'].uns['spatial_graph'], dtype=torch.long, device=device) # 特征图检查 if name1 not in self.smo_adata_dict['modal1'].uns or name2 not in self.smo_adata_dict['modal2'].uns: raise ValueError("feature_graph not found! Run create_featuregraph() first.") g_feature_modal1 = torch.tensor(self.smo_adata_dict['modal1'].uns[name1], dtype=torch.long, device=device) g_feature_modal2 = torch.tensor(self.smo_adata_dict['modal2'].uns[name2], dtype=torch.long, device=device) # 数据转tensor gene_pca_tensor = torch.tensor(gene_pca, dtype=torch.float, device=device) modal2_pca_tensor = torch.tensor(modal2_pca, dtype=torch.float, device=device) gene_recon_target = gene_pca_tensor modal2_recon_target = modal2_pca_tensor gene_counts_tensor = torch.tensor(gene_counts, dtype=torch.float, device=device) gene_spot_sum_tensor = torch.tensor(gene_spot_counts_sum, dtype=torch.float, device=device) cell_type_gene_tensor = torch.tensor(self.ref_adata_dict['basis_matrix'].values, dtype=torch.float, device=device) num_cell_types = cell_type_gene_tensor.shape[0] ### 3. 模型参数配置(保持原逻辑,输入维度自动适配) hidden_dim = self.nn_para_dict['hidden_dim'] embedding_dim = embd_dim in_fea_dim_modal1 = gene_pca.shape[1] in_fea_dim_modal2 = modal2_pca.shape[1] print(f"="*60) print(f"模型输入维度确认:") print(f" - RNA模态:{in_fea_dim_modal1} 维") print(f" - 第二模态({self.modal2_type.upper()}):{in_fea_dim_modal2} 维(= modal2_target_dim)") print(f"="*60) ### 4. 初始化模型组件(保持原逻辑) if method == 1: cross_fusion = GCN_between_attention_encoder( in_fea_dim_modal1, in_fea_dim_modal2, hidden_dim, embedding_dim ).to(device) modal1_encoder = cross_fusion.encode_modal1 modal2_encoder = cross_fusion.encode_modal2 elif method == 2: cross_fusion = GCN_between_within_attention_encoder( in_fea_dim_modal1, in_fea_dim_modal2, hidden_dim, embedding_dim ).to(device) modal1_encoder = cross_fusion.encode_modal1 modal2_encoder = cross_fusion.encode_modal2 else: raise ValueError(f"Unsupported method {method}! Use 1 or 2.") reconstructor = ExpressionReconstructor( embedding_dim, pca_dim_rna=in_fea_dim_modal1, pca_dim_adt=in_fea_dim_modal2 ).to(device) proportion_predictor = CellTypeProportionPredictor(embedding_dim, num_cell_types).to(device) consistency_verifier = ConsistencyVerifier( modal1_encoder=modal1_encoder, modal2_encoder=modal2_encoder, reconstructor=reconstructor ).to(device) ### 负二项分布参数(保持原逻辑) gene_theta = nn.Parameter(torch.ones(gene_counts.shape[1], device=device)) gene_spot_offset = nn.Parameter(torch.zeros(gene_counts.shape[0], device=device)) gene_gene_offset = nn.Parameter(torch.zeros(gene_counts.shape[1], device=device)) ### 5. 优化器(保持原逻辑) optimizer = torch.optim.Adam( list(cross_fusion.parameters()) + list(reconstructor.parameters()) + list(proportion_predictor.parameters()) + list(consistency_verifier.parameters()) + [gene_theta, gene_spot_offset, gene_gene_offset], lr=self.nn_para_dict['learning_rate'], weight_decay=1e-5 ) ### 6. 计算拉普拉斯矩阵(保持原逻辑) L, _, _ = self.get_laplacian(normalize=True) L = L.to(device) ### 7. 训练流程(保持原逻辑) epochs = self.nn_para_dict['epochs'] weight_nb, weight_recon = self.weight_loss weight_consistency = self.weight_consistency weight_spatial = self.weight_spatial # 初始化损失记录 total_losses = [] gene_nb_losses = [] recon_losses = [] consistency_losses = [] spatial_losses = [] for epoch in range(epochs): # 所有组件设为train模式 cross_fusion.train() reconstructor.train() proportion_predictor.train() consistency_verifier.train() # 前向传播 if method == 1: embedding, y1, y2, attn_gene2other, attn_other2gene = cross_fusion( g_spatial_omics1=g_spatial, feat_omics1=gene_pca_tensor, g_spatial_omics2=g_spatial, feat_omics2=modal2_pca_tensor ) elif method == 2: embedding, y1, y2, attn_gene2other, attn_other2gene = cross_fusion( g_spatial_omics1=g_spatial, g_feature_omics1=g_feature_modal1, feat_omics1=gene_pca_tensor, g_spatial_omics2=g_spatial, g_feature_omics2=g_feature_modal2, feat_omics2=modal2_pca_tensor ) ### 计算各损失项 # 1. 重构损失 gene_recon, modal2_recon = reconstructor(embedding) gene_recon_loss = F.mse_loss(gene_recon, gene_recon_target) modal2_recon_loss = F.mse_loss(modal2_recon, modal2_recon_target) total_recon_loss = weight_recon * (gene_recon_loss + modal2_recon_loss) # 2. 基因NB损失 cell_proportions = proportion_predictor(embedding) gene_mu = compute_gene_nb_mu(cell_type_gene_tensor, cell_proportions) gene_mu = torch.exp(torch.log(gene_mu) + gene_spot_offset.unsqueeze(1) + gene_gene_offset.unsqueeze(0)) gene_mu = gene_mu * gene_spot_sum_tensor gene_nb_loss = weight_nb * gene_negative_binomial_loss(gene_counts_tensor, gene_mu, gene_theta) # 3. 一致性损失 y1_recon, y2_recon = consistency_verifier( y1=y1, y2=y2, g_spatial1=g_spatial, g_feature1=g_feature_modal1, g_spatial2=g_spatial, g_feature2=g_feature_modal2 ) consistency_loss = weight_consistency * (F.mse_loss(y1, y1_recon) + F.mse_loss(y2, y2_recon)) # 4. 空间平滑损失 L_P = torch.matmul(L, cell_proportions) P_T_L_P = torch.matmul(cell_proportions.T, L_P) spatial_loss = weight_spatial * 0.5 * torch.trace(P_T_L_P) ### 总损失拼接 total_loss = gene_nb_loss + total_recon_loss + consistency_loss + spatial_loss ### 反向传播 optimizer.zero_grad() total_loss.backward() optimizer.step() ### 损失记录 total_losses.append(total_loss.item()) gene_nb_losses.append(gene_nb_loss.item()) recon_losses.append(total_recon_loss.item()) consistency_losses.append(consistency_loss.item()) spatial_losses.append(spatial_loss.item()) ### 训练日志 if (epoch + 1) % 100 == 0: torch.cuda.empty_cache() print( f"Epoch {epoch + 1}/{epochs} | " f"总损失: {total_loss.item():.4f} | " f"基因NB损失: {gene_nb_loss.item():.4f} | " f"平均重构损失: {total_recon_loss.item()/2:.4f} | " f"一致性损失: {consistency_loss.item():.4f} | " f"空间平滑损失: {spatial_loss.item():.4f}" ) ### 早停逻辑 if len(total_losses) > 60 and min(total_losses) != min(total_losses[-60:]): print(f"Early stopped at epoch {epoch + 1}.") final_embedding = embedding.detach().cpu() final_gene_recon = gene_recon.detach().cpu() final_modal2_recon = modal2_recon.detach().cpu() final_cell_proportions = cell_proportions.detach().cpu() final_attn_gene2other = attn_gene2other.detach().cpu() final_attn_other2gene = attn_other2gene.detach().cpu() break if epoch == epochs - 1: final_embedding = embedding.detach().cpu() final_gene_recon = gene_recon.detach().cpu() final_modal2_recon = modal2_recon.detach().cpu() final_cell_proportions = cell_proportions.detach().cpu() final_attn_gene2other = attn_gene2other.detach().cpu() final_attn_other2gene = attn_other2gene.detach().cpu() ### 8. 损失曲线绘制 if plot: plt.figure(figsize=(25, 5)) plt.subplot(1, 5, 1) plt.plot(total_losses, 'b-') plt.title('Total Loss') plt.xlabel('Epoch') plt.subplot(1, 5, 2) plt.plot(gene_nb_losses, 'g-') plt.title(f'Gene Negative Binomial Loss (w={weight_nb})') plt.xlabel('Epoch') plt.subplot(1, 5, 3) plt.plot(recon_losses, 'r-') plt.title(f'Average Reconstruction Loss (w={weight_recon})') plt.xlabel('Epoch') plt.subplot(1, 5, 4) plt.plot(consistency_losses, 'orange') plt.title(f'Consistency Loss (w={weight_consistency})') plt.xlabel('Epoch') plt.subplot(1, 5, 5) plt.plot(spatial_losses, 'purple') plt.title(f'Spatial Smooth Loss (w={weight_spatial})') plt.xlabel('Epoch') plt.tight_layout() os.makedirs(model_save_dir, exist_ok=True) plt.savefig(os.path.join(model_save_dir, "training_losses.png"), dpi=300) plt.close() ### 9. 模型保存 if model_save: os.makedirs(model_save_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_path = os.path.join(model_save_dir, f"SpaMultiDecon_model_{timestamp}.pth") checkpoint = { 'method': method, 'cross_fusion_state_dict': cross_fusion.state_dict(), 'reconstructor_state_dict': reconstructor.state_dict(), 'proportion_predictor_state_dict': proportion_predictor.state_dict(), 'consistency_verifier_state_dict': consistency_verifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'losses': { 'total': total_losses, 'gene_nb': gene_nb_losses, 'recon': recon_losses, 'consistency': consistency_losses, 'spatial': spatial_losses }, 'model_config': { 'embd_dim': embedding_dim, 'in_fea_dim_modal1': in_fea_dim_modal1, 'in_fea_dim_modal2': in_fea_dim_modal2, 'hidden_dim': hidden_dim, 'modal2_type': self.modal2_type, 'modal2_target_dim': self.modal2_target_dim }, 'data_params': { 'gene_theta': gene_theta, 'gene_spot_offset': gene_spot_offset, 'gene_gene_offset': gene_gene_offset } } torch.save(checkpoint, model_path) print(f"模型已保存到: {model_path}") ### 10. 结果存储 adata_result = self.smo_adata_dict['modal1'].copy() adata_result.obsm['cell_type_proportions'] = final_cell_proportions.numpy() proportion_df = pd.DataFrame( final_cell_proportions.numpy(), index=adata_result.obs_names, columns=self.ref_adata_dict['basis_matrix'].index.tolist() ) adata_result.obs = pd.concat([adata_result.obs, proportion_df], axis=1) adata_result.obsm['embedding'] = final_embedding.numpy() adata_result.obsm['reconstructed_gene'] = final_gene_recon.numpy() adata_result.obsm['reconstructed_modal2'] = final_modal2_recon.numpy() adata_result.obsm['attn_gene2other'] = final_attn_gene2other.detach().cpu().numpy() adata_result.obsm['attn_other2gene'] = final_attn_other2gene.detach().cpu().numpy() return adata_result, cross_fusion