diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java index f5cd53c42ffa..03125f6d8ce8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java @@ -53,7 +53,9 @@ public class LazyFlinkSourceSplitEnumerator private final PipelineOptions pipelineOptions; private final int numSplits; private final List> pendingSplits; + private final List pendingRequests; private boolean splitsInitialized; + private boolean splitsReady; public LazyFlinkSourceSplitEnumerator( SplitEnumeratorContext> context, @@ -66,7 +68,9 @@ public LazyFlinkSourceSplitEnumerator( this.pipelineOptions = pipelineOptions; this.numSplits = numSplits; this.pendingSplits = new ArrayList<>(numSplits); + this.pendingRequests = new ArrayList<>(); this.splitsInitialized = splitInitialized; + this.splitsReady = false; } @Override @@ -94,9 +98,13 @@ public void initializeSplits() { }, (sourceSplits, error) -> { if (error != null) { - pendingSplits.addAll(sourceSplits); throw new RuntimeException("Failed to start source enumerator.", error); } + splitsReady = true; + for (int subtask : pendingRequests) { + handleSplitRequest(subtask, null); + } + pendingRequests.clear(); }); } @@ -113,6 +121,12 @@ public void handleSplitRequest(int subtask, @Nullable String hostname) { LOG.info("Subtask {} {} is requesting a file source split", subtask, hostInfo); } + if (!splitsReady) { + LOG.info("Subtask {} requested split before enumeration done, buffering", subtask); + pendingRequests.add(subtask); + return; + } + if (!pendingSplits.isEmpty()) { final FlinkSourceSplit split = pendingSplits.remove(pendingSplits.size() - 1); context.assignSplit(split, subtask); diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java index d25a0d31c919..fc6af93a34d0 100644 --- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java +++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java @@ -17,23 +17,30 @@ */ package org.apache.beam.sdk.io.snowflake; -import static org.apache.beam.sdk.io.TextIO.readFiles; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; import com.opencsv.CSVParser; import com.opencsv.CSVParserBuilder; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.sql.SQLException; import java.time.LocalDateTime; import java.time.ZoneId; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; import javax.annotation.Nullable; import javax.sql.DataSource; import net.snowflake.client.api.datasource.SnowflakeDataSource; @@ -43,11 +50,13 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Compression; import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.WriteFilesResult; +import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MoveOptions; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema; @@ -59,6 +68,7 @@ import org.apache.beam.sdk.io.snowflake.services.SnowflakeServices; import org.apache.beam.sdk.io.snowflake.services.SnowflakeServicesImpl; import org.apache.beam.sdk.io.snowflake.services.SnowflakeStreamingServiceConfig; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Create; @@ -67,7 +77,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reify; -import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.Values; @@ -424,31 +433,26 @@ public Read withQuotationMark(ValueProvider quotationMark) { public PCollection expand(PBegin input) { checkArguments(); - PCollection emptyCollection = input.apply(Create.of((Void) null)); String tmpDirName = makeTmpDirName(); - PCollection output = - emptyCollection - .apply( - ParDo.of( - new CopyIntoStageFn( - getDataSourceProviderFn(), - getQuery(), - getTable(), - getStorageIntegrationName(), - getStagingBucketName(), - tmpDirName, - getSnowflakeServices(), - getQuotationMark()))) - .apply(Reshuffle.viaRandomKey()) - .apply(FileIO.matchAll()) - .apply(FileIO.readMatches()) - .apply(readFiles()) - .apply(ParDo.of(new MapCsvToStringArrayFn(getQuotationMark()))) - .apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper()))); + SnowflakeBoundedSource source = + new SnowflakeBoundedSource<>( + getDataSourceProviderFn(), + getQuery(), + getTable(), + getStorageIntegrationName(), + getStagingBucketName(), + tmpDirName, + getSnowflakeServices(), + getQuotationMark(), + getCsvMapper(), + getCoder()); + + PCollection output = input.apply(org.apache.beam.sdk.io.Read.from(source)); output.setCoder(getCoder()); - emptyCollection + input + .apply(Create.of((Void) null)) .apply(Wait.on(output)) .apply(ParDo.of(new CleanTmpFilesFromGcsFn(getStagingBucketName(), tmpDirName))); return output; @@ -483,103 +487,212 @@ private String makeTmpDirName() { ); } - private static class CopyIntoStageFn extends DoFn { + /** + * A {@link BoundedSource} that reads from Snowflake by running COPY INTO to stage CSV files, + * then splitting into one sub-source per file. + */ + private static class SnowflakeBoundedSource extends BoundedSource { + private static final Logger LOG = LoggerFactory.getLogger(SnowflakeBoundedSource.class); + private final SerializableFunction dataSourceProviderFn; - private final ValueProvider query; - private final ValueProvider database; - private final ValueProvider schema; - private final ValueProvider table; + private final @Nullable ValueProvider query; + private final @Nullable ValueProvider table; private final ValueProvider storageIntegrationName; - private final ValueProvider stagingBucketDir; + private final ValueProvider stagingBucketName; private final String tmpDirName; private final SnowflakeServices snowflakeServices; private final ValueProvider quotationMark; + private final CsvMapper csvMapper; + private final Coder coder; + + // Non-null only for child sources (one per staged file) + private final @Nullable String filePath; + private final long fileSize; - private CopyIntoStageFn( + SnowflakeBoundedSource( SerializableFunction dataSourceProviderFn, - ValueProvider query, - ValueProvider table, + @Nullable ValueProvider query, + @Nullable ValueProvider table, ValueProvider storageIntegrationName, - ValueProvider stagingBucketDir, + ValueProvider stagingBucketName, String tmpDirName, SnowflakeServices snowflakeServices, - ValueProvider quotationMark) { + ValueProvider quotationMark, + CsvMapper csvMapper, + Coder coder) { + this( + dataSourceProviderFn, + query, + table, + storageIntegrationName, + stagingBucketName, + tmpDirName, + snowflakeServices, + quotationMark, + csvMapper, + coder, + null, + 0); + } + + private SnowflakeBoundedSource( + SerializableFunction dataSourceProviderFn, + @Nullable ValueProvider query, + @Nullable ValueProvider table, + ValueProvider storageIntegrationName, + ValueProvider stagingBucketName, + String tmpDirName, + SnowflakeServices snowflakeServices, + ValueProvider quotationMark, + CsvMapper csvMapper, + Coder coder, + @Nullable String filePath, + long fileSize) { this.dataSourceProviderFn = dataSourceProviderFn; this.query = query; this.table = table; this.storageIntegrationName = storageIntegrationName; + this.stagingBucketName = stagingBucketName; + this.tmpDirName = tmpDirName; this.snowflakeServices = snowflakeServices; this.quotationMark = quotationMark; - this.stagingBucketDir = stagingBucketDir; - this.tmpDirName = tmpDirName; - DataSourceProviderFromDataSourceConfiguration - dataSourceProviderFromDataSourceConfiguration = - (DataSourceProviderFromDataSourceConfiguration) this.dataSourceProviderFn; - DataSourceConfiguration config = dataSourceProviderFromDataSourceConfiguration.getConfig(); - - this.database = config.getDatabase(); - this.schema = config.getSchema(); + this.csvMapper = csvMapper; + this.coder = coder; + this.filePath = filePath; + this.fileSize = fileSize; } - @ProcessElement - public void processElement(ProcessContext context) throws Exception { - String databaseValue = getValueOrNull(this.database); - String schemaValue = getValueOrNull(this.schema); - String tableValue = getValueOrNull(this.table); - String queryValue = getValueOrNull(this.query); + @Override + public List> split( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + if (filePath != null) { + return Collections.singletonList(this); + } String stagingBucketRunDir = String.format( "%s/%s/run_%s/", - stagingBucketDir.get(), tmpDirName, UUID.randomUUID().toString().subSequence(0, 8)); + stagingBucketName.get(), + tmpDirName, + UUID.randomUUID().toString().subSequence(0, 8)); - SnowflakeBatchServiceConfig config = + DataSourceProviderFromDataSourceConfiguration dsProvider = + (DataSourceProviderFromDataSourceConfiguration) dataSourceProviderFn; + DataSourceConfiguration config = dsProvider.getConfig(); + + SnowflakeBatchServiceConfig batchConfig = new SnowflakeBatchServiceConfig( dataSourceProviderFn, - databaseValue, - schemaValue, - tableValue, - queryValue, + getValueOrNull(config.getDatabase()), + getValueOrNull(config.getSchema()), + getValueOrNull(table), + getValueOrNull(query), storageIntegrationName.get(), stagingBucketRunDir, quotationMark.get()); - String output = snowflakeServices.getBatchService().read(config); + LOG.info("Running Snowflake COPY INTO stage: {}", stagingBucketRunDir); + String globPattern = snowflakeServices.getBatchService().read(batchConfig); + + List files = FileSystems.match(globPattern).metadata(); + LOG.info("Snowflake COPY INTO produced {} files", files.size()); + + return files.stream() + .map( + metadata -> + new SnowflakeBoundedSource( + dataSourceProviderFn, + query, + table, + storageIntegrationName, + stagingBucketName, + tmpDirName, + snowflakeServices, + quotationMark, + csvMapper, + coder, + metadata.resourceId().toString(), + metadata.sizeBytes())) + .collect(Collectors.toList()); + } - context.output(output); + @Override + public long getEstimatedSizeBytes(PipelineOptions options) { + return fileSize; } - } - /** - * Parses {@code String} from incoming data in {@link PCollection} to have proper format for CSV - * files. - */ - public static class MapCsvToStringArrayFn extends DoFn { - private ValueProvider quoteChar; + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + if (filePath == null) { + throw new IOException("Cannot create reader from unsplit parent source"); + } + return new SnowflakeFileReader<>(this); + } - public MapCsvToStringArrayFn(ValueProvider quoteChar) { - this.quoteChar = quoteChar; + @Override + public Coder getOutputCoder() { + return coder; } - @ProcessElement - public void processElement(ProcessContext c) throws IOException { - String csvLine = c.element(); - CSVParser parser = new CSVParserBuilder().withQuoteChar(quoteChar.get().charAt(0)).build(); - String[] parts = parser.parseLine(csvLine); - c.output(parts); + @Override + public void validate() { + // Validation is done in SnowflakeIO.Read.checkArguments() } - } - private static class MapStringArrayToUserDataFn extends DoFn { - private final CsvMapper csvMapper; + private static class SnowflakeFileReader extends BoundedReader { + private final SnowflakeBoundedSource source; + private transient BufferedReader reader; + private transient CSVParser csvParser; + private T current; - public MapStringArrayToUserDataFn(CsvMapper csvMapper) { - this.csvMapper = csvMapper; - } + SnowflakeFileReader(SnowflakeBoundedSource source) { + this.source = source; + } - @ProcessElement - public void processElement(ProcessContext context) throws Exception { - context.output(csvMapper.mapRow(context.element())); + @Override + public boolean start() throws IOException { + ResourceId resourceId = FileSystems.matchNewResource(source.filePath, false); + ReadableByteChannel channel = FileSystems.open(resourceId); + InputStream inputStream = new GZIPInputStream(Channels.newInputStream(channel)); + + reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); + csvParser = + new CSVParserBuilder().withQuoteChar(source.quotationMark.get().charAt(0)).build(); + + return advance(); + } + + @Override + public boolean advance() throws IOException { + String line = reader.readLine(); + if (line == null) { + return false; + } + try { + String[] parts = csvParser.parseLine(line); + current = source.csvMapper.mapRow(parts); + return true; + } catch (Exception e) { + throw new IOException("Error mapping CSV row: " + line, e); + } + } + + @Override + public T getCurrent() { + return current; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + } + } + + @Override + public BoundedSource getCurrentSource() { + return source; + } } } diff --git a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBatchServiceImpl.java b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBatchServiceImpl.java index 79a3900f3a2c..47361cb504ae 100644 --- a/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBatchServiceImpl.java +++ b/sdks/java/io/snowflake/src/test/java/org/apache/beam/sdk/io/snowflake/test/FakeSnowflakeBatchServiceImpl.java @@ -18,12 +18,15 @@ package org.apache.beam.sdk.io.snowflake.test; import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import java.util.zip.GZIPOutputStream; import net.snowflake.client.api.exception.SnowflakeSQLException; import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema; import org.apache.beam.sdk.io.snowflake.enums.CreateDisposition; @@ -123,8 +126,13 @@ private void writeToFile(String stagingBucketNameTmp, List rows) { Path filePath = Paths.get(String.format("./%s/table.csv.gz", stagingBucketNameTmp)); try { Files.createDirectories(filePath.getParent()); - Files.createFile(filePath); - Files.write(filePath, rows); + try (OutputStream os = Files.newOutputStream(filePath); + GZIPOutputStream gzip = new GZIPOutputStream(os)) { + for (String row : rows) { + gzip.write(row.getBytes(StandardCharsets.UTF_8)); + gzip.write('\n'); + } + } } catch (IOException e) { throw new RuntimeException("Failed to create files", e); }