/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.server;

import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.DataSegmentAndDescriptor;
import org.apache.druid.query.LeafSegmentsBundle;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.segment.IndexIO;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.ReferenceCountedSegmentProvider;
import org.apache.druid.segment.SegmentLazyLoadFailCallback;
import org.apache.druid.segment.SegmentMapFunction;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.TestSegmentUtils;
import org.apache.druid.segment.loading.AcquireSegmentAction;
import org.apache.druid.segment.loading.AcquireSegmentResult;
import org.apache.druid.segment.loading.LeastBytesUsedStorageLocationSelectorStrategy;
import org.apache.druid.segment.loading.LocalDataSegmentPuller;
import org.apache.druid.segment.loading.LocalLoadSpec;
import org.apache.druid.segment.loading.SegmentLoaderConfig;
import org.apache.druid.segment.loading.SegmentLoadingException;
import org.apache.druid.segment.loading.SegmentLocalCacheManager;
import org.apache.druid.segment.loading.StorageLocation;
import org.apache.druid.segment.loading.StorageLocationConfig;
import org.apache.druid.server.SegmentManager.DataSourceState;
import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.VersionedIntervalTimeline;
import org.apache.druid.timeline.partition.NumberedOverwriteShardSpec;
import org.apache.druid.timeline.partition.PartitionIds;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;

public class SegmentManagerTest extends InitializedNullHandlingTest
{
  private static final List<DataSegment> SEGMENTS = ImmutableList.of(
      TestSegmentUtils.makeSegment("small_source", "0", Intervals.of("0/1000")),
      TestSegmentUtils.makeSegment("small_source", "0", Intervals.of("1000/2000")),
      TestSegmentUtils.makeSegment("large_source", "0", Intervals.of("0/1000")),
      TestSegmentUtils.makeSegment("large_source", "0", Intervals.of("1000/2000")),
      TestSegmentUtils.makeSegment("large_source", "1", Intervals.of("1000/2000"))
  );

  private ExecutorService executor;
  private SegmentLocalCacheManager cacheManager;
  private SegmentManager segmentManager;
  private SegmentLocalCacheManager virtualCacheManager;
  private SegmentManager virtualSegmentManager;

  @Rule
  public TemporaryFolder temporaryFolder = new TemporaryFolder();

  @Before
  public void setup() throws IOException
  {
    EmittingLogger.registerEmitter(new NoopServiceEmitter());
    final File segmentCacheDir = temporaryFolder.newFolder();
    final SegmentLoaderConfig loaderConfig = new SegmentLoaderConfig()
    {
      @Override
      public File getInfoDir()
      {
        return segmentCacheDir;
      }

      @Override
      public List<StorageLocationConfig> getLocations()
      {
        return Collections.singletonList(
            new StorageLocationConfig(segmentCacheDir, null, null)
        );
      }
    };

    final File vsfRoot = temporaryFolder.newFolder();
    final File virtualSegmentCacheDir = new File(vsfRoot, "segmentCache");
    FileUtils.mkdirp(virtualSegmentCacheDir);
    final File vsfInfoDir = new File(vsfRoot, "info");
    FileUtils.mkdirp(vsfInfoDir);
    final SegmentLoaderConfig virtualLoaderConfig = new SegmentLoaderConfig()
    {
      @Override
      public File getInfoDir()
      {
        return vsfInfoDir;
      }

      @Override
      public List<StorageLocationConfig> getLocations()
      {
        return Collections.singletonList(
            new StorageLocationConfig(virtualSegmentCacheDir, null, null)
        );
      }

      @Override
      public boolean isVirtualStorage()
      {
        return true;
      }
    };

    final ObjectMapper objectMapper = TestHelper.makeJsonMapper();
    objectMapper.registerSubtypes(TestSegmentUtils.TestLoadSpec.class);
    objectMapper.registerSubtypes(TestSegmentUtils.TestSegmentizerFactory.class);
    objectMapper.registerSubtypes(LocalLoadSpec.class);
    objectMapper.setInjectableValues(
        new InjectableValues.Std()
            .addValue(ExprMacroTable.class.getName(), TestExprMacroTable.INSTANCE)
            .addValue(ObjectMapper.class.getName(), objectMapper)
            .addValue(DataSegment.PruneSpecsHolder.class, DataSegment.PruneSpecsHolder.DEFAULT)
            .addValue(LocalDataSegmentPuller.class, new LocalDataSegmentPuller())
            .addValue(IndexIO.class, TestHelper.getTestIndexIO())
    );

    final List<StorageLocation> storageLocations = loaderConfig.toStorageLocations();
    cacheManager = new SegmentLocalCacheManager(
        storageLocations,
        loaderConfig,
        new LeastBytesUsedStorageLocationSelectorStrategy(storageLocations),
        TestIndex.INDEX_IO,
        objectMapper
    );
    segmentManager = new SegmentManager(cacheManager);

    final List<StorageLocation> virtualStorageLocations = virtualLoaderConfig.toStorageLocations();
    virtualCacheManager = new SegmentLocalCacheManager(
        virtualStorageLocations,
        virtualLoaderConfig,
        new LeastBytesUsedStorageLocationSelectorStrategy(virtualStorageLocations),
        TestIndex.INDEX_IO,
        objectMapper
    );

    virtualSegmentManager = new SegmentManager(virtualCacheManager);
    executor = Execs.multiThreaded(SEGMENTS.size(), "SegmentManagerTest-%d");
  }

  @After
  public void tearDown()
  {
    executor.shutdownNow();
  }

  @Test
  public void testLoadSegment() throws ExecutionException, InterruptedException
  {
    final List<Future<Void>> loadFutures = SEGMENTS.stream()
                                                   .map(
                                                       segment -> executor.submit(
                                                           () -> loadSegmentOrFail(segment)
                                                       )
                                                   )
                                                   .collect(Collectors.toList());

    for (Future<Void> loadFuture : loadFutures) {
      loadFuture.get();
    }

    assertResult(SEGMENTS);
  }

  @Test
  public void testLoadBootstrapSegment() throws ExecutionException, InterruptedException
  {
    final List<Future<Void>> loadFutures = SEGMENTS.stream()
                                                   .map(
                                                       segment -> executor.submit(
                                                           () -> {
                                                             try {
                                                               segmentManager.loadSegmentOnBootstrap(
                                                                   segment,
                                                                   SegmentLazyLoadFailCallback.NOOP
                                                               );
                                                             }
                                                             catch (IOException | SegmentLoadingException e) {
                                                               throw new RuntimeException(e);
                                                             }
                                                             return (Void) null;
                                                           }
                                                       )
                                                   )
                                                   .collect(Collectors.toList());

    for (Future<Void> loadFuture : loadFutures) {
      loadFuture.get();
    }

    assertResult(SEGMENTS);
  }

  @Test
  public void testDropSegment() throws SegmentLoadingException, ExecutionException, InterruptedException, IOException
  {
    List<ReferenceCountedSegmentProvider> referenceProviders = new ArrayList<>();
    for (DataSegment eachSegment : SEGMENTS) {
      segmentManager.loadSegment(eachSegment);
      ReferenceCountedSegmentProvider refProvider = cacheManager.getSegmentReferenceProvider(eachSegment);
      referenceProviders.add(refProvider);
      Assert.assertFalse(refProvider.isClosed());
    }

    final List<Future<Void>> futures = ImmutableList.of(SEGMENTS.get(0), SEGMENTS.get(2)).stream()
                                                    .map(
                                                        segment -> executor.submit(
                                                            () -> {
                                                              segmentManager.dropSegment(segment);
                                                              return (Void) null;
                                                            }
                                                        )
                                                    )
                                                    .collect(Collectors.toList());

    for (Future<Void> eachFuture : futures) {
      eachFuture.get();
    }

    assertResult(
        ImmutableList.of(SEGMENTS.get(1), SEGMENTS.get(3), SEGMENTS.get(4))
    );
    for (int i = 0; i < SEGMENTS.size(); i++) {
      Assert.assertEquals(0, referenceProviders.get(i).getNumReferences());
      if (i == 0 || i == 2) {
        Assert.assertTrue(referenceProviders.get(i).isClosed());
      } else {
        Assert.assertFalse(referenceProviders.get(i).isClosed());
      }
    }
  }

  private Void loadSegmentOrFail(DataSegment segment)
  {
    try {
      segmentManager.loadSegment(segment);
    }
    catch (IOException | SegmentLoadingException e) {
      throw new RuntimeException(e);
    }
    return null;
  }

  @Test
  public void testLoadDropSegment()
      throws SegmentLoadingException, ExecutionException, InterruptedException, IOException
  {
    segmentManager.loadSegment(SEGMENTS.get(0));
    segmentManager.loadSegment(SEGMENTS.get(2));

    final List<Future<Void>> loadFutures = ImmutableList.of(SEGMENTS.get(1), SEGMENTS.get(3), SEGMENTS.get(4))
                                                        .stream()
                                                        .map(
                                                            segment -> executor.submit(() -> loadSegmentOrFail(segment))
                                                        )
                                                        .collect(Collectors.toList());
    final List<Future<Void>> dropFutures = ImmutableList.of(SEGMENTS.get(0), SEGMENTS.get(2)).stream()
                                                        .map(
                                                            segment -> executor.submit(
                                                                () -> {
                                                                  segmentManager.dropSegment(segment);
                                                                  return (Void) null;
                                                                }
                                                            )
                                                        )
                                                        .collect(Collectors.toList());

    for (Future<Void> loadFuture : loadFutures) {
      loadFuture.get();
    }
    for (Future<Void> dropFuture : dropFutures) {
      dropFuture.get();
    }

    assertResult(
        ImmutableList.of(SEGMENTS.get(1), SEGMENTS.get(3), SEGMENTS.get(4))
    );
  }

  @Test
  public void testLoadDuplicatedSegmentsSequentially() throws SegmentLoadingException, IOException
  {
    for (DataSegment segment : SEGMENTS) {
      segmentManager.loadSegment(segment);
    }
    // try to load an existing segment
    segmentManager.loadSegment(SEGMENTS.get(0));

    assertResult(SEGMENTS);
  }

  @Test
  public void testLoadDuplicatedSegmentsInParallel()
      throws ExecutionException, InterruptedException
  {
    final List<Future<Void>> loadFutures = ImmutableList.of(SEGMENTS.get(0), SEGMENTS.get(0), SEGMENTS.get(0))
                                                        .stream()
                                                        .map(
                                                            segment -> executor.submit(
                                                                () -> loadSegmentOrFail(segment)
                                                            )
                                                        )
                                                        .collect(Collectors.toList());

    for (Future<Void> loadFuture : loadFutures) {
      loadFuture.get();
    }

    assertResult(ImmutableList.of(SEGMENTS.get(0)));
  }

  @Test
  public void testNonExistingSegmentsSequentially() throws SegmentLoadingException, IOException
  {
    segmentManager.loadSegment(SEGMENTS.get(0));

    // try to drop a non-existing segment of different data source
    segmentManager.dropSegment(SEGMENTS.get(2));
    assertResult(ImmutableList.of(SEGMENTS.get(0)));
  }

  @Test
  public void testNonExistingSegmentsInParallel()
      throws SegmentLoadingException, ExecutionException, InterruptedException, IOException
  {
    segmentManager.loadSegment(SEGMENTS.get(0));
    final List<Future<Void>> futures = ImmutableList.of(SEGMENTS.get(1), SEGMENTS.get(2))
                                                    .stream()
                                                    .map(
                                                        segment -> executor.submit(
                                                            () -> {
                                                              segmentManager.dropSegment(segment);
                                                              return (Void) null;
                                                            }
                                                        )
                                                    )
                                                    .collect(Collectors.toList());

    for (Future<Void> future : futures) {
      future.get();
    }

    assertResult(ImmutableList.of(SEGMENTS.get(0)));
  }

  @Test
  public void testRemoveEmptyTimeline() throws SegmentLoadingException, IOException
  {
    segmentManager.loadSegment(SEGMENTS.get(0));
    assertResult(ImmutableList.of(SEGMENTS.get(0)));
    Assert.assertEquals(1, segmentManager.getDataSources().size());
    segmentManager.dropSegment(SEGMENTS.get(0));
    Assert.assertEquals(0, segmentManager.getDataSources().size());
  }

  @Test
  public void testGetNonExistingTimeline()
  {
    Assert.assertEquals(
        Optional.empty(),
        segmentManager.getTimeline((new TableDataSource("nonExisting")))
    );
  }

  @Test
  public void testLoadAndDropNonRootGenerationSegment() throws SegmentLoadingException, IOException
  {
    final DataSegment segment = new DataSegment(
        "small_source",
        Intervals.of("0/1000"),
        "0",
        ImmutableMap.of("type", "test", "interval", Intervals.of("0/1000"), "version", 0),
        new ArrayList<>(),
        new ArrayList<>(),
        new NumberedOverwriteShardSpec(
            PartitionIds.NON_ROOT_GEN_START_PARTITION_ID + 10,
            10,
            20,
            (short) 1,
            (short) 1
        ),
        0,
        10
    );

    segmentManager.loadSegment(segment);
    assertResult(ImmutableList.of(segment));

    segmentManager.dropSegment(segment);
    assertResult(ImmutableList.of());
  }

  @Test
  public void testGetSegmentsBundle() throws SegmentLoadingException, IOException
  {
    segmentManager.loadSegment(SEGMENTS.get(0));
    segmentManager.loadSegment(SEGMENTS.get(1));

    DataSegmentAndDescriptor d1 = new DataSegmentAndDescriptor(SEGMENTS.get(0), SEGMENTS.get(0).toDescriptor());
    DataSegmentAndDescriptor d2 = new DataSegmentAndDescriptor(SEGMENTS.get(1), SEGMENTS.get(1).toDescriptor());
    DataSegmentAndDescriptor d3 = new DataSegmentAndDescriptor(SEGMENTS.get(2), SEGMENTS.get(2).toDescriptor());
    DataSegmentAndDescriptor d4 = new DataSegmentAndDescriptor(null, SEGMENTS.get(3).toDescriptor());

    LeafSegmentsBundle bundle = segmentManager.getSegmentsBundle(
        List.of(d1, d2, d3, d4),
        SegmentMapFunction.IDENTITY
    );

    // expect 2 cached segments
    Assert.assertEquals(2, bundle.getCachedSegments().size());
    Assert.assertEquals(
        d1.getDescriptor(),
        bundle.getCachedSegments().get(0).getSegmentDescriptor()
    );
    Assert.assertEquals(
        d2.getDescriptor(),
        bundle.getCachedSegments().get(1).getSegmentDescriptor()
    );
    // no loadable segments since vsf is not enabled
    Assert.assertEquals(
        List.of(),
        bundle.getLoadableSegments()
    );
    // 2 missing segments since cannot load d3 on demand and it was not loaded into the cache
    Assert.assertEquals(
        List.of(d3.getDescriptor(), d4.getDescriptor()),
        bundle.getMissingSegments()
    );
  }

  @Test
  public void testGetSegmentsBundleVirtual()
      throws SegmentLoadingException, IOException, ExecutionException, InterruptedException
  {
    File loc = temporaryFolder.newFolder();
    File seg = TestIndex.persist(TestIndex.getIncrementalTestIndex(), IndexSpec.getDefault(), loc);
    DataSegment toLoad = SEGMENTS.get(1).withLoadSpec(
        Map.of(
            "type", "local",
            "path", seg.getAbsolutePath() + "/"
        )
    );

    final AcquireSegmentAction action = virtualSegmentManager.acquireSegment(toLoad);
    AcquireSegmentResult result = action.getSegmentFuture().get();
    Assert.assertNotNull(result);
    Assert.assertEquals(1L, result.getLoadSizeBytes());
    Assert.assertTrue(result.getLoadTimeNanos() > 0);

    DataSegmentAndDescriptor d1 = new DataSegmentAndDescriptor(SEGMENTS.get(0), SEGMENTS.get(0).toDescriptor());
    DataSegmentAndDescriptor d2 = new DataSegmentAndDescriptor(toLoad, toLoad.toDescriptor());
    DataSegmentAndDescriptor d3 = new DataSegmentAndDescriptor(SEGMENTS.get(2), SEGMENTS.get(2).toDescriptor());
    DataSegmentAndDescriptor d4 = new DataSegmentAndDescriptor(null, SEGMENTS.get(3).toDescriptor());

    LeafSegmentsBundle bundle = virtualSegmentManager.getSegmentsBundle(
        List.of(d1, d2, d3, d4),
        SegmentMapFunction.IDENTITY
    );

    // expect 1 cached segment since we called acquireSegment
    Assert.assertEquals(1, bundle.getCachedSegments().size());
    Assert.assertEquals(
        d2.getDescriptor(),
        bundle.getCachedSegments().get(0).getSegmentDescriptor()
    );
    // 2 loadable segments (in theory, would explode if we tried since they dont have real files)
    Assert.assertEquals(
        List.of(d1, d3),
        bundle.getLoadableSegments()
    );
    // 1 missing segment
    Assert.assertEquals(
        List.of(d4.getDescriptor()),
        bundle.getMissingSegments()
    );
  }

  private void assertResult(List<DataSegment> expectedExistingSegments)
  {
    final Map<String, Long> expectedDataSourceSizes =
        expectedExistingSegments.stream()
                                .collect(Collectors.toMap(DataSegment::getDataSource, DataSegment::getSize, Long::sum));
    final Map<String, Long> expectedDataSourceCounts =
        expectedExistingSegments.stream()
                                .collect(Collectors.toMap(DataSegment::getDataSource, segment -> 1L, Long::sum));
    final Set<String> expectedDataSourceNames = expectedExistingSegments.stream()
                                                                        .map(DataSegment::getDataSource)
                                                                        .collect(Collectors.toSet());
    final Map<String, VersionedIntervalTimeline<String, DataSegment>> expectedTimelines = new HashMap<>();
    for (DataSegment segment : expectedExistingSegments) {
      final VersionedIntervalTimeline<String, DataSegment> expectedTimeline =
          expectedTimelines.computeIfAbsent(
              segment.getDataSource(),
              k -> new VersionedIntervalTimeline<>(Ordering.natural())
          );
      expectedTimeline.add(
          segment.getInterval(),
          segment.getVersion(),
          segment.getShardSpec().createChunk(segment)
      );
    }

    Assert.assertEquals(expectedDataSourceNames, segmentManager.getDataSourceNames());
    Assert.assertEquals(expectedDataSourceCounts, segmentManager.getDataSourceCounts());
    Assert.assertEquals(expectedDataSourceSizes, segmentManager.getDataSourceSizes());

    final Map<String, DataSourceState> dataSources = segmentManager.getDataSources();
    Assert.assertEquals(expectedTimelines.size(), dataSources.size());

    dataSources.forEach(
        (sourceName, dataSourceState) -> {
          Assert.assertEquals(expectedDataSourceCounts.get(sourceName).longValue(), dataSourceState.getNumSegments());
          Assert.assertEquals(
              expectedDataSourceSizes.get(sourceName).longValue(),
              dataSourceState.getTotalSegmentSize()
          );
          Assert.assertEquals(
              expectedTimelines.get(sourceName).getAllTimelineEntries(),
              dataSourceState.getTimeline().getAllTimelineEntries()
          );
        }
    );
  }
}
