import type { Database } from "@local/sqlite-wasm-patch";
import {
  iterateRecordMap,
  PointerWithRecord,
  RecordMap,
  RecordPointer,
  RecordTable,
  RecordValue,
  setMapRecord,
  TABLE_NAMES,
  virtualTables,
  getPointer,
  getMapRecords,
  createRecordMapFromPointersWithRecords,
  ClientSingletonRecord,
  ClientSingletonRecordName,
  getMapRecord,
} from "libs/schema";
import { getSqlToUpsertRecord } from "libs/getSqlToUpsertRecord";
import { hasIntersection, isNonNullable } from "libs/predicates";
import { Logger } from "libs/logger";
import { isDecoderSuccess } from "ts-decoders";
import { SqlValue, Statement, sql } from "libs/sql-statement";
import { TransactionConflictError, ValidationError } from "libs/errors";
import { memoize, merge, uniq } from "lodash-comms";
import { GetRecordResult, GetRecordsResult, Query, QueryResult } from "libs/database";
import { ClientDatabaseAdapterApi, DatabaseChange } from "./ClientDatabaseAdapterApi";
import { Observable, combineLatest, map, of, share } from "rxjs";
import { astVisitor, parse, Statement as PgAstStatement } from "pgsql-ast-parser";
import { cacheReplayForTime, startWith } from "libs/rxjs-operators";
import { fromDatabaseDecoders, recordToDatabaseFnMap } from "libs/schema/client/decoders";
import { PartialDeep } from "type-fest";
import { groupPointersByTable } from "../utils";

export class ClientDatabaseAdapter implements ClientDatabaseAdapterApi {
  protected changeSubscriptions = new Set<(change: DatabaseChange) => void>();
  protected observeRecordCache = new Map<string, Observable<GetRecordResult<RecordTable>>>();
  protected db: Database;
  protected logger: Logger;
  protected isTransaction: boolean;

  constructor(props: { db: Database; logger: Logger; isTransaction?: boolean }) {
    this.db = props.db;
    this.logger = props.logger;
    this.isTransaction = props.isTransaction ?? false;
  }

  query(statement: Statement) {
    try {
      const rows = this.db.exec({
        sql: statement.text,
        bind: statement.values,
        returnValue: "resultRows",
        rowMode: "object",
      });

      return { rows } as {
        rows: Array<{ [columnName: string]: SqlValue }>;
      };
    } catch (error) {
      this.logger.error({ error, statement }, "[query] error");
      throw error;
    }
  }

  transaction<T>(fn: (tx: ClientDatabaseAdapter) => T): T {
    if (this.isTransaction) {
      throw new Error("ServerDatabaseAdapter: nested transactions not supported");
    }

    return this.db.transaction((db) => {
      const tx = new ClientDatabaseAdapter({ db, logger: this.logger, isTransaction: true });
      return fn(tx);
    });
  }

  /**
   * Will return -1 if the database is empty, -2 if the database is seemingly in an invalid state.
   */
  getSchemaVersion(): number {
    const { rows: tables } = this.query(sql`
      SELECT
        name
      FROM
        sqlite_master
      WHERE
        type = 'table'
    `);

    if (tables.length === 0) return -1;

    // If there are some tables but there isn't a migration table, that indicates
    // that creating the database schema failed partway through and the database is in
    // an invalid state.
    if (!tables.some((table) => table.name === "migration")) {
      return -2;
    }

    const { rows: migrations } = this.query(sql`
      SELECT
        "id"
      FROM
        "migration"
      ORDER BY
        "id" DESC
      LIMIT 1
    `);

    return (migrations[0]?.id as number | undefined) ?? -2;
  }

  getRecord<T extends RecordTable>(
    table: T,
    id: string,
    options?: { includeSoftDeletes?: boolean },
  ): GetRecordResult<T>;
  getRecord<T extends RecordTable>(
    pointer: RecordPointer<T>,
    options?: { includeSoftDeletes?: boolean },
  ): GetRecordResult<T>;
  getRecord<T extends RecordTable>(
    a: T | RecordPointer<T>,
    b?: string | { includeSoftDeletes?: boolean },
    c?: { includeSoftDeletes?: boolean },
  ): GetRecordResult<T> {
    const pointer = typeof a === "string" ? { table: a, id: b as string } : a;
    const options = typeof b === "object" ? b : c;

    if (!TABLE_NAMES.includes(pointer.table)) {
      throw new ValidationError(`getRecord: invalid record table "${pointer.table}"`);
    }

    const statement = sql`
      SELECT
        * 
      FROM 
        "${sql.raw(pointer.table)}"
      WHERE 
        "${sql.raw(pointer.table)}".id = ${pointer.id}
      LIMIT 1`;

    const {
      rows: [row],
    } = this.query(statement);

    if (!row) return [null];

    const record = this.decodeRecordFromDatabase(pointer.table, row);

    if (!options?.includeSoftDeletes && record?.deleted_at) {
      return [null];
    }

    return [record];
  }

  /**
   * Performance optimized method to subscribe to multiple records at once. Provide the `includeSoftDeletes`
   * option to include deleted records in the results. Note that the order of the records in the result will
   * match the order of the pointers in the input array (though some records may be omitted if they were not
   * found).
   */
  getRecords<Table extends RecordTable>(
    pointers: RecordPointer<Table>[],
    options: { includeSoftDeletes?: boolean } = {},
  ): GetRecordsResult<Table> {
    const groupedPointers = groupPointersByTable(pointers);

    if (groupedPointers.length === 0) return [[]];

    const statements = groupedPointers.reduce(
      (store, [table, pointers]) => {
        if (pointers.length === 0) return store;
        if (virtualTables[table]) return store;

        store.push({
          table: table as RecordTable,
          statement: sql`
            SELECT
              * 
            FROM 
              "${sql.raw(table)}"
            WHERE 
              "id" IN (${sql.join(pointers.map((p) => p.id))})
            ${options.includeSoftDeletes ? sql.EMPTY : sql`AND "deleted_at" IS NULL`};
          `,
        });

        return store;
      },
      [] as { table: RecordTable; statement: Statement }[],
    );

    // Create a map of pointer positions from the input array for stable sorting
    const pointerPositions = new Map(pointers.map((pointer, index) => [`${pointer.table}:${pointer.id}`, index]));

    const txFunction = (tx: ClientDatabaseAdapter) => {
      const pointersWithRecord = statements.flatMap(({ table, statement }) => {
        const { rows } = tx.query(statement);

        return rows.map((row) => {
          return {
            table,
            id: row.id as string,
            record: tx.decodeRecordFromDatabase(table, row),
          } as PointerWithRecord;
        });
      });

      // Sort based on original position in pointers array
      pointersWithRecord.sort((a, b) => {
        const posA = pointerPositions.get(`${a.table}:${a.id}`);
        const posB = pointerPositions.get(`${b.table}:${b.id}`);
        return posA! - posB!;
      });

      return pointersWithRecord;
    };

    try {
      const pointersWithRecord = this.isTransaction ? txFunction(this) : this.transaction(txFunction);
      return [pointersWithRecord as PointerWithRecord<Table>[]];
    } catch (error) {
      this.logger.error({ error, pointers, options }, "[getRecords] transaction error");
      throw error;
    }
  }

  getQuery<Table extends RecordTable>(query: Query<Table>): QueryResult<Table> {
    const txFunction = (tx: ClientDatabaseAdapter) => {
      return query.statements().map((s) => tx.query(s).rows);
    };

    const results = this.isTransaction ? txFunction(this) : this.transaction(txFunction);

    const parsedResults = query.parseQueryResults(results);

    const pointersWithRecord = parsedResults.flat(2) as PointerWithRecord[];

    const recordMap = createRecordMapFromPointersWithRecords(pointersWithRecord);

    const records = getMapRecords(recordMap, query.primaryTable);

    return [records, { recordMap }] as unknown as QueryResult<Table>;
  }

  observeRecord<T extends RecordTable>(
    table: T,
    id: string,
    options?: { includeSoftDeletes?: boolean },
  ): Observable<GetRecordResult<T>>;
  observeRecord<T extends RecordTable>(
    pointer: RecordPointer<T>,
    options?: { includeSoftDeletes?: boolean },
  ): Observable<GetRecordResult<T>>;
  observeRecord<T extends RecordTable>(
    a: T | RecordPointer<T>,
    b?: string | { includeSoftDeletes?: boolean },
    c?: { includeSoftDeletes?: boolean },
  ): Observable<GetRecordResult<T>> {
    const pointer = typeof a === "string" ? { table: a, id: b as string } : a;
    const options = typeof b === "object" ? b : c;
    const includeSoftDeletes = options?.includeSoftDeletes ?? false;

    const cacheKey = `${pointer.table}:${pointer.id}:${includeSoftDeletes}`;
    const cachedQuery = this.observeRecordCache.get(cacheKey);

    if (cachedQuery) {
      return cachedQuery as Observable<GetRecordResult<T>>;
    }

    if (!TABLE_NAMES.includes(pointer.table)) {
      throw new ValidationError(`liveRecord: invalid record table "${pointer.table}"`);
    }

    const runQuery = () => this.getRecord(pointer, options);

    const subscribe = (onChange: () => void) =>
      this.subscribeToRecordChanges(({ changes }) => {
        if (!getMapRecord(changes, pointer)) return;
        onChange();
      });

    const observable = getObservableForQuery<GetRecordResult<T>>({
      runQuery,
      subscribe,
    }).pipe(
      cacheReplayForTime({
        timeMs: 100,
        onInit: () => {
          this.observeRecordCache.set(cacheKey, observable);
        },
        onCleanup: () => {
          this.observeRecordCache.delete(cacheKey);
        },
      }),
    );

    return observable;
  }

  /**
   * Performance optimized method to subscribe to multiple records at once. Provide the `includeSoftDeletes`
   * option to include deleted records in the results. Note that the order of the records in the result will
   * match the order of the pointers in the input array (though some records may be omitted if they were not
   * found).
   */
  observeRecords<Table extends RecordTable>(
    pointers: RecordPointer<Table>[],
    options?: { includeSoftDeletes?: boolean },
  ): Observable<GetRecordsResult<Table>> {
    if (pointers.length === 0) {
      return of([[]]);
    }

    return combineLatest(
      pointers.map((pointer) =>
        this.observeRecord(pointer, options).pipe(
          map(([record]) => {
            if (!record) return null;
            return { ...pointer, record } as PointerWithRecord<Table>;
          }),
        ),
      ),
    ).pipe(map((records) => [records.filter(isNonNullable)]));
  }

  observeQuery<Table extends RecordTable>(query: Query<Table>): Observable<QueryResult<Table>> {
    const affectedTableNames = uniq(query.statements().flatMap((s) => parseTableNames(s.text)));

    if (affectedTableNames.length === 0) {
      this.logger.error(
        {
          primaryTable: query.primaryTable,
          statements: query.statements().map((s) => ({ text: s.text, values: s.values })),
        },
        "[observeQuery] [parseTableNames] could not calculate selected table names",
      );

      throw new Error("[observeQuery] [parseTableNames] could not calculate selected table names");
    }

    const runQuery = () => this.getQuery(query);

    const subscribe = (onChange: () => void) =>
      this.subscribeToRecordChanges(({ tableNames }) => {
        if (!hasIntersection(affectedTableNames, tableNames)) return;
        onChange();
      });

    return getObservableForQuery({
      runQuery,
      subscribe,
    });
  }

  writeRecordMap(
    recordMap: RecordMap | PointerWithRecord[],
    options: { forceUpdate?: boolean; throwOnVersionMismatch?: boolean } = {},
  ) {
    const afterPointersWithRecord = Array.isArray(recordMap) ? recordMap : Array.from(iterateRecordMap(recordMap));

    if (afterPointersWithRecord.length === 0) return {};

    const { forceUpdate = false, throwOnVersionMismatch = false } = options;

    this.logger.verbose({ pointersWithRecord: afterPointersWithRecord }, "writeRecordMap");

    const changeRecordMap: RecordMap = {};

    const ignoredRecords: Array<{
      table: string;
      id: string;
      incoming: RecordValue<any>;
      existing: RecordValue<any> | null;
    }> = [];

    try {
      this.transaction((tx) => {
        const [beforePointersWithRecord] = tx.getRecords(afterPointersWithRecord, { includeSoftDeletes: true });

        const beforeRecordMap = createRecordMapFromPointersWithRecords(beforePointersWithRecord);

        for (const pointerWithRecord of afterPointersWithRecord) {
          if (virtualTables[pointerWithRecord.table]) continue;

          const existingRecord = getMapRecord(beforeRecordMap, pointerWithRecord);

          const oldVersion = existingRecord?.version as number | undefined;

          if (forceUpdate || oldVersion === undefined || pointerWithRecord.record.version > oldVersion) {
            setMapRecord(changeRecordMap, pointerWithRecord, pointerWithRecord.record);

            const record = tx.encodeRecordForDatabase(pointerWithRecord);

            const statement = getSqlToUpsertRecord(
              {
                table: pointerWithRecord.table,
                id: pointerWithRecord.id,
                record,
              } as any,
              forceUpdate,
            );

            tx.query(statement);
          } else if (throwOnVersionMismatch && pointerWithRecord.record.version <= oldVersion) {
            tx.logger.warn({ existingRecord, pointerWithRecord }, `writeRecordMap: record version conflict`);
            throw new TransactionConflictError(`writeRecordMap: record version conflict`);
          } else {
            ignoredRecords.push({
              table: pointerWithRecord.table,
              id: pointerWithRecord.id,
              incoming: pointerWithRecord.record,
              existing: existingRecord,
            });
          }
        }

        // Regardless of whether or not we write the new records to the database,
        // we mark the records as having been read for garbage collection.
        tx.markRecordsRead(afterPointersWithRecord);
      });
    } catch (error) {
      this.logger.error({ error, recordMap, options }, "[writeRecordMap] transaction error");
      throw error;
    }

    if (ignoredRecords.length > 0) {
      this.logger.verbose(
        {
          forceUpdate,
          throwOnVersionMismatch,
          ignoredRecords,
        },
        "[writeRecordMap] ignored records",
      );
    }

    const changedTables = Object.keys(changeRecordMap);

    if (changedTables.length > 0) {
      this.emitDatabaseChange({
        tableNames: changedTables,
        changes: changeRecordMap,
      });
    }

    return changeRecordMap;
  }

  deletePointers(pointers: RecordPointer[], options: { suppressChangeNotifications?: boolean } = {}) {
    let changeRecordMap: RecordMap | null = null;

    const txFunction = (tx: ClientDatabaseAdapter) => {
      let groupedPointers: ReturnType<typeof groupPointersByTable>;

      if (options.suppressChangeNotifications) {
        groupedPointers = groupPointersByTable(pointers);
      } else {
        const [pointersWithRecord] = tx.getRecords(pointers, { includeSoftDeletes: true });
        changeRecordMap = createRecordMapFromPointersWithRecords(pointersWithRecord);
        groupedPointers = groupPointersByTable(pointersWithRecord.map(getPointer));
      }

      if (groupedPointers.length === 0) return;

      const statements = groupedPointers.reduce(
        (store, [table, pointers]) => {
          if (pointers.length === 0) return store;
          if (virtualTables[table]) return store;

          store.push({
            table,
            statement: sql`
              DELETE FROM
                "${sql.raw(table)}"
              WHERE
                "id" IN (${sql.join(pointers.map((p) => p.id))});
            `,
          });

          return store;
        },
        [] as { table: RecordTable; statement: Statement }[],
      );

      for (const { statement } of statements) {
        tx.query(statement);
      }
    };

    try {
      if (this.isTransaction) {
        txFunction(this);
      } else {
        this.transaction(txFunction);
      }
    } catch (error) {
      this.logger.error({ error, pointers }, "[deletePointers] transaction error");
      throw error;
    }

    if (!changeRecordMap) return;

    const change: DatabaseChange = {
      tableNames: Object.keys(changeRecordMap),
      changes: changeRecordMap,
    };

    if (!change.tableNames.length) return;

    this.emitDatabaseChange(change);
  }

  /** Emissions may contain records which are not changed. */
  subscribeToRecordChanges(callback: (change: DatabaseChange) => void) {
    this.changeSubscriptions.add(callback);

    return () => {
      this.changeSubscriptions.delete(callback);
    };
  }

  markRecordsRead(pointers: RecordPointer[]): void {
    const nonVirtualPointers = pointers.filter((p) => !virtualTables[p.table]);

    if (nonVirtualPointers.length === 0) return;

    const now = Date.now();

    const statement = sql`
      INSERT INTO "client_row_read" 
        ("row_table", "row_id", "read_at")
      VALUES 
        ${sql.bulk(nonVirtualPointers.map(({ table, id }) => [table, id, now]))}
      ON CONFLICT ("row_table", "row_id")
      DO UPDATE SET
        "read_at" = ${now}
    `;

    this.query(statement);
  }

  getRecordsLastReadBefore(time: number): RecordPointer[] {
    const statement = sql`
      SELECT
        *
      FROM
        "client_row_read"
      WHERE
        "client_row_read"."read_at" < ${time}
    `;

    const { rows } = this.query(statement);

    return rows.map((row) => getPointer(row.row_table as RecordTable, row.row_id as string));
  }

  getSingletonRecord<Name extends ClientSingletonRecordName>(name: Name): ClientSingletonRecord<Name> {
    const operation = (tx: ClientDatabaseAdapter) => {
      const { rows } = tx.query(
        sql`
          SELECT
            *
          FROM
            "singletons"
          WHERE
            "name" = ${name}
          LIMIT
            1  
        `,
      );

      let record = rows[0];

      if (!record) {
        record = {
          name,
          data: JSON.stringify({}),
          version: 1,
          updated_at: new Date().toISOString(),
        };

        tx.query(
          sql`
            INSERT INTO "singletons" (
              "name",
              "data",
              "version",
              "updated_at"
            ) VALUES (
              ${record.name},
              ${record.data},
              ${record.version},
              ${record.updated_at}
            )
          `,
        );
      }

      return {
        ...record,
        data: JSON.parse(record.data as string),
      } as ClientSingletonRecord<Name>;
    };

    try {
      if (this.isTransaction) return operation(this);
      return this.transaction(operation);
    } catch (error) {
      this.logger.error({ error, singletonName: name }, "[getSingletonRecord] transaction error");
      throw error;
    }
  }

  updateSingletonRecord<Name extends ClientSingletonRecordName>(
    name: Name,
    update: PartialDeep<ClientSingletonRecord<Name>["data"]>,
  ): ClientSingletonRecord<Name> {
    const operation = (tx: ClientDatabaseAdapter) => {
      const existing = tx.getSingletonRecord(name);

      const data = merge(existing.data, update);

      const newRecord = {
        name,
        data,
        version: existing.version + 1,
        updated_at: new Date().toISOString(),
      };

      tx.query(
        sql`
          UPDATE "singletons" 
          SET
            "data" = ${JSON.stringify(newRecord.data)},
            "version" = ${newRecord.version},
            "updated_at" = ${newRecord.updated_at}
          WHERE
            "name" = ${newRecord.name};
        `,
      );

      return newRecord;
    };

    try {
      if (this.isTransaction) return operation(this);
      return this.transaction(operation);
    } catch (error) {
      this.logger.error({ error, singletonName: name, update }, "[updateSingletonRecord] transaction error");
      throw error;
    }
  }

  logDatabaseState() {
    try {
      const tables = this.transaction((tx) => {
        const tables = tx
          .query(sql`SELECT name FROM sqlite_master WHERE type='table';`)
          .rows.map((row) => row.name as string);

        return tables.map((table) => {
          return {
            table,
            rows: tx.query(sql`SELECT * FROM "${sql.raw(table)}";`).rows,
          };
        });
      });

      this.logger.debug({ tables }, "logDatabaseState");
    } catch (error) {
      this.logger.error({ error }, "[logDatabaseState] transaction error");
    }
  }

  close(): void {
    if (this.isTransaction) {
      throw new Error("ClientDatabaseAdapter: cannot call close inside a transaction");
    }

    this.db.close();
  }

  protected emitDatabaseChange(change: DatabaseChange) {
    this.logger.verbose({ change }, "emitDatabaseChange");

    for (const callback of this.changeSubscriptions) {
      callback(change);
    }
  }

  protected encodeRecordForDatabase<Table extends RecordTable>(
    pointerWithRecord: PointerWithRecord<Table>,
  ): { [columnName: string]: SqlValue } {
    const mapRecordToDatabase = recordToDatabaseFnMap[pointerWithRecord.table];

    if (!mapRecordToDatabase) {
      throw new Error(`writeRecordMap: could not find record mapper for ${pointerWithRecord.table}`);
    }

    try {
      return mapRecordToDatabase(pointerWithRecord.record);
    } catch (e) {
      this.logger.error({ pointerWithRecord, error: e }, "Error encoding record for database");
      throw e;
    }
  }

  protected decodeRecordFromDatabase<Table extends RecordTable>(table: Table, row: any): RecordValue<Table> | null {
    const decoder = fromDatabaseDecoders[table];

    if (!decoder) {
      this.logger.error(`decodeRecord: could not find decoder for ${table}`);
      return null;
    }

    const decoded = decoder.decode(row);

    if (isDecoderSuccess(decoded)) {
      return decoded.value as RecordValue<Table>;
    }

    this.logger.error({ decoded }, `decodeRecord error`);

    return null;
  }
}

/**
 * Note that this function internally uses a SQL parser
 * made for Postgres' SQL syntax. It's possible we might
 * run into edge cases in the future where a SQLite query
 * causes this to error.
 */
// Parsing a sql statement is relatively expensive (e.g. taking 20ms on a
// large query), so we memoize the results.
const parseTableNames = memoize((sqlQuery: string): string[] => {
  let statements: PgAstStatement[];

  try {
    statements = parse(sqlQuery);
  } catch (e) {
    console.error(`parseTableNames: Error parsing SQL`, sqlQuery, e);
    throw e;
  }

  if (statements.length !== 1) {
    throw new Error(`parseTableNames: Must receive exactly one SQL statement`);
  }

  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
  const statement = statements[0]!;

  const tables = new Set<string>();

  const visitor = astVisitor(() => ({
    tableRef: (t) => tables.add(t.name),
  }));

  visitor.statement(statement);

  const tableNames = Array.from(tables);

  return tableNames;
});

function getObservableForQuery<T>(args: {
  runQuery: () => T;
  subscribe: (onChanges: () => void) => () => void;
}): Observable<T> {
  const { runQuery, subscribe } = args;

  return new Observable<T>((subscriber) => {
    const unsubscribe = subscribe(() => subscriber.next(runQuery()));

    return () => {
      unsubscribe();
    };
  }).pipe(
    // Note that we're using share & startWith instead of just using shareReply as a small memory optimization.
    // This way, we don't need to store an additional copy of the query results in memory.
    // The downside is that we need to rerun the query whenever a new subscriber subscribers (i.e. additional
    // CPU overhead).
    share({ resetOnRefCountZero: true }),
    startWith(runQuery),
  );
}
