/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.room;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsCollectionContaining.hasItem;
import static org.hamcrest.core.IsCollectionContaining.hasItems;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import android.database.Cursor;
import android.database.sqlite.SQLiteException;

import androidx.annotation.NonNull;
import androidx.arch.core.executor.JunitTaskExecutorRule;
import androidx.sqlite.db.SupportSQLiteDatabase;
import androidx.sqlite.db.SupportSQLiteOpenHelper;
import androidx.sqlite.db.SupportSQLiteStatement;

import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

@RunWith(JUnit4.class)
public class InvalidationTrackerTest {
    private InvalidationTracker mTracker;
    private RoomDatabase mRoomDatabase;
    private SupportSQLiteOpenHelper mOpenHelper;
    @Rule
    public JunitTaskExecutorRule mTaskExecutorRule = new JunitTaskExecutorRule(1, true);

    @Before
    public void setup() {
        mRoomDatabase = mock(RoomDatabase.class);
        SupportSQLiteDatabase sqliteDb = mock(SupportSQLiteDatabase.class);
        final SupportSQLiteStatement statement = mock(SupportSQLiteStatement.class);
        mOpenHelper = mock(SupportSQLiteOpenHelper.class);

        doReturn(statement).when(sqliteDb).compileStatement(eq(InvalidationTracker.CLEANUP_SQL));
        doReturn(sqliteDb).when(mOpenHelper).getWritableDatabase();
        doReturn(true).when(mRoomDatabase).isOpen();
        ReentrantLock closeLock = new ReentrantLock();
        doReturn(closeLock).when(mRoomDatabase).getCloseLock();
        //noinspection ResultOfMethodCallIgnored
        doReturn(mOpenHelper).when(mRoomDatabase).getOpenHelper();

        mTracker = new InvalidationTracker(mRoomDatabase, "a", "B", "i");
        mTracker.internalInit(sqliteDb);
    }

    @Before
    public void setLocale() {
        Locale.setDefault(Locale.forLanguageTag("tr-TR"));
    }

    @After
    public void unsetLocale() {
        Locale.setDefault(Locale.US);
    }

    @Test
    public void tableIds() {
        assertThat(mTracker.mTableIdLookup.get("a"), is(0));
        assertThat(mTracker.mTableIdLookup.get("b"), is(1));
    }

    @Test
    public void testWeak() throws InterruptedException {
        final AtomicInteger data = new AtomicInteger(0);
        InvalidationTracker.Observer observer = new InvalidationTracker.Observer("a") {
            @Override
            public void onInvalidated(@NonNull Set<String> tables) {
                data.incrementAndGet();
            }
        };
        mTracker.addWeakObserver(observer);
        setVersions(1, 0);
        refreshSync();
        assertThat(data.get(), is(1));
        observer = null;
        forceGc();
        setVersions(2, 0);
        refreshSync();
        assertThat(data.get(), is(1));
    }

    @Test
    public void addRemoveObserver() throws Exception {
        InvalidationTracker.Observer observer = new LatchObserver(1, "a");
        mTracker.addObserver(observer);
        assertThat(mTracker.mObserverMap.size(), is(1));
        mTracker.removeObserver(new LatchObserver(1, "a"));
        assertThat(mTracker.mObserverMap.size(), is(1));
        mTracker.removeObserver(observer);
        assertThat(mTracker.mObserverMap.size(), is(0));
    }

    private void drainTasks() throws InterruptedException {
        mTaskExecutorRule.drainTasks(200);
    }

    @Test(expected = IllegalArgumentException.class)
    public void badObserver() {
        InvalidationTracker.Observer observer = new LatchObserver(1, "x");
        mTracker.addObserver(observer);
    }

    @Test
    public void refreshReadValues() throws Exception {
        setVersions(1, 0, 2, 1);
        refreshSync();
        assertThat(mTracker.mTableVersions, is(new long[]{1, 2, 0}));

        setVersions(3, 1);
        refreshSync();
        assertThat(mTracker.mTableVersions, is(new long[]{1, 3, 0}));

        setVersions(7, 0);
        refreshSync();
        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));

        refreshSync();
        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));
    }

    private void refreshSync() throws InterruptedException {
        mTracker.refreshVersionsAsync();
        drainTasks();
    }

    @Test
    public void refreshCheckTasks() throws Exception {
        when(mRoomDatabase.query(anyString(), any(Object[].class)))
                .thenReturn(mock(Cursor.class));
        mTracker.refreshVersionsAsync();
        mTracker.refreshVersionsAsync();
        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
        drainTasks();

        reset(mTaskExecutorRule.getTaskExecutor());
        mTracker.refreshVersionsAsync();
        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
    }

    @Test
    public void observe1Table() throws Exception {
        LatchObserver observer = new LatchObserver(1, "a");
        mTracker.addObserver(observer);
        setVersions(1, 0, 2, 1);
        refreshSync();
        assertThat(observer.await(), is(true));
        assertThat(observer.getInvalidatedTables().size(), is(1));
        assertThat(observer.getInvalidatedTables(), hasItem("a"));

        setVersions(3, 1);
        observer.reset(1);
        refreshSync();
        assertThat(observer.await(), is(false));

        setVersions(4, 0);
        refreshSync();
        assertThat(observer.await(), is(true));
        assertThat(observer.getInvalidatedTables().size(), is(1));
        assertThat(observer.getInvalidatedTables(), hasItem("a"));
    }

    @Test
    public void observe2Tables() throws Exception {
        LatchObserver observer = new LatchObserver(1, "A", "B");
        mTracker.addObserver(observer);
        setVersions(1, 0, 2, 1);
        refreshSync();
        assertThat(observer.await(), is(true));
        assertThat(observer.getInvalidatedTables().size(), is(2));
        assertThat(observer.getInvalidatedTables(), hasItems("A", "B"));

        setVersions(3, 1);
        observer.reset(1);
        refreshSync();
        assertThat(observer.await(), is(true));
        assertThat(observer.getInvalidatedTables().size(), is(1));
        assertThat(observer.getInvalidatedTables(), hasItem("B"));

        setVersions(4, 0);
        observer.reset(1);
        refreshSync();
        assertThat(observer.await(), is(true));
        assertThat(observer.getInvalidatedTables().size(), is(1));
        assertThat(observer.getInvalidatedTables(), hasItem("A"));

        observer.reset(1);
        refreshSync();
        assertThat(observer.await(), is(false));
    }

    @Test
    public void locale() {
        LatchObserver observer = new LatchObserver(1, "I");
        mTracker.addObserver(observer);
    }

    @Test
    public void closedDb() {
        doReturn(false).when(mRoomDatabase).isOpen();
        doThrow(new IllegalStateException("foo")).when(mOpenHelper).getWritableDatabase();
        mTracker.addObserver(new LatchObserver(1, "a", "b"));
        mTracker.mRefreshRunnable.run();
    }

    // @Test - disabled due to flakiness b/65257997
    public void closedDbAfterOpen() throws InterruptedException {
        setVersions(3, 1);
        mTracker.addObserver(new LatchObserver(1, "a", "b"));
        mTracker.syncTriggers();
        mTracker.mRefreshRunnable.run();
        doThrow(new SQLiteException("foo")).when(mRoomDatabase).query(
                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
                any(Object[].class));
        mTracker.mPendingRefresh.set(true);
        mTracker.mRefreshRunnable.run();
    }

    /**
     * Key value pairs of VERSION, TABLE_ID
     */
    private void setVersions(int... keyValuePairs) throws InterruptedException {
        // mockito does not like multi-threaded access so before setting versions, make sure we
        // sync background tasks.
        drainTasks();
        Cursor cursor = createCursorWithValues(keyValuePairs);
        doReturn(cursor).when(mRoomDatabase).query(
                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
                any(Object[].class)
        );
    }

    private Cursor createCursorWithValues(final int... keyValuePairs) {
        Cursor cursor = mock(Cursor.class);
        final AtomicInteger index = new AtomicInteger(-2);
        when(cursor.moveToNext()).thenAnswer(new Answer<Boolean>() {
            @Override
            public Boolean answer(InvocationOnMock invocation) throws Throwable {
                return index.addAndGet(2) < keyValuePairs.length;
            }
        });
        Answer<Integer> intAnswer = new Answer<Integer>() {
            @Override
            public Integer answer(InvocationOnMock invocation) throws Throwable {
                return keyValuePairs[index.intValue() + (Integer) invocation.getArguments()[0]];
            }
        };
        Answer<Long> longAnswer = new Answer<Long>() {
            @Override
            public Long answer(InvocationOnMock invocation) throws Throwable {
                return (long) keyValuePairs[index.intValue()
                        + (Integer) invocation.getArguments()[0]];
            }
        };
        when(cursor.getInt(anyInt())).thenAnswer(intAnswer);
        when(cursor.getLong(anyInt())).thenAnswer(longAnswer);
        return cursor;
    }

    static class LatchObserver extends InvalidationTracker.Observer {
        private CountDownLatch mLatch;
        private Set<String> mInvalidatedTables;

        LatchObserver(int count, String... tableNames) {
            super(tableNames);
            mLatch = new CountDownLatch(count);
        }

        boolean await() throws InterruptedException {
            return mLatch.await(3, TimeUnit.SECONDS);
        }

        @Override
        public void onInvalidated(@NonNull Set<String> tables) {
            mInvalidatedTables = tables;
            mLatch.countDown();
        }

        void reset(@SuppressWarnings("SameParameterValue") int count) {
            mInvalidatedTables = null;
            mLatch = new CountDownLatch(count);
        }

        Set<String> getInvalidatedTables() {
            return mInvalidatedTables;
        }
    }

    private static void forceGc() {
        // Use a random index in the list to detect the garbage collection each time because
        // .get() may accidentally trigger a strong reference during collection.
        ArrayList<WeakReference<byte[]>> leak = new ArrayList<>();
        do {
            WeakReference<byte[]> arr = new WeakReference<>(new byte[100]);
            leak.add(arr);
        } while (leak.get((int) (Math.random() * leak.size())).get() != null);
    }
}
