{-# LANGUAGE CPP                        #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE ScopedTypeVariables        #-}

module Persistence.Database (DBMonad, onDatabase) where

import Persistence.DBConfig
import Persistence.DatabaseConnection
import Persistence.Schema

import Control.Exception.Lifted
import Control.Monad (when)
import Control.Monad.IO.Class
import Control.Monad.Trans.Control
import Control.Monad.Trans.Reader
import Control.Monad.Logger
import Control.Monad.IO.Unlift
import qualified Control.Monad.Fail as Fail

import qualified Data.List
import Data.Text (Text, pack)

import Database.Persist.Sql

type DBMonad m a = ReaderT SqlBackend m a

onDatabase :: ( MonadIO m
              , MonadBaseControl IO m
              , MonadUnliftIO m
              , Fail.MonadFail m
              )
           => DBConfig
           -> DBMonad (NoLoggingT m) a
           -> m a
onDatabase :: DBConfig -> DBMonad (NoLoggingT m) a -> m a
onDatabase dbConfig :: DBConfig
dbConfig f :: DBMonad (NoLoggingT m) a
f = do
  (Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a
connection <- IO ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
-> m ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
 -> m ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a))
-> IO ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
-> m ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
forall a b. (a -> b) -> a -> b
$ DBConfig
-> IO ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m, MonadLogger m, MonadLoggerIO m,
 MonadUnliftIO m) =>
DBConfig -> IO ((Pool SqlBackend -> m a) -> m a)
getConnection DBConfig
dbConfig
  NoLoggingT m a -> m a
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT m a -> m a) -> NoLoggingT m a -> m a
forall a b. (a -> b) -> a -> b
$ (Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a
connection ((Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a)
-> (Pool SqlBackend -> NoLoggingT m a) -> NoLoggingT m a
forall a b. (a -> b) -> a -> b
$ DBMonad (NoLoggingT m) a -> Pool SqlBackend -> NoLoggingT m a
forall (m :: * -> *) backend a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool (DBMonad (NoLoggingT m) a -> Pool SqlBackend -> NoLoggingT m a)
-> DBMonad (NoLoggingT m) a -> Pool SqlBackend -> NoLoggingT m a
forall a b. (a -> b) -> a -> b
$ do
    DBConfig -> DBMonad (NoLoggingT m) ()
forall (m :: * -> *).
(MonadBaseControl IO m, MonadIO m, MonadUnliftIO m, MonadFail m) =>
DBConfig -> DBMonad m ()
runFullMigrationSet DBConfig
dbConfig
    DBMonad (NoLoggingT m) a
f

runFullMigrationSet :: forall m . ( MonadBaseControl IO m
                                  , MonadIO m
                                  , MonadUnliftIO m
                                  , Fail.MonadFail m
                                  )
                    => DBConfig -> DBMonad m ()
runFullMigrationSet :: DBConfig -> DBMonad m ()
runFullMigrationSet dbConfig :: DBConfig
dbConfig =
  Bool -> DBMonad m () -> DBMonad m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DBConfig -> Bool
doMigrate DBConfig
dbConfig) (DBMonad m () -> DBMonad m ()) -> DBMonad m () -> DBMonad m ()
forall a b. (a -> b) -> a -> b
$ do
    DBMonad m ()
MonadIO m => DBMonad m ()
disableNotices -- PostgreSQL prints notices if an index already exists
    Migration -> ReaderT SqlBackend m [Text]
forall (m :: * -> *).
MonadUnliftIO m =>
Migration -> ReaderT SqlBackend m [Text]
runMigrationSilent Migration
migrateAll
    (String -> ReaderT SqlBackend m [Single (Maybe Text)])
-> [String] -> DBMonad m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SomeException -> ReaderT SqlBackend m [Single (Maybe Text)])
-> ReaderT SqlBackend m [Single (Maybe Text)]
-> ReaderT SqlBackend m [Single (Maybe Text)]
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
(e -> m a) -> m a -> m a
handle SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
MonadIO m =>
SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
ignoreIndexExistsError (ReaderT SqlBackend m [Single (Maybe Text)]
 -> ReaderT SqlBackend m [Single (Maybe Text)])
-> (String -> ReaderT SqlBackend m [Single (Maybe Text)])
-> String
-> ReaderT SqlBackend m [Single (Maybe Text)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ReaderT SqlBackend m [Single (Maybe Text)]
MonadIO m => String -> ReaderT SqlBackend m [Single (Maybe Text)]
runRawSql) ([String] -> DBMonad m ()) -> [String] -> DBMonad m ()
forall a b. (a -> b) -> a -> b
$ DBConfig -> [String]
indexesSQL DBConfig
dbConfig
  where
    disableNotices :: MonadIO m => DBMonad m ()
    disableNotices :: DBMonad m ()
disableNotices =
      Bool -> DBMonad m () -> DBMonad m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DBConfig -> Maybe String
adapter DBConfig
dbConfig Maybe String -> Maybe String -> Bool
forall a. Eq a => a -> a -> Bool
== String -> Maybe String
forall a. a -> Maybe a
Just "postgresql")
        (String -> ReaderT SqlBackend m [Single (Maybe Text)]
MonadIO m => String -> ReaderT SqlBackend m [Single (Maybe Text)]
runRawSql "SET client_min_messages = error;" ReaderT SqlBackend m [Single (Maybe Text)]
-> DBMonad m () -> DBMonad m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> DBMonad m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

    runRawSql :: MonadIO m => String -> DBMonad m [Single (Maybe Text)]
    runRawSql :: String -> ReaderT SqlBackend m [Single (Maybe Text)]
runRawSql sql :: String
sql =
      let query :: Text
query = String -> Text
pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$
#ifdef MYSQL
                         if isMySql dbConfig
                         then
                             sql
                         else
#endif
                             String
sql String -> String -> String
forall a. [a] -> [a] -> [a]
++ " SELECT ('dummy');"
      in  Text
-> [PersistValue] -> ReaderT SqlBackend m [Single (Maybe Text)]
forall a (m :: * -> *) backend.
(RawSql a, MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m [a]
rawSql Text
query []

    -- MySQL does not have "CREATE INDEX IF NOT EXISTS", so we work around this
    -- by trying to create it in any case. It throws an error if it already
    -- exists. This error is then ignored.
    ignoreIndexExistsError :: MonadIO m
                           => SomeException -> DBMonad m [Single (Maybe Text)]
    ignoreIndexExistsError :: SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
ignoreIndexExistsError = String
-> SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
MonadIO m =>
String
-> SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
ignoreError "'ix_"

    ignoreError :: MonadIO m
                => String -> SomeException -> DBMonad m [Single (Maybe Text)]
#ifdef MYSQL
    ignoreError searchInfix exception =
      if isMySql dbConfig && (searchInfix `Data.List.isInfixOf` message)
      then
          return []
      else
#else
    ignoreError :: String
-> SomeException -> ReaderT SqlBackend m [Single (Maybe Text)]
ignoreError _ exception :: SomeException
exception =
#endif
      String -> ReaderT SqlBackend m [Single (Maybe Text)]
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
message where message :: String
message = SomeException -> String
forall a. Show a => a -> String
show SomeException
exception

indexesSQL :: DBConfig -> [String]
indexesSQL :: DBConfig -> [String]
indexesSQL dbConfig :: DBConfig
dbConfig =
  ((String, [String]) -> String) -> [(String, [String])] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String, [String]) -> String
sqlString [(String, [String])]
Persistence.Schema.indexes
  where
    sqlString :: (String, [String]) -> String
    sqlString :: (String, [String]) -> String
sqlString (table :: String
table, columns :: [String]
columns) =
      let indexName :: String
indexName = "ix_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
table String -> String -> String
forall a. [a] -> [a] -> [a]
++ "__" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
Data.List.intercalate "__" [String]
columns
          indexedColumns :: String
indexedColumns = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
Data.List.intercalate ", " [String]
columns
#ifdef MYSQL
          indexNameMySql = take 64 indexName
          mysqlString =
            "CREATE INDEX " ++ indexNameMySql ++ " ON " ++ table ++ ""
            ++ "(" ++ indexedColumns ++ ");"
#endif
          genericString :: String
genericString =
            "CREATE INDEX IF NOT EXISTS " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
indexName String -> String -> String
forall a. [a] -> [a] -> [a]
++ " ON " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
table String -> String -> String
forall a. [a] -> [a] -> [a]
++ ""
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ "(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
indexedColumns String -> String -> String
forall a. [a] -> [a] -> [a]
++ ");"
      in  case DBConfig -> Maybe String
adapter DBConfig
dbConfig of
#ifdef MYSQL
            Just "mysql" -> mysqlString
            Just "mysql2" -> mysqlString
#endif
            Just "postgresql" -> String
genericString
            Just "sqlite" -> String
genericString
            Just "sqlite3" -> String
genericString
            _ -> String
genericString