# Generative_jax **Repository Path**: fmscole/generative_jax ## Basic Information - **Project Name**: Generative_jax - **Description**: 纯jax实现的生成模型 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-06-03 - **Last Updated**: 2026-06-04 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Generative Models in JAX / Stax Pure JAX + Stax implementations of GANs, VAEs, and Diffusion Models. ## Structure ``` ├── lib/ # Shared library │ ├── stax_plus.py # Extended JAX stax layer library │ └── datasets.py # MNIST / Fashion-MNIST data loader ├── gans/ # Generative Adversarial Networks │ ├── cgan_stax.py # Conditional GAN │ ├── cgan_stax_fori_loop.py # C-GAN with lax.fori_loop │ ├── cgan_stax_scan.py # C-GAN with lax.scan │ └── wcgan_stax.py # Wasserstein GAN with Gradient Penalty ├── vae/ # Variational Autoencoders │ └── vae_stax.py # VAE (convolutional encoder/decoder) ├── diffusion/ # Diffusion Models │ ├── diffusion_stax.py # DDPM with Stax U-Net │ └── configs/mnist.py # MNIST configuration ``` ## Installation ```bash pip install jax jaxlib optax numpy matplotlib torch torchvision ``` > `torch`/`torchvision` is only used for MNIST data loading. > The models themselves are pure JAX/Stax. ## Usage ```bash cd generative # Conditional GAN on MNIST python gans/cgan_stax.py # Wasserstein GAN on MNIST python gans/wcgan_stax.py # VAE on MNIST python vae/vae_stax.py # Diffusion model on MNIST python diffusion/diffusion_stax.py ``` MNIST is automatically downloaded to `data/` on first run.