Skip to content

Commit 005689a

Browse files
authored
Merge pull request #64 from PolymathicAI/fix/spectrum_crop
automatically crop spectra
2 parents 926c515 + 6a925d7 commit 005689a

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch.nn.functional as F
2+
3+
from aion.modalities import Spectrum
4+
5+
6+
def pad_spectrum(x: Spectrum) -> Spectrum:
7+
"""Pad a Spectrum object to its specified pad_length.
8+
9+
Note: Each of the sequence attributes (flux, ivar,
10+
mask, wavelength) should be 2D tensors.
11+
"""
12+
padding_values = {"lambda": 99999, "mask": True, "ivar": 0}
13+
14+
for k in ["flux", "ivar", "mask", "wavelength"]:
15+
setattr(
16+
x,
17+
k,
18+
F.pad(
19+
getattr(x, k),
20+
(0, x.pad_length - getattr(x, k).shape[-1]),
21+
mode="constant",
22+
value=padding_values[k] if k in padding_values else 0,
23+
),
24+
)
25+
26+
return x

aion/codecs/spectrum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from aion.codecs.base import Codec
77
from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d
88
from aion.codecs.modules.spectrum import LatentSpectralGrid
9+
from aion.codecs.preprocessing.spectrum import pad_spectrum
910
from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer
1011
from aion.codecs.utils import CodecPytorchHubMixin
1112
from aion.modalities import Spectrum
@@ -54,6 +55,9 @@ def quantizer(self) -> Quantizer:
5455
return self._quantizer
5556

5657
def _encode(self, x: Spectrum) -> Float[torch.Tensor, "b c t"]:
58+
if hasattr(x, "pad_length"):
59+
x = pad_spectrum(x)
60+
5761
# Extract fields from Spectrum instance
5862
flux = x.flux
5963
ivar = x.ivar

aion/modalities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Spectrum(Modality):
9797
ivar: Float[Tensor, " batch length"]
9898
mask: Bool[Tensor, " batch length"]
9999
wavelength: Float[Tensor, " batch length"]
100+
pad_length: ClassVar[int]
100101

101102
def __repr__(self) -> str:
102103
repr_str = (
@@ -112,13 +113,15 @@ class DESISpectrum(Spectrum):
112113

113114
token_key: ClassVar[str] = "tok_spectrum_desi"
114115
num_tokens: ClassVar[int] = 273
116+
pad_length: ClassVar[int] = 7808
115117

116118

117119
class SDSSSpectrum(Spectrum):
118120
"""SDSS spectrum modality data."""
119121

120122
token_key: ClassVar[str] = "tok_spectrum_sdss"
121123
num_tokens: ClassVar[int] = 273
124+
pad_length: ClassVar[int] = 4800
122125

123126

124127
# Catalog modality

0 commit comments

Comments
 (0)