TGNN-Solv Architecture¶
System Overview¶
The repository maintains two closely related model families:
TGNN-Solv- predicts crystal and interaction parameters, solves SLE, then applies a bounded correction in solver parameter space
DirectGNN- reuses the same graph backbone but predicts
ln(x2)directly without the thermodynamic bottleneck
The central comparison is whether the explicit physics bottleneck helps relative to the same backbone trained directly on solubility.
Internal Package Map¶
The implementation still lives in the legacy flat modules such as:
src/tgnn_solv/model.pysrc/tgnn_solv/trainer.pysrc/tgnn_solv/inference.py
For contributor navigation, the package now also exposes grouped namespaces:
tgnn_solv.modelstgnn_solv.physicstgnn_solv.trainingtgnn_solv.evaluationtgnn_solv.chemistrytgnn_solv.core
Those grouped imports are thin compatibility re-exports over the legacy flat modules. They do not change runtime behavior.
Optional Stage 0 Pretraining¶
Outside the main paper curriculum, the repository also implements a standalone
encoder/readout pretraining stage in src/tgnn_solv/pretrain.py.
This is distinct from Phase 1 in trainer.py:
Stage 0- optional pre-curriculum molecular pretraining on large SMILES sets
Phase 1- supervised auxiliary warmup on the processed solubility training split
The maintained Pretrainer updates:
model.gnnmodel.readout
in place, using four tasks:
- masked 2-hop subgraph atom reconstruction
- bond-type prediction
- RDKit property regression over a 12-descriptor target vector
- graph contrastive learning
The temporary Stage 0 heads are discarded after pretraining, so the downstream model remains the normal TGNN-Solv architecture.
The maintained Stage 0 surface is no longer only a raw Python class:
src/tgnn_solv/pretrain_pipeline.py- loads SMILES sources
- runs Stage 0
- saves/restores encoder checkpoints
- applies warm-start weights back into a fresh TGNN model
The saved Stage 0 checkpoint format intentionally contains:
gnn_state_dictreadout_state_dictpretrain_historypretrain_metadata
That makes Stage 0 a reusable warm-start artifact rather than a one-off notebook trick.
TGNN-Solv Forward Pass¶
The maintained TGNNSolv path in src/tgnn_solv/model.py runs in this order.
1. Dual-graph encoder¶
The encoder is selected by encoder_type:
encoder_type="mpnn"- the maintained local MPNN path via
GNNEncoder encoder_type="gps"GPSEncoder, which combines the local message-passing block with graph-global multi-head attention and per-graph positional encodings
When encoder_type="gps", positional encodings are computed on the fly from
the current batch graph:
gps_positional_encoding="laplacian"- absolute non-trivial Laplacian eigenvectors
gps_positional_encoding="rwse"- random-walk structural encodings
gps_pe_dim, gps_num_heads, and gps_use_edge_attr control the size and
behavior of that encoder path.
Both encoder families support:
encoder_role_mode="shared_residual"by default- optional
encoder_role_mode="split_late"for late role-specific blocks
Temperature is typically excluded from the crystal-property encoder path:
use_temperature_in_encoder=Falseuse_temperature_in_interaction=Falseuse_temperature_in_nrtl_head=True
This keeps the crystal heads mostly temperature-invariant and injects temperature later where it matters physically.
2. Pre-interaction auxiliary heads¶
Before solute-solvent interaction, the model predicts:
HansenHeadAuxPropsHeadforV_m
Two optional prior paths can bound those predictions:
use_descriptor_priors=True- learned adapter from compact RDKit-derived prior features
use_group_priors=True- fixed fragment-count priors
In either mode, the graph branch predicts a bounded residual around the prior.
3. Optional Morgan side information¶
If use_morgan_features=True, Morgan fingerprints are projected into the
molecular representation space:
- before the crystal-property heads
- again after interaction/readout
This augments the graph branch without bypassing the physics path.
4. Solute-solvent interaction and pair representation¶
The default interaction block is stacked cross-attention:
interaction_mode="cross_attn"
An alternative bipartite message-passing block is also supported:
interaction_mode="bipartite"
PhysicsAwareReadout produces graph-level vectors, and the pair vector is then
constructed as:
[g_sol, g_slv, g_sol * g_slv, |g_sol - g_slv|]
If use_descriptor_augmentation=True, the model also computes normalized
RDKit descriptor embeddings for solute and solvent, forms the analogous
descriptor interaction block
[d_sol, d_slv, d_sol * d_slv, |d_sol - d_slv|]
and projects the concatenated graph-plus-descriptor pair state back to
pair_dim before FusionHead and NRTLHead.
This is a true TGNN branch, not a DirectGNN-only feature. The descriptor path augments the TGNN pair state without removing the solver bottleneck, so it is useful when you want to test whether missing chemistry signal is upstream of the thermodynamic head rather than inside it.
Optional solvent-type routing is handled by SolventTypeMoE.
5. FusionHead crystal-property prediction¶
FusionHead predicts:
T_mdH_fusdCp_fus
There are two maintained crystal modes.
Standard mode:
T_m = T_m_min + (T_m_max - T_m_min) * sigmoid(...)dH_fus = S_H * softplus(...)dCp_fus = fixed_dCp_fusunlesspredict_dCp_fus=True
Crystal GC-prior mode:
- enabled with
use_gc_priors_crystal=True - raw
T_m_gc,dH_fus_gc, anddCp_fus_gcare computed per solute ingroup_contribution.py - raw GC priors come from SMARTS-based Joback-style fragmentation
- partial fragmentation is allowed; the hard
400 Kfallback is used only when no usable counts exist or a required increment is missing scripts/training/train.pyfits a train-only affine calibration for the melting prior:T_m_gc_calibrated = gc_prior_tm_scale * T_m_gc + gc_prior_tm_biasFusionHeadthen predicts only bounded residuals around the calibrated GC priors:T_m = T_m_gc_calibrated + gc_prior_Tm_residual_max * tanh(...)dH_fus = dH_fus_gc * zero_centered_scale(...)dCp_fus = dCp_fus_gc
Important implementation details:
- the last linear layers of the GC residual branches are zero-initialized
- at initialization, the GC crystal prediction is exactly the calibrated prior
gc_prior_residual_freeze_epochscan freeze those residual branches during the early part of Phase 1use_gc_priors_crystal=Truerequirespredict_dCp_fus=False
6. NRTLHead¶
NRTLHead predicts binary interaction parameters from the pair vector plus
explicit temperature features.
Supported parameterizations:
ref_invT(default)- predicts
tau(T_ref)plus one inverse-temperature slope per direction abclegacy
ref_invT is the maintained compact form.
7. Solver-facing substitution and oracle injection¶
Before the SLE solver is called, the model can alter which crystal parameters are sent into the solver while keeping the raw head outputs available for auxiliary losses.
Relevant mechanisms:
- GC crystal priors
- determine how
FusionHeadbuilds crystal predictions - oracle injection
- enabled with
use_oracle_injection=True - during training, supervised
T_mand/ordH_fuscan replace predicted values in the solver path - raw
fusion_paramsstill store the model prediction, not the substituted value
model.forward(...) therefore exposes both:
fusion_params- raw head predictions used by losses
solver_fusion_params- actual values sent into the solver
fusion_gc_priors- present when crystal GC priors are enabled
oracle_injection_masks- present when oracle injection is active
For diagnostics, evaluation can force oracle injection explicitly without changing normal inference behavior.
8. SLESolver¶
SLESolver contains zero learnable parameters. It combines:
- ideal solubility
- NRTL activity-coefficient terms
- iterative fixed-point solution
Key properties:
- training uses
n_iter_train, evaluation usesn_iter_eval - solver math runs in float32 for stability
- implicit differentiation is controlled by
use_implicit_diff - in GC crystal mode,
dCp_fus_gcflows into the ideal-solubility term through the solver-facing fusion parameters
9. AdaptivePhysicsCorrection¶
The correction path is still structured around solver parameters. It does not
replace ln(x2) with an unconstrained direct bypass.
It:
- predicts bounded deltas for
T_m,dH_fus,tau_12, andtau_21 - reruns the solver with those corrected parameters
- blends the corrected result through a learned gate
DirectGNN¶
src/tgnn_solv/baselines/direct_gnn.py reuses:
- the same encoder
- the same interaction stack
- the same readout
- thermometer temperature features
It removes:
FusionHeadNRTLHeadSLESolverAdaptivePhysicsCorrection
and replaces them with a direct MLP to ln(x2).
Optional DirectGNN feature paths:
use_morgan_features=Trueuse_descriptor_augmentation=True
Descriptor augmentation¶
The maintained descriptor path now does the following:
- computes the full RDKit descriptor vector for solute and solvent
- sanitizes non-finite values to zero before model use
- normalizes descriptors with train-set mean/std only
- stores
descriptor_meananddescriptor_stdin the checkpoint - reuses one shared descriptor MLP for both molecular roles
- augments the pair representation with
[d_sol, d_slv, d_sol * d_slv, |d_sol - d_slv|]
Training Behavior¶
src/tgnn_solv/trainer.py implements a three-phase curriculum:
- Phase 1
- supervised property warmup only
- no solubility loss
- correction frozen
- GC crystal residual branches can be frozen for the first
gc_prior_residual_freeze_epochs - Phase 2
- full SLE training
- correction unfreezes at
phase2_correction_unfreeze_epoch - oracle injection, if enabled, anneals over the last part of the phase
Inference-Time Utilities¶
The repository also ships post-training utilities that sit outside
model.forward(...) but are part of the maintained surface:
src/tgnn_solv/inference.py- checkpoint loading/saving
- single-system prediction
- temperature scans
- human-readable interpretation
src/tgnn_solv/uncertainty.py- MC-dropout and deep-ensemble uncertainty estimation
src/tgnn_solv/domain.py- applicability-domain / OOD screening
These are deployment and diagnostics layers, not extra learnable blocks inside TGNN-Solv itself.
Applicability domain¶
The current OOD helper fits on the training loader and combines:
- Mahalanobis distance in pair latent space
- nearest-neighbor Morgan Tanimoto similarity for solute and solvent
It also reports exact seen/unseen flags for the queried solute and solvent.
Despite some older high-level references, leverage is not part of the current implementation’s decision path. - Phase 3 - low-learning-rate fine-tuning - oracle injection forced off
The canonical paper budget remains 50 / 200 / 50.
Pair-Aware Temperature Batching¶
The main TGNN training path uses pair-aware batching by default so losses like:
pair_temp_rankvant_hoff_local
can observe multiple temperatures from the same (solute, solvent) pair in one
batch.
DirectGNN reuses the same batching controls where applicable.
High-Signal Config Flags¶
The easiest configuration flags to miss are:
encoder_role_modenrtl_tau_modeuse_morgan_featuresuse_descriptor_augmentationuse_descriptor_priorsuse_group_priorsuse_gc_priors_crystalgc_prior_tm_scalegc_prior_tm_biasgc_prior_residual_freeze_epochsuse_oracle_injectionbridge_loss_weightuse_walden_checkuse_pair_temperature_batching
All live in src/tgnn_solv/config.py.
Dataset Outputs¶
TGNNSolvDataset returns (solute_graph, solvent_graph, targets_dict).
Core keys include:
Tln_x2has_solubilitypair_keysolvent_typeT_m,T_m_mask,has_T_mdH_fus,dH_mask,has_dH_fushansen_sol,hansen_maskln_gamma_inf,gamma_mask
Optional keys appear when their feature paths are enabled:
solute_morgan_fp,solvent_morgan_fpsolute_descriptors,solvent_descriptorssolute_descriptor_prior_features,solvent_descriptor_prior_featuressolute_group_prior_features,solvent_group_prior_featuresT_m_gc,dH_fus_gc,dCp_fus_gc
Checkpoints and Resume¶
The main training CLIs now support resumable checkpoints:
scripts/training/train.py --checkpoint-every ... --resume ...scripts/training/train_directgnn.py --checkpoint-every ... --resume ...
The heavy experiment runners reuse those checkpoints automatically:
scripts/experiments/run_full_budget_experiment.pyscripts/experiments/run_medium_budget_comparison.py
Artifact and Benchmark Metadata¶
The architecture layer now extends beyond the forward graph itself. Maintained training and evaluation entry points emit structured sidecars so downstream benchmarking and reproduction can recover provenance without relying on shell history:
- checkpoint sidecars
<checkpoint>.manifest.json<checkpoint>.model_card.json- benchmark bundle sidecars
run_manifest.jsonbenchmark_card.json
These metadata files are built through tgnn_solv.artifacts and describe:
- model family
- resolved config path
- input data paths and checksums
- git commit and dirty state
- output artifact paths
- supported capability flags such as uncertainty, OOD, and physics reporting
That provenance layer is now part of the maintained system architecture, not an optional afterthought.