Releases: google/flax
0.12.6
What's Changed
- Annotate
wrapped_fn'sselfargument indecorator_lift_transformfor trace-ability by @copybara-service[bot] in #5298 - add with_attributes by @copybara-service[bot] in #5308
- Remove optax pin by @samanklesaria in #5292
- Removed redundant code checking wrt in nnx.Optimizer by @vfdev-5 in #5226
- Fix PyTreeNode + Generic losing parameters when Generic is last in bases by @mohsinm-dev in #5237
- add recursive_map example by @copybara-service[bot] in #5311
- Add sharding propagation support in nnx.eval_shape (clone of #5111) by @samanklesaria in #5247
- fix eval_shape's _to_variable by @copybara-service[bot] in #5316
- support out_sharding mapping in set_metadata by @copybara-service[bot] in #5312
- feat(nnx): add GQA support to MultiHeadAttention by @ayulockedin in #5259
- Allow DenyList to be compared. by @copybara-service[bot] in #5322
- add graph_updates argument to jit by @copybara-service[bot] in #5317
- add graph_updates argument in shard_map by @copybara-service[bot] in #5319
- Remove jax/tools/colab_tpu.py. by @copybara-service[bot] in #5324
- add graph_updates to vmap by @copybara-service[bot] in #5320
- Add split method to RngStream by @samanklesaria in #5270
- add graph_updates for scan by @copybara-service[bot] in #5327
- add graph_updates to while_loop by @copybara-service[bot] in #5328
- Added support for data masking in Average, Accuracy and MultiMetric by @vfdev-5 in #5326
- add graph_updates to fori_loop by @copybara-service[bot] in #5329
- add graph_updates to pmap, grad, and value_and_grad by @copybara-service[bot] in #5330
- add tree-mode-nnx FLIP by @copybara-service[bot] in #5310
- Copybara import of the project: by @copybara-service[bot] in #5331
- add graph_updates to remat by @copybara-service[bot] in #5336
- add graph_updates to eval_shape and checkify by @copybara-service[bot] in #5338
- add compat module by @copybara-service[bot] in #5340
- add more tests that check for consistent aliasing in transforms by @copybara-service[bot] in #5341
- don't allow Variable mutation in custom_vjp on differentiable arguments by @copybara-service[bot] in #5342
- support hijax and ref Variables in simple transforms by @copybara-service[bot] in #5345
- allow in/out_axes when graph_updates=False by @copybara-service[bot] in #5323
- add transform_metadata transform by @copybara-service[bot] in #5346
- clean up custom_vjp's graph_updates=False section by @copybara-service[bot] in #5350
- improve error messages for tree mode duplicates check by @copybara-service[bot] in #5347
- Introduce
manual_type: ManualAxisTypeparameter on ShapedArray to track varying/unreduced/reduced and removevma: frozensetparameter. by @copybara-service[bot] in #5339 - remove aliases from nnx.graph in favor of nnx.compat by @copybara-service[bot] in #5348
- check aliases on all transform args and simplify apply_variable_updates by @copybara-service[bot] in #5349
- Add intermediate value captures (extends #4925) by @samanklesaria in #5257
- fix jit_partial lower and compile by @copybara-service[bot] in #5355
- Added support for data masking in Average, Accuracy and MultiMetric by @vfdev-5 in #5332
- Do a few more cleanups after pmap_shmap merge by @copybara-service[bot] in #5358
- update to version 0.12.6 by @copybara-service[bot] in #5356
- Do a few more cleanups after pmap_shmap merge by @copybara-service[bot] in #5359
- Do a few more cleanups after the pmap shmap merge deletion. by @copybara-service[bot] in #5361
- simplify SimpleScan by @copybara-service[bot] in #5357
- improve nnx.Dict error handling by @copybara-service[bot] in #5362
- add graph node in prefix checks by @copybara-service[bot] in #5365
- improve aliasing error msg by @copybara-service[bot] in #5364
- Generalize out_sharding to work with NamedSharding and Format by @samanklesaria in #5246
- add nnx.map and nnx.abstract_with_sharding by @copybara-service[bot] in #5366
Full Changelog: v0.12.5...v0.12.6
0.12.5
What's Changed
- maintain data/static definition in split / state by @copybara-service[bot] in #5243
- Add graph=False support for nnx.grad and nnx.value_and_grad by @copybara-service[bot] in #5240
- Add tree-mode support for nnx.remat by @copybara-service[bot] in #5242
- Add tree-mode support for nnx.vmap by @copybara-service[bot] in #5250
- Add tree-mode arg error handling by @copybara-service[bot] in #5251
- [flax:benchmarks] Fix flax benchmarks. by @copybara-service[bot] in #5241
- Improve unregistered data detection by @copybara-service[bot] in #5249
- Add graph argument to split, state, graphdef, clone by @copybara-service[bot] in #5258
- Add nnx_graph_mode config flag by @copybara-service[bot] in #5261
- Add tree-mode support to nnx.{cond,switch,eval_shape,checkify} by @copybara-service[bot] in #5252
- Add tree-mode support to nnx.scan by @copybara-service[bot] in #5262
- Add tree-mode support for nnx.custom_vjp by @copybara-service[bot] in #5264
- Add tree-mode support to nnx.{while,fori}_loop by @copybara-service[bot] in #5263
- add tree-mode support for pmap by @copybara-service[bot] in #5265
- Add tree-mode support to iter_graph by @copybara-service[bot] in #5266
- Add tree-mode support for nnx.recursive_map, nnx.view, and nnx.view_info by @copybara-service[bot] in #5267
- Use linen_vars_to_nnx_attrs in ToLinen variable restoration by @copybara-service[bot] in #5272
- Add tree-mode-only nnx.{vjp,jvp,jit_partial} by @copybara-service[bot] in #5268
- Use linen_vars_to_nnx_attrs in ToLinen variable restoration by @copybara-service[bot] in #5273
- fix nn.vmap's _broadcast_prefix_tree by @copybara-service[bot] in #5282
- fix nn.vmap's split_fn by @copybara-service[bot] in #5286
- [pmap] Remove
jax.config.pmap_shmap_merge. by @copybara-service[bot] in #5289 - add set_graph_mode by @copybara-service[bot] in #5269
- add error handling to iter_children by @copybara-service[bot] in #5296
- expose and update view guide by @cgarciae in #5294
- Remove chex dependency by @samanklesaria in #5295
- improve error messages for tree mode errors by @copybara-service[bot] in #5301
- remove experimental Variable.mutable property by @copybara-service[bot] in #5304
- update version to 0.12.5 by @copybara-service[bot] in #5305
- simplify nnx.display for Variable by @copybara-service[bot] in #5306
- simplify set_view API by @copybara-service[bot] in #5303
Full Changelog: v0.12.4...v0.12.5
0.12.4
This release fixes an issue with nnx.List and nnx.Sequential passing all its elements as static when being treated as a pytree.
What's Changed
- fix List and improve Pytree by @cgarciae in #5072
- Copybara import of the project: by @copybara-service[bot] in #5211
- Fixed mask sharding if inputs is sharded by @vfdev-5 in #5212
- nnx.view by @chapman20j in #5204
- bonsai link by @chapman20j in #5220
- Add out_sharding argument to Embed layer call by @samanklesaria in #5205
- Improve Pytree flatten/unflatten by @copybara-service[bot] in #5216
- Pin optax version to fix CI issues by @samanklesaria in #5225
- Pin doc deps to lower versions by @samanklesaria in #5222
- Fix Sequential unwrapping for static attributes by @samanklesaria in #5218
- feat(nnx): add Grouped Query Attention (GQA) support by @ayulockedin in #5180
- chore: Migrate gsutil usage to gcloud storage by @gurusai-voleti in #5230
- Makes
linen.Partitionedcompatible withlinen.WeightNormby @copybara-service[bot] in #5234 - maintain data/static definition in split / state by @copybara-service[bot] in #5227
- Adds a test for
linen.WeightNormandlinen.with_partitioningcompatibility by @copybara-service[bot] in #5236 - Add Tunix link by @copybara-service[bot] in #5239
- Add tree-mode support for nnx.shard_map by @copybara-service[bot] in #5238
- Rename sharding_names to out_sharding in NNX Variable metadata by @copybara-service[bot] in #5215
New Contributors
- @gurusai-voleti made their first contribution in #5230
Full Changelog: v0.12.3...v0.12.4
0.12.3
What's Changed
- Set numpy<2.4 to fix DeprecationWarning in the CI doctest by @vfdev-5 in #5163
- ignore optax deprecation warning by @cgarciae in #5165
- fix general guides landing page by @ivrolan in #5139
- Remove nnx.split_rngs call wrapping nnx.scan in linen to nnx tutorial by @samanklesaria in #5160
- Add out_sharding arguments to linear layers where supported by @jackopenn in #5156
- fix kw_only_dataclasses for python 3.14 (part 2) by @copybara-service[bot] in #5135
- No public description by @copybara-service[bot] in #5164
- Have _graph_flatten respect nnx.data declarations (extension of #5140) by @samanklesaria in #5159
- [pmap] Avoid degraded performance under the new
jax.pmap. by @copybara-service[bot] in #5152 - Support multiple None and UNCONSTRAINED when resolving logical rules by @copybara-service[bot] in #5129
- improve hijax guide by @cgarciae in #5115
- docs: fix typo 'paramater' -> 'parameter' by @ayulockedin in #5166
- Make
nnx.popremove sown attributes. by @samanklesaria in #5133 - Rename sharding_names to sharding_metadata by @samanklesaria in #5089
- Fix bug in graph overhead benchmark by @samanklesaria in #5183
- Fixed typos in the docstrings using antigravity by @vfdev-5 in #5145
- docs(nnx): add missing functional args to Conv and LinearGeneral by @ayulockedin in #5174
- empty change by @copybara-service[bot] in #5184
- Use nnx split during tabulate (clone of #5069) by @samanklesaria in #5186
- Handle pure bodies in nnx.fori_loop by @samanklesaria in #5141
- Typo fix in _cached_partial method by @vfdev-5 in #5142
- Update mnist example to use NNX (clone of #5064) by @samanklesaria in #5188
- Docs: Fix typo and clarify introduction in Functional API section by @Moriyuki-S in #5157
- Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error:
ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batchdue to split_rngs being replicated. by @copybara-service[bot] in #5189 - Remove ref cycles introduced by self-calling nested functions. by @copybara-service[bot] in #5193
- Add HijaxTransformCoverageTest by @cgarciae in #5190
- allow nnx standalone import p1 by @copybara-service[bot] in #5196
- Add _graph_node_set_key method for List class by @samanklesaria in #5171
- _apply_sharding disallow mixed Explicit/Auto mesh by @copybara-service[bot] in #5199
- update flax to version 0.12.3 by @copybara-service[bot] in #5206
New Contributors
- @ivrolan made their first contribution in #5139
- @jackopenn made their first contribution in #5156
- @ayulockedin made their first contribution in #5166
- @Moriyuki-S made their first contribution in #5157
Full Changelog: v0.12.2...v0.12.3
Version 0.12.2
What's Changed
- [flax:examples:wmt] Small linter fixes. by @copybara-service[bot] in #5012
- [flax:examples:seq2seq] Create main and default config based on seq2seq.ipynb. by @copybara-service[bot] in #5119
- [flax:examples:vae] Small linter fixes. by @copybara-service[bot] in #5014
- [flax:examples:gemma] Fixing linter errors. by @copybara-service[bot] in #5013
- [flax:examples:sst2] Fix pytype errors. by @copybara-service[bot] in #5118
- Allow substring matching in
nnx.PathContainsby @thijs-vanweezel in #5094 - [flax:examples:sst2] Fix notebook error. by @copybara-service[bot] in #5122
- [flax:examples:ppo] Fix some linter / import issues. #jax-fixit by @copybara-service[bot] in #5120
- Avoid passing
concreteargument tojax.rematby @copybara-service[bot] in #5121 - [flax:examples:lm1b_nnx] Update example to work internally. #jax-fixit. by @copybara-service[bot] in #5125
- [flax:examples:nlp_seq] Create a main.py file to run tests with config files to match other examples. #jax-fixit by @copybara-service[bot] in #5126
- [jax:benchmarks] Add tracing/lowering benchmarks for a few flax examples. by @copybara-service[bot] in #4911
- remove abstracted_axes from nnx.jit by @copybara-service[bot] in #5132
- Pooling operation by @jorisSchaller in #5057
- Added is_causal mask argument to flax.nnx.dot_product_attention by @ibbyml in #5093
- Add out_sharding argument to call methods for layers with jax calls that support it by @samanklesaria in #5102
- Temporary fix for failing CI by @vfdev-5 in #5144
- New release 0.12.2 by @IvyZX in #5149
New Contributors
- @thijs-vanweezel made their first contribution in #5094
- @ibbyml made their first contribution in #5093
Full Changelog: v0.12.1...v0.12.2
v0.12.1
Deprecations
Variable.value
Variable.value is now deprecated. Consider the following example:
import jax.numpy as jnp
import jax
from flax import nnx
my_param = nnx.Param({'a': 0.0})
@nnx.jit
def f(m):
m.value['a'] = 1.0
return mRunning f(my_param) produces Param(value={'a': 0.0}), not Param(value={'a': 1.0}) as before. This is because getting the value parameter new returns a copy of the pytree values (like dict / list). Instead, use the __setitem__ method to update the value:
@nnx.jit
def f(m):
m['a'] = 1.0
return mnnx.Data and nnx.Static
nnx.Data and nnx.Static annotations are now deprecated. To create nnx.Pytree or nnx.Module dataclasses use the new nnx.dataclass with nnx.data and nnx.static as field descriptors.
# old
@dataclasses.dataclass
class Foo(nnx.Pytree):
a: nnx.Data[int]
b: nnx.Static[str]
# new
@nnx.dataclass
class Foo(nnx.Pytree):
a: int = nnx.data()
b: str = nnx.static()Pull Requests
- Clarify
*Normlayer docstrings:axis_index_groupsis unused under SPMD jit. by @copybara-service[bot] in #4940 - Move
ArrayRefcreation to the end ofVariablecreation by @IvyZX in #4980 - clean up jax.Ref-related names by @copybara-service[bot] in #4988
- Add compute_flops and compute_vjp_flops options to
nnx.tabulateby @samanklesaria in #4948 - Fix nnx.tabulate crash with empty dict/None values (fixes #4889) by @mohsinm-dev in #4891
- Future-proof imports of jax.new_ref / jax.Ref. by @copybara-service[bot] in #4986
- Use
jnp.stackinstead ofnp.stackinflax.training.common_utils.stack_forestby @vfdev-5 in #4991 - Fixed broken nnx.statelib.diff by @vfdev-5 in #4992
- Implemented spectral norm in NNX by @mattbahr in #4623
- Improve Variable.{get,set}_metadata by @cgarciae in #4985
- Move iter_children and iter_modules to functions by @samanklesaria in #4961
- Avoid install, import, or tests with tensorflow-text under Python 3.13+. by @jburnim in #5001
- disallow setting metadata through settattr by @cgarciae in #4993
- Use sphinx 6.2+ for docs, which works with Python 3.13. by @jburnim in #5009
- Removed kernel_init/bias_init atttributes from popular layers by @vfdev-5 in #4998
- Migrate from
jax.experimental.enable_x64tojax.enable_x64. by @copybara-service[bot] in #5011 - Add Rngs KeylessInitializers by @cgarciae in #5017
- optimize scan transpositions by @cgarciae in #5015
- Variable refactor by @cgarciae in #5006
- Remove invalid gymnasium dependency in pyproject.toml by @IvyZX in #5016
- Use jax.shard_map in flax by @copybara-service[bot] in #5020
- use jax.shard_map by @copybara-service[bot] in #5018
- Fix formatting in PR template checklist by @rapsealk in #5024
- Fixed attribute visualization in treescope_repr by @vfdev-5 in #5022
- feat: add
nnx.set_metadatato in-place change metadata of the state variables ofnnx.Modules by @pfackeldey in #5007 - Update README to use fully qualified
nnx.Linearin example by @rapsealk in #5023 - Fix nnx tabulate variable hooks by @mohsinm-dev in #5008
- python 3.13 support by @cgarciae in #4987
- Added a note in nnx.jit about arg donation by @vfdev-5 in #5031
- Add flip doc link to eager sharding error message by @IvyZX in #5033
- fix reseed for abstract values by @cgarciae in #5034
- Deduplicate
Variablenodes initer_graphand eliminate recursion. by @copybara-service[bot] in #5035 - Support for python 3.14 by @vfdev-5 in #5032
- [docs] Exposed more helper functions/classes in state.rst by @vfdev-5 in #5037
- Copybara import of the project: by @copybara-service[bot] in #5041
- Internal change by @copybara-service[bot] in #5048
- filter grad state in nnx.Optimizer by @copybara-service[bot] in #5049
- Add NNX WeightNorm (update of #4568) by @samanklesaria in #5043
- Fix shard_map documentation link in compilation.py by @vfdev-5 in #5038
- Fix ValueError when
nnx.jitis used withnnx.custom_vjpby @samanklesaria in #5045 - Recursive map by @chapman20j in #5042
- Convert linen pytorch guide to nnx by @samanklesaria in #4999
- Set Mode with Tests by @chapman20j in #5056
- Fixing Optimizer docstring - fixing #5060 by @Lucas-Fernandes-Martins in #5061
- Update tutorial examples to thread explicit RNGs by @samanklesaria in #4975
- Fix NNX jit static args with in_shardings issue #4989 by @mohsinm-dev in #4996
- support explicit sharding in eager sharding by @cgarciae in #5070
- Added missing LayerNorm test case into TestLayersSameGraph by @vfdev-5 in #5076
- fix main by @cgarciae in #5081
- docs: Document
allow_duplicatesargument ofnnx.to_arrays. by @dan-zheng in #5083 - add promote_dtype to all standard layers by @cgarciae in #5080
- add nnx.dataclass by @cgarciae in #5066
- Expand ConvTranspose padding documentation by @samanklesaria in #4990
- Added kernel_metadata/bias_metadata args to nnx layers by @vfdev-5 in #5074
- Add nnx.use_eager_sharding context manager by @samanklesaria in #5079
- fix main by @cgarciae in #5090
- Adding set_mode_info by @chapman20j in #5071
- Fixed nnx.scan with carry as pytree and sow by @vfdev-5 in #5073
- Fix bound method auto-unbinding for NNX transforms by @mohsinm-dev in #5055
- deprecate Variable.value by @cgarciae in #5052
- Add eq for variables by @samanklesaria in #5084
- Fixed deprecated .value usage failing CI tests by @vfdev-5 in #5097
- update jax minver to 0.8.1 by @cgarciae in #5095
New Contributors
- @samanklesaria made their first contribution in #4948
- @jburnim made their first contribution in #5001
- @rapsealk made their first contribution in #5024
- @pfackeldey made their first contribution in #5007
- @chapman20j made their first contribution in #5042
- @Lucas-Fernandes-Martins made their first contribution in #5061
Full Changelog: v0.12.0...v0.12.1
0.12.0
Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.
Breaking Changes
Pytree Strict Attributes
nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:
from flax import nnx
import jax
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = [ # ERROR
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
]
self.bias = None # status = static
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,)) # ERRORThis happens for two reasons:
- JAX pytree structures that contain Arrays now have to be marked with
nnx.data. Alternatively, if the container pytree is alistor adict, you can usennx.Listornnx.Dict, which additionally allow mixed "data" and "static" elements. - Attributes will no longer automatically change their status—this now has to be done explicitly using
nnx.dataornnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.
To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = nnx.List([ # nnx.data also works but List is recommended
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
])
self.bias = nnx.data(None)
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,))For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:
jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))
with jax.set_mesh(mesh):
variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
print(variable.value.sharding)Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:
with jax.set_mesh(mesh):
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model')
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)Eager sharding can also be turned off globally via the flax_always_shard_variable config flag or the FLAX_ALWAYS_SHARD_VARIABLE environment variable:
import flax
flax.config.update('flax_always_shard_variable', False)For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer semantics:
w = nnx.Variable(jnp.array(0))
w += 1 # ERRORThe fix is to simply operate on the .value property instead:
w.value += 1All Changes
- Doc fix: remove dead link to pre-Orbax checkpointing. by @copybara-service[bot] in #4914
- Fix typo in unflatten docs by @copybara-service[bot] in #4918
- fix RNN by @copybara-service[bot] in #4917
- Update optimizer.py to support masked variable from optax. by @ywrt in #4904
- Added missing functions to graph.rst by @vfdev-5 in #4922
- Update flax/docs_nnx/guides/performance.md and .ipynb by @hanrach9 in #4919
- Added preferred_element_type arg to nnx.Linear*, nnx.Conv*, nnx.Einsum by @vfdev-5 in #4920
- Update README badges and remove invalid ones by @IvyZX in #4905
- static + pytree guide by @cgarciae in #4897
- fix mypy by @copybara-service[bot] in #4931
- Avoid passing non-boolean mask to
whereargument ofjax.numpyreductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in #4923 - Ported nnx.PReLU from linen by @vfdev-5 in #4934
- Added nnx.scan docs and few minor docs fixes by @vfdev-5 in #4930
- add variables argument to nnx.clone by @cgarciae in #4945
- only copy dicts on State.getitem by @cgarciae in #4946
- always differentiate standalone Variables in nnx.grad by @cgarciae in #4947
- Implement instance norm in NNX by @mattbahr in #4939
- Automatically apply sharding constraints to sharded models by @IvyZX in #4844
- Add reference of flip doc to gspmd guide by @IvyZX in #4949
- Fixed nnx.is_data docstring rendering by @vfdev-5 in #4957
- expose pytree guide by @cgarciae in #4951
- fix toy examples by @cgarciae in #4952
- Explicitly cast attribute names to string before checking for private attributes. by @copybara-service[bot] in #4955
- add flax_hijax_variable flag by @cgarciae in #4953
- mark shard_map as implemented in transforms guide by @cgarciae in #4738
- improve Variable flatten by @cgarciae in #4954
- Minor typo fix in nnx.call docstring by @vfdev-5 in #4959
- allow split tuples in Rngs.fork by @cgarciae in #4958
- Fixed Gemma example using Gemma2 models by @vfdev-5 in #4830
- finish pytree guide by @cgarciae in #4929
- update bridge wrappers from maxtext by @cgarciae in #4937
- fix HashableMapping hash definition for mixed key types by @copybara-service[bot] in #4936
- Flax RNG guide for jax.jit: clarify rng outputs are shared but not inputs. by @copybara-service[bot] in #4956
- fix Variable pytree flatten by @copybara-service[bot] in #4962
- import PathParts from flax.typing by @cgarciae in #4966
- Correctly expose
flax.config.temp_flip_flagby @IvyZX in #4969 - raise on Variable inplace operators by @cgarciae in #4967
- Copybara import of the project: by @copybara-service[bot] in #4976
- update to version 0.12.0 by @cgarciae in #4982
- Minor typo fixes in flax gspmd guide by @vfdev-5 in #4970
- ignore uv.lock by @copybara-service[bot] in #4974
- [nnx] preserve the function's type information in jit by @cgarciae in #4981
- add Variable.set_metadata by @cgarciae in #4968
- propagate eager sharding by @cgarciae in #4983
New Contributors
Full Changelog: v0.11.2...v0.12.0
0.11.2
What's Changed
nnx.merge now doesn't create a copy of the Variables in the incoming states by default, meaning that the new merged structures holds references to the incoming Variables. This enables new patterns, for example its now possible to create models with the same state but with different runtime behavior:
model = SomeModel(...)
# create eval model
eval_model = nnx.merge(*nnx.split(model)) # same Variables, different structure
eval_model.eval()model and eval_model share the same Variables and are therefore kept in sync but have different runtime behavior, this avoids having to constantly mutate a single model back and forth between different runtime modes which can be error prone / cause unwanted recompilation.
To keep the old behavior use nnx.merge(..., copy=True).
PRs
- add Rngs random helpers by @cgarciae in #4876
- Fix re-export and docs for identity by @jlperla in #4850
- Fix ToLinen docstring return description by @mohsinm-dev in #4852
- Update doc build instructions and clean up unused packages by @IvyZX in #4885
- Improve docs related with dataclasses by @IvyZX in #4884
- Fix broken contributing documentation link by @mohsinm-dev in #4855
- Internal change by @copybara-service[bot] in #4886
- Fix string key preservation in replace_by_pure_dict by @mohsinm-dev in #4860
- Remove the need for Conv and ConvTranspose to know the precise batch size. by @copybara-service[bot] in #4877
- call jax's source_info_util.register_exclusion in flax's traceback_util.register_exclusion by @copybara-service[bot] in #4887
- Update typo in nnx.Optimizer by @codinfox in #4880
- Exposed split_rngs docstring in the docs_nnx by @vfdev-5 in #4846
- Pin sentencepiece version to 0.2.0 to fix head by @IvyZX in #4892
- Relax duplicate check to exclude non-string values such as PartitionSpec.UNCONSTRAINED, since those can be repeated. by @copybara-service[bot] in #4881
- add find_duplicates by @cgarciae in #4894
- Sharding API improvements (non breaking) by @IvyZX in #4893
- document jax.random shorthand methods by @cgarciae in #4899
- Optimiser was already instantiated using the model - 05_vae.py by @nenuadrian in #4857
- revert is_leaf logic in _check_carry_same_references by @copybara-service[bot] in #4903
- Doc fix: remove outdated advice on flax v0.6.10; it was released two years ago. by @copybara-service[bot] in #4910
- Fix bug when raising ScopeParamNotFoundError. by @copybara-service[bot] in #4898
- fix mypy on main by @cgarciae in #4909
- merge no copy Variables by @cgarciae in #4912
- update version to 0.11.2 by @copybara-service[bot] in #4915
New Contributors
- @mohsinm-dev made their first contribution in #4852
- @codinfox made their first contribution in #4880
- @nenuadrian made their first contribution in #4857
Full Changelog: v0.11.1...v0.11.2
v0.11.1
What's Changed
- Make
Sequential()be identity by @SobhanMP in #4796 - Add a JAX/Flax key concepts doc by @IvyZX in #4795
- miscellaneous improvements by @cgarciae in #4859
- Replace
jax.sharding.use_meshwithjax.set_mesh.jax.set_meshcan act as a global setter or a context manager. by @copybara-service[bot] in #4862 - Pytree and ArrayRef refactor by @cgarciae in #4863
- Add old property attributes for object->pytree rename. by @copybara-service[bot] in #4864
- Add BatchNorm layers to CNN in MNIST tutorial for improved training stability by @sanepunk in #4773
- Description by @copybara-service[bot] in #4866
- update and pop for dict by @cgarciae in #4869
- simplify nnx_basics by @cgarciae in #4868
- updates to version 0.11.1 by @cgarciae in #4878
New Contributors
Full Changelog: v0.11.0...v0.11.1
v0.11.0
v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:
Rngsin standard layers: all standard layers no longer hold a shared reference to therngsobject given in the constructor, instead they now keep afork-ed copy of theRngsorRngStreamobjects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.- Optimizer Updates: the Optimizer abstraction no longer holds a reference to the
modelto avoid reference sharing, instead themodelmust be provided as the first argument toupdate. - Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of
splitandmergewhen interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*APIs.
Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
- [nnx] mutable array p3 by @cgarciae in #4755
- [nnx] allow method calls in ToLinen by @cgarciae in #4808
- Internal change by @copybara-service[bot] in #4807
- Preserve sharding information in axes_scan by @copybara-service[bot] in #4806
- Deduplicate contributing and philosophy and move to main site by @IvyZX in #4809
- Fixed nnx.remat docstring rendering by @vfdev-5 in #4790
- Added a note to gemma guide about model's license consent on kaggle by @vfdev-5 in #4776
- [nnx] ToLinen add abtract_init flag by @cgarciae in #4813
- Modify NNX to use id(variable) instead of nnx.Variables as dictionary by @divyashreepathihalli in #4814
- Allow using LazyRngs for flax init/apply. by @copybara-service[bot] in #4818
- [nnx] remove VariableState by @cgarciae in #4800
- Fix failing CI jobs: trailing whitespace, deprecated
.typeusage by @vfdev-5 in #4823 - [nnx] fix Rngs dtype check by @cgarciae in #4820
- refactor: move usages of
.valueto[...]in modules_test.py by @lukeyeh in #4815 - Added training script for Gemma model by @vfdev-5 in #4822
- [nnx] add flax_pytree_module flag by @cgarciae in #4811
- create ModelAndOptimizer symbol by @copybara-service[bot] in #4849
- [nnx] remove Optimizer.model attribute by @cgarciae in #4842
- [nnx] add mutable array support in update by @cgarciae in #4851
- Migrate
transforms_test.pyfrom.valueto[...]by @lukeyeh in #4841 - 0.11.0 migration guide by @cgarciae in #4854
New Contributors
- @divyashreepathihalli made their first contribution in #4814
- @lukeyeh made their first contribution in #4815
Full Changelog: v0.10.7...v0.11.0