Warm-up: string vacua and automatic differentiation in JAX#
What’s in this notebook? This notebook provides an introduction to JAX and some of its features like automatic differentiation or just-in-time compilation. We also demonstrate the general implementation for JAXVacua by going through an example, namely compactification on a the symmetric \(T^6\) following the conventions of section 4 of hep-th/0411061.
(Created: Andreas Schachner, August 19, 2024)
Outline#
Setup#
# General imports
import warnings
import time
import numpy as np
from tqdm.auto import tqdm
# JAX imports
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# Plotting
import seaborn as sn
import matplotlib.pyplot as plt
cmap = sn.color_palette("viridis", as_cmap=True)
warnings.filterwarnings('ignore')
Introduction to JAX#
JAX is an open-source software package developed by Google that enables high-performance numerical computing and machine learning research. It combines NumPy-like array operations with automatic differentiation (autograd), Just-In-Time (JIT) compilation via XLA (Accelerated Linear Algebra), and parallel execution, making it a powerful tool for deep learning, scientific computing, and large-scale optimisation tasks.
JAX is designed to be both flexible and efficient, offering key features such as:
Automatic Differentiation: Supports forward- and reverse-mode differentiation, making it ideal for gradient-based optimisation.
JIT Compilation: Uses XLA to compile and optimize code for execution on CPUs, GPUs, and TPUs, leading to significant speedups.
Vectorization (
vmap): Enables easy and efficient batch processing without manual looping.Parallel Execution (
pmap): Allows scalable computation across multiple devices, including multi-GPU and TPU setups.
JAX serves as the foundation for deep learning frameworks such as Flax and Haiku, making it a popular choice for AI research and experimentation. With its NumPy-like API and composable transformations, JAX provides a modern and efficient alternative to traditional deep learning libraries.
Let us stress here the core feature sitting at the heart of jax: Automatic Differentiation (AD) is a computational technique for efficiently and accurately computing derivatives of functions. Unlike numerical differentiation (which approximates derivatives using finite differences) or symbolic differentiation (which manipulates mathematical expressions), AD leverages the chain rule to compute derivatives programmatically. AD works by decomposing complex functions into a sequence of elementary operations, for which derivatives are known, and systematically applying the chain rule.
In the context of machine learning, and especially neural networks, AD is indispensable. In particular, JAX takes AD further by combining it with JIT compilation, allowing researchers to achieve exceptional computational speed and scalability. The training process involves optimising a loss function by adjusting model parameters using gradient-based optimisation algorithms (e.g., stochastic gradient descent). AD automates the computation of gradients with respect to millions (or even billions) of parameters, ensuring both accuracy and efficiency. Key benefits of AD include:
Efficiency: AD avoids the computational inefficiencies of numerical differentiation and the algebraic complexity of symbolic differentiation.
Generality: AD can handle arbitrary architectures, including complex, custom-designed neural networks, without requiring manual derivation of gradients.
Scalability: Libraries likeJAX leverage AD to compute gradients across entire datasets and optimise performance on modern hardware such as GPUs and TPUs.
For theoretical physicists, AD is particularly useful in applications such as lattice field theory, and exploring energy landscapes, where gradient computations play a central role. By using AD-enabled libraries, physicists can focus on the physical and mathematical aspects of their research without being bogged down by the intricacies of derivative calculations.
In this notebook, rather than providing a general introduction to jax, we focus directly on applications to string compactifications.
We discuss the simple example of a symmetric \(T^6\) to highlight the main features of jax and also to explain the basic principles of our implementation
further described in the notebook 02_jaxvacua_overview.ipynb.
There are many online resources and general introduction to jax and related packages to which we refer for more details. Here are some of them:
Lectures at Amsterdam University on
jaxand ML in Physics
Applications to string vacua#
Automatic differentiation (AD) is useful in string compactifications for moduli stabilisation because it enables efficient computation of gradients of scalar potentials, which govern the dynamics of moduli fields. These gradients are essential for finding stable vacua in string theory, where the potential landscape is often complex and high-dimensional. JAX is a strong choice for implementation due to its efficient AD (jax.grad), Just-In-Time (JIT) compilation (jax.jit) for speed-up, and automatic vectorisation (jax.vmap) for parallel computation. Its ability to seamlessly run on GPUs and TPUs further accelerates numerical experiments, making it ideal for exploring the vast moduli space in string theory.
Below, we will use flux compactifications on the symmetric \(T^6\) as a testing ground, see the very last section of this notebook for further background and conventions.
In this case, the above idea is simple: we only have to define two input functions, namely the Kähler potential \(K\) and superpotential \(W\). All other objects like the Kähler covariant derivatives \(DW\) or the scalar potential \(V\) can be computed just through automatic differentiation. Below, we provide a brief introduction on how to apply the jax framework to flux compactifications on the symmetric \(T^6\) as a simple toy model. More details on the general implementation in JAXVacua will be provided in the notebook 02_jaxvacua_overview.ipynb.
Computing holomorphic derivatives with jax#
Automatic differentiation (AD) is a technique for computing exact derivatives of functions efficiently and accurately, without relying on symbolic differentiation or numerical approximations. JAX implements AD using forward-mode and reverse-mode differentiation, making it well-suited for both scalar computations and large-scale optimisation problems. The jax.grad function enables efficient gradient computation via reverse-mode AD, which is particularly useful for optimising high-dimensional functions, such as those in deep learning and physics simulations. Additionally, JAX supports higher-order differentiation with jax.jacfwd and jax.jacrev, allowing computation of Jacobians and Hessians. By leveraging XLA (Accelerated Linear Algebra) for Just-In-Time (JIT) compilation (see below), JAX further accelerates differentiation, making it ideal for scientific computing and machine learning.
Let us start simple and look at an arbitrary Kähler potential:
Kahler = lambda x,y: -jnp.log((-1.j*(x[0]-y[0]))+(-1.j*(x[1]-y[1])))
Here, x are the holomorhpic and y the anti-holomorphic (conjugate) coordinates which we treat as being independent at this stage. We can take the derivative along the x coorinates via
dKx = jax.grad(Kahler,argnums=0,holomorphic=True)
Here, argnums=0 indentifies the component to take drivatives of. Further, holomorphic=True ensures that we’re working with complex coordinates.
This function can be evaluated at an arbitrary point in moduli space using:
test_x=jnp.array([1.+1j*1.,1.+1j*1])
test_y=jnp.conj(test_x)
dKx(test_x,test_y)
Similarly, the derivative along the y direction is
dKy=jax.grad(Kahler,argnums=1,holomorphic=True)
dKy(test_x,test_y)
This is simply the complex conjugate of dKx as expected.
We can compute the Jacobian along the x and y directions also using jax.jacfwd:
dKi=jax.jacfwd(Kahler,argnums=(0,1),holomorphic=True)
dKi(test_x,test_y)
As we have seen above, another important object to compute is the Kähler metric as the second derivative of the Kähler potential. In principle, we can compute the Kähler metric in our toy setup as the Hessian directly via
Kij=jax.hessian(Kahler, argnums=(0,1), holomorphic=True)
Kij(test_x,test_y)
Then we need to take only the corresponding entries! This means, however, that we compute extra derivatives like \(K_{xx}\) etc. that we may not strictly speaking need. It is therefore convenient to use instead
Kij=jax.jacfwd(jax.grad(Kahler, argnums=(0), holomorphic=True), argnums=(1), holomorphic=True)
Kij(test_x,test_y)
This gives the expected result!
The symmetric \(T^6\)#
Let us now come to applying the above approach to systematising compactifications on the symmetric torus \(T^6\) as described below. We define Kähler and superpotential as
kahler = lambda z,cz,tau,ctau: -3*jnp.log((-1.j*(z-cz)))-jnp.log((-1.j*(tau-ctau)))
We can evaluate the Kähler potential for a particular choice for the values of the moduli:
z0=1.+1j*2.
cz0=jnp.conj(z0)
tau0=1.+1j*1
ctau0=jnp.conj(tau0)
kahler(z0,cz0,tau0,ctau0)
We note that we have to provide the complex conjugate fields as inputs because they have to be treated as independent variables in this framework. If we were not to include the conjugate variables, it becomes basically impossible to define the Kähler metric as we illustrate below:
def kahler_wrong(z,tau):
cz = jnp.conj(z)
ctau = jnp.conj(tau)
return -3*jnp.log((-1.j*(z-cz)))-jnp.log((-1.j*(tau-ctau)))
# Correct Kähler metric
K_z_cz = jax.jacfwd(jax.grad(kahler,argnums=0,holomorphic=True),argnums=1,holomorphic=True)
# Wrong Kähler metric
K_z_cz_wrong = jax.jacfwd(jax.grad(kahler_wrong,argnums=0,holomorphic=True),argnums=0,holomorphic=True)
K_z_cz(z0,cz0,tau0,ctau0), K_z_cz_wrong(z0,tau0)
As expected, the correctly defined Kähler metric \(K_{z\bar{z}}=\partial_z \partial_{\bar{z}} K\) is real and positive.
Next, we define the superpotential
P1 = lambda z,F: -F[1] - 3*F[5]*z - 3*F[4]*z**2 + F[0]*z**3
P2 = lambda z,F: -F[3] - 3*F[7]*z - 3*F[6]*z**2 + F[2]*z**3
superpotential = lambda z,tau,F: P1(z,F)-tau*P2(z,F)
We test these expressions on a particular choice of fluxes and values for the moduli:
FF=jnp.array([-1.,1., 1., 0., 2., 1., 4., 0.])
superpotential(z0,tau0,FF)
The first derivatives of the Kähler metric are easily obtained using the function jax.grad:
Kz=jax.grad(kahler,argnums=0,holomorphic=True)
Ktau=jax.grad(kahler,argnums=2,holomorphic=True)
Similarly, the derivatives \(\partial_z W\) and \(\partial_\tau W\) of the superpotential are computed as:
Wz=jax.grad(superpotential,argnums=0,holomorphic=True)
Wtau=jax.grad(superpotential,argnums=1,holomorphic=True)
Next, we can derive the Kähler covariant derivatives \(D_i W = \partial_i W+ (\partial_i K) W\) as
DWz= lambda z,cz,tau,ctau,F: Wz(z,tau,F)+Kz(z,cz,tau,ctau)*superpotential(z,tau,F)
DWtau= lambda z,cz,tau,ctau,F: Wtau(z,tau,F)+Ktau(z,cz,tau,ctau)*superpotential(z,tau,F)
Evaluating these expressions, we get
DWz(z0,cz0,tau0,ctau0,FF),DWtau(z0,cz0,tau0,ctau0,FF)
The Kähler metric is obtained from:
Kij=jax.hessian(kahler, argnums=(0,1,2,3), holomorphic=True)
Kij(z0,cz0,tau0,ctau0)[0][1]
We now have everything at our disposal to define the scalar potential for flux compactification on the symmetric T6. It can be computed as follows:
def scalar_potential(z,cz,tau,ctau,fluxes):
r"""
**Description:**
Computes the flux scalar potential for compactification on the symmetric T^6.
Args:
z (complex): Value of the complex structure modulus.
cz (complex): Value of the conjugate complex structure modulus.
tau (complex): Value of the axio-dilaton.
ctau (complex): Value of the conjugate axio-dilaton.
fluxes (Array): Fluxes.
Returns:
complex: Value of the flux scalar potential.
"""
# Compute the Kähler metric
KM=Kij(z,cz,tau,ctau)
# Compute the inverse Kähler metric
IKM_csm=1/KM[0][1]
IKM_tau=1/KM[2][-1]
# Compute the F-terms
DWt=DWtau(z,cz,tau,ctau,fluxes)
DWm=DWz(z,cz,tau,ctau,fluxes)
return jnp.exp(kahler(z,cz,tau,ctau))*(IKM_csm*DWm*jnp.conj(DWm)+IKM_tau*DWt*jnp.conj(DWt))
We can evaluate it as follows
scalar_potential(z0,cz0,tau0,ctau0,FF)
We explained above that for flux compactifications of the symmetric \(T^6\), the \(F\)-term conditions can actually be solved analytically. Let us therefore test our implementation from above agains an actual minimum. We choose the fluxes \begin{equation} f = h = (0,1, 1, 0) \end{equation} which should lead to the VEVs \begin{equation} \langle z\rangle = \mathrm{i}\quad , ; \langle \tau\rangle = \mathrm{i} \end{equation} with a value of the superpotential at the minimum of \begin{equation} \langle W\rangle = -2-6\mathrm{i}, . \end{equation} Let us first test this value of the superpotential with our function from above
FF=jnp.array([0.,1., 1., 0., 0., 1., 1., 0.])
z_sol=1j*1.
tau_sol=1j*1.
superpotential(z_sol,tau_sol,FF)
This is indeed the expected result.
Let us verify that the above is an actual SUSY vacuum. That is, let us show that the \(F\)-term conditions vanish:
DWz(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF),DWtau(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF)
Both \(F\)-term conditions are indeed zero. We have argued above that this should be equivalent to the flux potential vanishing:
scalar_potential(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF)
Indeed, we have obtaind a SUSY Minkowski minimum. More ecplicitly, to check that these are actually minima, we should also have vanishing first derivatives of the scalar potential. These can be computed as follows:
dV_z = jax.grad(scalar_potential,argnums=0,holomorphic=True)
dV_tau = jax.grad(scalar_potential,argnums=2,holomorphic=True)
dV_z(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF),dV_tau(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF)
This means indeed that the above values correspond to a minimum of the potential. Of course, working in SUGRA, we know that solutions to the \(F\)-term conditions automatically lead to a minimum of the full potential. But the abvoe serves as a consistency check for our implementation.
Note that we intentially let the scalar potential as defined above have complex output.
This is because we want to take holomorphic derivatives with respect to our complex scalar fields which requires both input and output types to be complex.
More explicitly, if we defined the scalar potential with real output type, we could write
scalar_potential_real = lambda z,cz,tau,ctau,fluxes: scalar_potential(z,cz,tau,ctau,fluxes).real
scalar_potential_real(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF)
As we see, the output dtype is now float64 instead complex128 as above.
However, if we were now to compute holomorphic derivatives, we end up with the following error
try:
dV_z_real = jax.grad(scalar_potential_real,argnums=0,holomorphic=True)
dV_z_real_val = dV_z_real(z_sol,jnp.conj(z_sol),tau_sol,jnp.conj(tau_sol),FF)
except Exception as e:
print(e)
The exception that you should find is: grad with holomorphic=True requires outputs with complex dtype, but got float64.
Just in time compilation and automatic vectorisation#
Up to this point, we have mainly focussed on using automatic differentiation to compute the scalar potential and related objects.
However, there are additional features implemented within the jax framework which make the above implementations even more powerful and versatile.
The first is Just-In-Time (JIT) compilation which is a technique that improves execution speed by compiling code at runtime rather than before execution. In JAX, JIT compilation is implemented using XLA (Accelerated Linear Algebra), which optimizes and fuses operations for efficient execution on GPUs and TPUs. When a function is decorated with jax.jit, JAX traces its computation graph, compiles it using XLA, and caches the optimized version for reuse. This significantly speeds up numerical computations, especially in machine learning and scientific computing, by reducing overhead and leveraging hardware acceleration.
Let us apply jax.jit to the function scalar_potential as defined above:
scalar_potential_jit = jax.jit(scalar_potential)
To compile this function, we first have to evaluate it once:
scalar_potential_jit(z0,cz0,tau0,ctau0,FF)
We can then compare the timing for evaluating the scalar potential with and without just-in-time compilation:
%%timeit
res = scalar_potential(z0,cz0,tau0,ctau0,FF)
%%timeit
res = jax.block_until_ready(scalar_potential_jit(z0,cz0,tau0,ctau0,FF))
We find about a speed-up of around two orders of magnitude for the compiled function using jax.jit. (Note that the actual timing might depend on details of the machine used for running the code.) Importantly, jax.jit is simply a decorator that is applied to a standard function that provides a large speed-up without changing the underlying implementation.
Automatic vectorisation is a technique that transforms scalar operations into parallel computations, enabling efficient execution on modern hardware. In JAX, this is implemented using jax.vmap, which applies a function across batches of inputs without explicit loops. vmap automatically vectorises computations by leveraging hardware accelerators like GPUs and TPUs, optimising memory access patterns and reducing execution time. Unlike manual batching, vmap maintains code simplicity while ensuring efficient parallel execution, making it particularly useful for deep learning, numerical simulations, and large-scale tensor computations.
Let us now apply jax.vmap to the scalar potential above:
scalar_potential_vmap = jax.vmap(scalar_potential)
N=10**2
FF=np.random.randint(-10,10,(N,8))
z0=np.random.uniform(-1,1,N)+1j*np.random.uniform(0,5,N)
cz0=jnp.conj(z0)
tau0=np.random.uniform(-1,1,N)+1j*np.random.uniform(0,5,N)
ctau0=jnp.conj(tau0)
scalar_potential_vmap(z0,cz0,tau0,ctau0,FF)[:10]
In this way, we managed to
We can speed up the above code with jax.jit:
scalar_potential_vmap_jit = jax.vmap(scalar_potential_jit)
scalar_potential_vmap_jit(z0,cz0,tau0,ctau0,FF)[:10]
Comparing the timing of the two function, we find:
%%timeit
res = jax.block_until_ready(scalar_potential_vmap(z0,cz0,tau0,ctau0,FF))
%%timeit
res = jax.block_until_ready(scalar_potential_vmap_jit(z0,cz0,tau0,ctau0,FF))
Again, we find a huge speed-up of around two orders of magnitude.
The advantage of using jax.vmap becomes even more apparent by comparing it with a standard for-loop:
import time
from tqdm.auto import tqdm
tic = time.time()
res = []
for i in tqdm(range(N)):
res.append(scalar_potential(z0[i],cz0[i],tau0[i],ctau0[i],FF[i]))
toc = time.time()
print(f"Time for evaluation: {np.around(toc-tic,2)}")
jnp.array(res)[:10]
The evaluation is a standard for-loop takes several seconds, while the vmapped function scalar_potential_vmap only a couple of milliseconds.
Thus, just applying the decorator jax.vmap provides an easy way to evaluate a single function for many datapoints in a much more efficient way.
Note that the order in which we apply jax.jit and jax.vmap does not really matter for the most part:
scalar_potential_jit_vmap = jax.jit(jax.vmap(scalar_potential))
scalar_potential_jit_vmap(z0,cz0,tau0,ctau0,FF)
%%timeit
res = jax.block_until_ready(scalar_potential_jit_vmap(z0,cz0,tau0,ctau0,FF))
We see approximately the same speed-up.
Background Material: flux compactifications on a symmetric \(T^6\)#
For completeness, we provide a short summary of our conventions for Type IIB flux compactifications on the symmetric \(T^6\). We consider compactifications on a symmetric \(T^6\). One interesting feature of the symmetric \(T^6\) is the existence of vacua with vanishing tree-level superpotential. These special vacua give rise to the opportunity to study their emergence within GAs. A major issue with finding such vacua is that they are quite scarce in comparison to generic vacua. Hence, we cannot take for granted that GAs can identify these exceptional cases within the landscape.
We follow the conventions of hep-th/0411061. The symmetric torus can be viewed as a direct product of three \(T^{2}\) setting the modular parameters \(z\equiv z_{1}=z_{2}=z_{3}\) all equal. The moduli space has two complex dimensions, so we have upon gauge fixing \(8\) independent flux parameters. Let us first parametrize a general \(T^6\) before we specialize to the symmetric case further below. We define coordinates \(x^i,y^i\) for \(i=1,2,3\) with periodicity \(x^i\equiv x^i+1,y^i\equiv y^i+1\) such that the three holomorphic 1-forms can be written as \(\mathrm{d} z^i=\, \mathrm{d} x^i+z^{ij}\, \mathrm{d} y^j\). We take the orientation \begin{align} \int , \mathrm{d} x^1\wedge , \mathrm{d} x^2\wedge , \mathrm{d} x^3\wedge , \mathrm{d} y^1\wedge , \mathrm{d} y^2\wedge , \mathrm{d} y^3=1 \end{align} and choose a symplectic basis for \(H^3(T^6,\mathbb{Z})\), namely \begin{align} \alpha^0&=, \mathrm{d} x^1\wedge , \mathrm{d} x^2\wedge , \mathrm{d} x^3\quad ,; \alpha_{ij}=\frac{1}{2}\epsilon_{ilm}, \mathrm{d} x^l\wedge , \mathrm{d} x^m\wedge , \mathrm{d} y^j\quad ,;\beta^{ij}=-\frac{1}{2}\epsilon_{jlm}, \mathrm{d} y^l\wedge , \mathrm{d} y^m\wedge , \mathrm{d} x^i\quad ,;\beta^0=, \mathrm{d} y^1\wedge , \mathrm{d} y^2\wedge , \mathrm{d} y^3, . \end{align} The holomorphic 3-form can be written as \begin{equation} \Omega=, \mathrm{d} z^1\wedge , \mathrm{d} z^2\wedge , \mathrm{d} z^3, . \end{equation} We can expand the \(3\)-form fluxes in terms of the symplectic basis \begin{align} F_3&=a^0\alpha^0+a^{ij}\alpha_{ij}+b_{ij}\beta^{ij}+b_0\beta^0\quad , ; H_3&=c^0\alpha^0+c^{ij}\alpha_{ij}+d_{ij}\beta^{ij}+d_0\beta^0, . \end{align}
For a symmetric \(T^{6}\), we take \begin{equation} z^{ij}=z\delta^{ij}, . \end{equation} This is equivalent to taking the \(T^6\) to be factorizable as three two-tori with equal modular parameter. Similarly, the fluxes get reduced to \begin{align} a^{ij}=a\delta^{ij},\quad b_{ij}=b\delta_{ij},\quad c^{ij}=c\delta^{ij},\quad d_{ij}=d\delta_{ij}, . \end{align} The superpotential takes the simple form \begin{align} W=P_1(z)-\tau P_2(z) \end{align} where \(P_i\) are cubic polynomials in \(z\), i.e., \begin{align} P_1(z)&= a^0z^3-3az^2-3bz-b_0\quad , , P_2(z)&= c^0z^3-3cz^2-3dz-d_0, . \end{align} The K”ahler potential for \(z\) and \(\tau\) reads \begin{align} \mathcal{K}=-3\log(-i(z-\overline{z}))-\log(-i(\tau-\overline{\tau})) \end{align} and the D3-brane charge induced by fluxes corresponds to \begin{align} N_{\rm flux}=b_0c^0-a^0d_0+3(bc-ad), . \end{align}
The F-term constraints can be written in the form \begin{align} P_1(z)-\overline{\tau}P_2(z)&=0, ,\ P_1(z)-\tau P_2(z)&=(z-\overline{z})(P_{1}^{\prime}(z)-\tau P_{2}^{\prime}(z)) \end{align} For non-zero VEV of the superpotential \(W_{0}\neq 0\), the axio-dilaton can be obtained from the first equation so that \begin{equation} \tau=\dfrac{\overline{P_{1}(z)}}{\overline{P_{2}(z)}}, . \end{equation} Plugged into the second equation, one finds for \(z=x+iy\) \begin{align} q_1(x)y^2&=q_3(x) \quad , ; q_0(x)y^4=q_4(x) , . \end{align} The \(q_i\) are polynomials in \(x\) which have for instance been computed in the appendix of hep-th/0411061. Surprisingly, multiplying both equations to eliminate \(y\), a cubic (rather than sextic) equation in \(x\) remains so that \(x\) can be found by solving \begin{align} \alpha_3 x^3+\alpha_2 x^2+\alpha_1 x+\alpha_0=0, . \end{align} The coefficients \(\alpha_i\) are combinations of flux integers and can again be found in hep-th/0411061.
Solutions with \(W=0\) satisfy \begin{equation} P_{1}(z)=P_{2}(z)=0, . \end{equation} Thus, the solution for \(\tau\) simply reads \begin{equation} \tau=\dfrac{P_{1}^{\prime}(z)}{P_{2}^{\prime}(z)}, . \end{equation} As shown in hep-th/0201028, these solutions obey the special property that \(P_{1}\) and \(P_{2}\) must factorize over the integers, cf. Sect. 4.3.3 in hep-th/0411061.
Take-aways#
JAX exposes forward- and reverse-mode automatic differentiation via
jax.grad,jax.jacfwd/jax.jacrev, andjax.hessian.Moduli derivatives are obtained with
holomorphic=Trueand require complex-valued output; the real scalar potential must be wrapped (e.g..real) beforejax.gradis applied.For Type IIB flux compactifications on the symmetric \(T^6\), \(F\)-term equations \(D_iW = 0\) identify SUSY Minkowski vacua of the scalar potential \(V\).
jax.jitandjax.vmapdeliver order-of-magnitude speed-ups when evaluating the scalar potential and its derivatives across many flux samples.These primitives are the building blocks that the higher-level JAXVacua API wraps in subsequent notebooks.