File tree Expand file tree Collapse file tree 3 files changed +33
-0
lines changed
Expand file tree Collapse file tree 3 files changed +33
-0
lines changed Original file line number Diff line number Diff line change 1+ import torch .nn .functional as F
2+
3+ from aion .modalities import Spectrum
4+
5+
6+ def pad_spectrum (x : Spectrum ) -> Spectrum :
7+ """Pad a Spectrum object to its specified pad_length.
8+
9+ Note: Each of the sequence attributes (flux, ivar,
10+ mask, wavelength) should be 2D tensors.
11+ """
12+ padding_values = {"lambda" : 99999 , "mask" : True , "ivar" : 0 }
13+
14+ for k in ["flux" , "ivar" , "mask" , "wavelength" ]:
15+ setattr (
16+ x ,
17+ k ,
18+ F .pad (
19+ getattr (x , k ),
20+ (0 , x .pad_length - getattr (x , k ).shape [- 1 ]),
21+ mode = "constant" ,
22+ value = padding_values [k ] if k in padding_values else 0 ,
23+ ),
24+ )
25+
26+ return x
Original file line number Diff line number Diff line change 66from aion .codecs .base import Codec
77from aion .codecs .modules .convnext import ConvNextDecoder1d , ConvNextEncoder1d
88from aion .codecs .modules .spectrum import LatentSpectralGrid
9+ from aion .codecs .preprocessing .spectrum import pad_spectrum
910from aion .codecs .quantizers import LucidrainsLFQ , Quantizer , ScalarLinearQuantizer
1011from aion .codecs .utils import CodecPytorchHubMixin
1112from aion .modalities import Spectrum
@@ -54,6 +55,9 @@ def quantizer(self) -> Quantizer:
5455 return self ._quantizer
5556
5657 def _encode (self , x : Spectrum ) -> Float [torch .Tensor , "b c t" ]:
58+ if hasattr (x , "pad_length" ):
59+ x = pad_spectrum (x )
60+
5761 # Extract fields from Spectrum instance
5862 flux = x .flux
5963 ivar = x .ivar
Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ class Spectrum(Modality):
9797 ivar : Float [Tensor , " batch length" ]
9898 mask : Bool [Tensor , " batch length" ]
9999 wavelength : Float [Tensor , " batch length" ]
100+ pad_length : ClassVar [int ]
100101
101102 def __repr__ (self ) -> str :
102103 repr_str = (
@@ -112,13 +113,15 @@ class DESISpectrum(Spectrum):
112113
113114 token_key : ClassVar [str ] = "tok_spectrum_desi"
114115 num_tokens : ClassVar [int ] = 273
116+ pad_length : ClassVar [int ] = 7808
115117
116118
117119class SDSSSpectrum (Spectrum ):
118120 """SDSS spectrum modality data."""
119121
120122 token_key : ClassVar [str ] = "tok_spectrum_sdss"
121123 num_tokens : ClassVar [int ] = 273
124+ pad_length : ClassVar [int ] = 4800
122125
123126
124127# Catalog modality
You can’t perform that action at this time.
0 commit comments