Postgresql.hs 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265
  1. {-# LANGUAGE DeriveDataTypeable #-}
  2. {-# LANGUAGE OverloadedStrings #-}
  3. {-# LANGUAGE ScopedTypeVariables #-}
  4. {-# LANGUAGE TypeFamilies #-}
  5. {-# LANGUAGE ViewPatterns #-}
  6. -- | A postgresql backend for persistent.
  7. module Database.Persist.Postgresql
  8. ( withPostgresqlPool
  9. , withPostgresqlPoolWithVersion
  10. , withPostgresqlConn
  11. , withPostgresqlConnWithVersion
  12. , createPostgresqlPool
  13. , createPostgresqlPoolModified
  14. , createPostgresqlPoolModifiedWithVersion
  15. , module Database.Persist.Sql
  16. , ConnectionString
  17. , PostgresConf (..)
  18. , openSimpleConn
  19. , openSimpleConnWithVersion
  20. , tableName
  21. , fieldName
  22. , mockMigration
  23. , migrateEnableExtension
  24. ) where
  25. import qualified Database.PostgreSQL.LibPQ as LibPQ
  26. import qualified Database.PostgreSQL.Simple as PG
  27. import qualified Database.PostgreSQL.Simple.Internal as PG
  28. import qualified Database.PostgreSQL.Simple.FromField as PGFF
  29. import qualified Database.PostgreSQL.Simple.ToField as PGTF
  30. import qualified Database.PostgreSQL.Simple.Transaction as PG
  31. import qualified Database.PostgreSQL.Simple.Types as PG
  32. import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS
  33. import Database.PostgreSQL.Simple.Ok (Ok (..))
  34. import Control.Arrow
  35. import Control.Exception (Exception, throw, throwIO)
  36. import Control.Monad (forM)
  37. import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO)
  38. import Control.Monad.Logger (MonadLogger, runNoLoggingT)
  39. import Control.Monad.Trans.Reader (runReaderT)
  40. import Control.Monad.Trans.Writer (WriterT(..), runWriterT)
  41. import qualified Blaze.ByteString.Builder.Char8 as BBB
  42. import Data.Acquire (Acquire, mkAcquire, with)
  43. import Data.Aeson
  44. import Data.Aeson.Types (modifyFailure)
  45. import Data.ByteString (ByteString)
  46. import qualified Data.ByteString.Char8 as B8
  47. import Data.Conduit
  48. import qualified Data.Conduit.List as CL
  49. import Data.Data
  50. import Data.Either (partitionEithers)
  51. import Data.Fixed (Pico)
  52. import Data.Function (on)
  53. import Data.Int (Int64)
  54. import qualified Data.IntMap as I
  55. import Data.IORef
  56. import Data.List (find, sort, groupBy)
  57. import Data.List.NonEmpty (NonEmpty)
  58. import qualified Data.List.NonEmpty as NEL
  59. import qualified Data.Map as Map
  60. import Data.Maybe
  61. import Data.Monoid ((<>))
  62. import Data.Pool (Pool)
  63. import Data.Text (Text)
  64. import qualified Data.Text as T
  65. import qualified Data.Text.Encoding as T
  66. import qualified Data.Text.IO as T
  67. import Data.Text.Read (rational)
  68. import Data.Time (utc, localTimeToUTC)
  69. import Data.Typeable (Typeable)
  70. import System.Environment (getEnvironment)
  71. import Database.Persist.Sql
  72. import qualified Database.Persist.Sql.Util as Util
  73. -- | A @libpq@ connection string. A simple example of connection
  74. -- string would be @\"host=localhost port=5432 user=test
  75. -- dbname=test password=test\"@. Please read libpq's
  76. -- documentation at
  77. -- <https://www.postgresql.org/docs/current/static/libpq-connect.html>
  78. -- for more details on how to create such strings.
  79. type ConnectionString = ByteString
  80. -- | PostgresServerVersionError exception. This is thrown when persistent
  81. -- is unable to find the version of the postgreSQL server.
  82. data PostgresServerVersionError = PostgresServerVersionError String deriving Data.Typeable.Typeable
  83. instance Show PostgresServerVersionError where
  84. show (PostgresServerVersionError uniqueMsg) =
  85. "Unexpected PostgreSQL server version, got " <> uniqueMsg
  86. instance Exception PostgresServerVersionError
  87. -- | Create a PostgreSQL connection pool and run the given
  88. -- action. The pool is properly released after the action
  89. -- finishes using it. Note that you should not use the given
  90. -- 'ConnectionPool' outside the action since it may already
  91. -- have been released.
  92. withPostgresqlPool :: (MonadLogger m, MonadUnliftIO m)
  93. => ConnectionString
  94. -- ^ Connection string to the database.
  95. -> Int
  96. -- ^ Number of connections to be kept open in
  97. -- the pool.
  98. -> (Pool SqlBackend -> m a)
  99. -- ^ Action to be executed that uses the
  100. -- connection pool.
  101. -> m a
  102. withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci
  103. -- | Same as 'withPostgresPool', but takes a callback for obtaining
  104. -- the server version (to work around an Amazon Redshift bug).
  105. --
  106. -- @since 2.6.2
  107. withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLogger m)
  108. => (PG.Connection -> IO (Maybe Double))
  109. -- ^ Action to perform to get the server version.
  110. -> ConnectionString
  111. -- ^ Connection string to the database.
  112. -> Int
  113. -- ^ Number of connections to be kept open in
  114. -- the pool.
  115. -> (Pool SqlBackend -> m a)
  116. -- ^ Action to be executed that uses the
  117. -- connection pool.
  118. -> m a
  119. withPostgresqlPoolWithVersion getVer ci = withSqlPool $ open' (const $ return ()) getVer ci
  120. -- | Create a PostgreSQL connection pool. Note that it's your
  121. -- responsibility to properly close the connection pool when
  122. -- unneeded. Use 'withPostgresqlPool' for an automatic resource
  123. -- control.
  124. createPostgresqlPool :: (MonadUnliftIO m, MonadLogger m)
  125. => ConnectionString
  126. -- ^ Connection string to the database.
  127. -> Int
  128. -- ^ Number of connections to be kept open
  129. -- in the pool.
  130. -> m (Pool SqlBackend)
  131. createPostgresqlPool = createPostgresqlPoolModified (const $ return ())
  132. -- | Same as 'createPostgresqlPool', but additionally takes a callback function
  133. -- for some connection-specific tweaking to be performed after connection
  134. -- creation. This could be used, for example, to change the schema. For more
  135. -- information, see:
  136. --
  137. -- <https://groups.google.com/d/msg/yesodweb/qUXrEN_swEo/O0pFwqwQIdcJ>
  138. --
  139. -- @since 2.1.3
  140. createPostgresqlPoolModified
  141. :: (MonadUnliftIO m, MonadLogger m)
  142. => (PG.Connection -> IO ()) -- ^ Action to perform after connection is created.
  143. -> ConnectionString -- ^ Connection string to the database.
  144. -> Int -- ^ Number of connections to be kept open in the pool.
  145. -> m (Pool SqlBackend)
  146. createPostgresqlPoolModified = createPostgresqlPoolModifiedWithVersion getServerVersion
  147. -- | Same as other similarly-named functions in this module, but takes callbacks for obtaining
  148. -- the server version (to work around an Amazon Redshift bug) and connection-specific tweaking
  149. -- (to change the schema).
  150. --
  151. -- @since 2.6.2
  152. createPostgresqlPoolModifiedWithVersion
  153. :: (MonadUnliftIO m, MonadLogger m)
  154. => (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version.
  155. -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created.
  156. -> ConnectionString -- ^ Connection string to the database.
  157. -> Int -- ^ Number of connections to be kept open in the pool.
  158. -> m (Pool SqlBackend)
  159. createPostgresqlPoolModifiedWithVersion getVer modConn ci =
  160. createSqlPool $ open' modConn getVer ci
  161. -- | Same as 'withPostgresqlPool', but instead of opening a pool
  162. -- of connections, only one connection is opened.
  163. withPostgresqlConn :: (MonadUnliftIO m, MonadLogger m)
  164. => ConnectionString -> (SqlBackend -> m a) -> m a
  165. withPostgresqlConn = withPostgresqlConnWithVersion getServerVersion
  166. -- | Same as 'withPostgresqlConn', but takes a callback for obtaining
  167. -- the server version (to work around an Amazon Redshift bug).
  168. --
  169. -- @since 2.6.2
  170. withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLogger m)
  171. => (PG.Connection -> IO (Maybe Double))
  172. -> ConnectionString
  173. -> (SqlBackend -> m a)
  174. -> m a
  175. withPostgresqlConnWithVersion getVer = withSqlConn . open' (const $ return ()) getVer
  176. open'
  177. :: (PG.Connection -> IO ())
  178. -> (PG.Connection -> IO (Maybe Double))
  179. -> ConnectionString -> LogFunc -> IO SqlBackend
  180. open' modConn getVer cstr logFunc = do
  181. conn <- PG.connectPostgreSQL cstr
  182. modConn conn
  183. ver <- getVer conn
  184. smap <- newIORef $ Map.empty
  185. return $ createBackend logFunc ver smap conn
  186. -- | Gets the PostgreSQL server version
  187. getServerVersion :: PG.Connection -> IO (Maybe Double)
  188. getServerVersion conn = do
  189. [PG.Only version] <- PG.query_ conn "show server_version";
  190. let version' = rational version
  191. --- λ> rational "9.8.3"
  192. --- Right (9.8,".3")
  193. --- λ> rational "9.8.3.5"
  194. --- Right (9.8,".3.5")
  195. case version' of
  196. Right (a,_) -> return $ Just a
  197. Left err -> throwIO $ PostgresServerVersionError err
  198. -- | Choose upsert sql generation function based on postgresql version.
  199. -- PostgreSQL version >= 9.5 supports native upsert feature,
  200. -- so depending upon that we have to choose how the sql query is generated.
  201. -- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text)
  202. upsertFunction :: a -> Double -> Maybe a
  203. upsertFunction f version = if (version >= 9.5)
  204. then Just f
  205. else Nothing
  206. -- | Generate a 'SqlBackend' from a 'PG.Connection'.
  207. openSimpleConn :: LogFunc -> PG.Connection -> IO SqlBackend
  208. openSimpleConn = openSimpleConnWithVersion getServerVersion
  209. -- | Generate a 'SqlBackend' from a 'PG.Connection', but takes a callback for
  210. -- obtaining the server version.
  211. --
  212. -- @since 2.9.1
  213. openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend
  214. openSimpleConnWithVersion getVer logFunc conn = do
  215. smap <- newIORef $ Map.empty
  216. serverVersion <- getVer conn
  217. return $ createBackend logFunc serverVersion smap conn
  218. -- | Create the backend given a logging function, server version, mutable statement cell,
  219. -- and connection.
  220. createBackend :: LogFunc -> Maybe Double
  221. -> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend
  222. createBackend logFunc serverVersion smap conn = do
  223. SqlBackend
  224. { connPrepare = prepare' conn
  225. , connStmtMap = smap
  226. , connInsertSql = insertSql'
  227. , connInsertManySql = Just insertManySql'
  228. , connUpsertSql = serverVersion >>= upsertFunction upsertSql'
  229. , connPutManySql = serverVersion >>= upsertFunction putManySql
  230. , connClose = PG.close conn
  231. , connMigrateSql = migrate'
  232. , connBegin = \_ mIsolation -> case mIsolation of
  233. Nothing -> PG.begin conn
  234. Just iso -> PG.beginLevel (case iso of
  235. ReadUncommitted -> PG.ReadCommitted -- PG Upgrades uncommitted reads to committed anyways
  236. ReadCommitted -> PG.ReadCommitted
  237. RepeatableRead -> PG.RepeatableRead
  238. Serializable -> PG.Serializable) conn
  239. , connCommit = const $ PG.commit conn
  240. , connRollback = const $ PG.rollback conn
  241. , connEscapeName = escape
  242. , connNoLimit = "LIMIT ALL"
  243. , connRDBMS = "postgresql"
  244. , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL"
  245. , connLogFunc = logFunc
  246. , connMaxParams = Nothing
  247. , connRepsertManySql = serverVersion >>= upsertFunction repsertManySql
  248. }
  249. prepare' :: PG.Connection -> Text -> IO Statement
  250. prepare' conn sql = do
  251. let query = PG.Query (T.encodeUtf8 sql)
  252. return Statement
  253. { stmtFinalize = return ()
  254. , stmtReset = return ()
  255. , stmtExecute = execute' conn query
  256. , stmtQuery = withStmt' conn query
  257. }
  258. insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
  259. insertSql' ent vals =
  260. let sql = T.concat
  261. [ "INSERT INTO "
  262. , escape $ entityDB ent
  263. , if null (entityFields ent)
  264. then " DEFAULT VALUES"
  265. else T.concat
  266. [ "("
  267. , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
  268. , ") VALUES("
  269. , T.intercalate "," (map (const "?") $ entityFields ent)
  270. , ")"
  271. ]
  272. ]
  273. in case entityPrimary ent of
  274. Just _pdef -> ISRManyKeys sql vals
  275. Nothing -> ISRSingle (sql <> " RETURNING " <> escape (fieldDB (entityId ent)))
  276. upsertSql' :: EntityDef -> NonEmpty UniqueDef -> Text -> Text
  277. upsertSql' ent uniqs updateVal = T.concat
  278. [ "INSERT INTO "
  279. , escape (entityDB ent)
  280. , "("
  281. , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
  282. , ") VALUES ("
  283. , T.intercalate "," $ map (const "?") (entityFields ent)
  284. , ") ON CONFLICT ("
  285. , T.intercalate "," $ concat $ map (\x -> map escape (map snd $ uniqueFields x)) (entityUniques ent)
  286. , ") DO UPDATE SET "
  287. , updateVal
  288. , " WHERE "
  289. , wher
  290. , " RETURNING ??"
  291. ]
  292. where
  293. wher = T.intercalate " AND " $ map singleCondition $ NEL.toList uniqs
  294. singleCondition :: UniqueDef -> Text
  295. singleCondition udef = T.intercalate " AND " (map singleClause $ map snd (uniqueFields udef))
  296. singleClause :: DBName -> Text
  297. singleClause field = escape (entityDB ent) <> "." <> (escape field) <> " =?"
  298. -- | SQL for inserting multiple rows at once and returning their primary keys.
  299. insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult
  300. insertManySql' ent valss =
  301. let sql = T.concat
  302. [ "INSERT INTO "
  303. , escape (entityDB ent)
  304. , "("
  305. , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
  306. , ") VALUES ("
  307. , T.intercalate "),(" $ replicate (length valss) $ T.intercalate "," $ map (const "?") (entityFields ent)
  308. , ") RETURNING "
  309. , Util.commaSeparated $ Util.dbIdColumnsEsc escape ent
  310. ]
  311. in ISRSingle sql
  312. execute' :: PG.Connection -> PG.Query -> [PersistValue] -> IO Int64
  313. execute' conn query vals = PG.execute conn query (map P vals)
  314. withStmt' :: MonadIO m
  315. => PG.Connection
  316. -> PG.Query
  317. -> [PersistValue]
  318. -> Acquire (ConduitM () [PersistValue] m ())
  319. withStmt' conn query vals =
  320. pull `fmap` mkAcquire openS closeS
  321. where
  322. openS = do
  323. -- Construct raw query
  324. rawquery <- PG.formatQuery conn query (map P vals)
  325. -- Take raw connection
  326. (rt, rr, rc, ids) <- PG.withConnection conn $ \rawconn -> do
  327. -- Execute query
  328. mret <- LibPQ.exec rawconn rawquery
  329. case mret of
  330. Nothing -> do
  331. merr <- LibPQ.errorMessage rawconn
  332. fail $ case merr of
  333. Nothing -> "Postgresql.withStmt': unknown error"
  334. Just e -> "Postgresql.withStmt': " ++ B8.unpack e
  335. Just ret -> do
  336. -- Check result status
  337. status <- LibPQ.resultStatus ret
  338. case status of
  339. LibPQ.TuplesOk -> return ()
  340. _ -> PG.throwResultError "Postgresql.withStmt': bad result status " ret status
  341. -- Get number and type of columns
  342. cols <- LibPQ.nfields ret
  343. oids <- forM [0..cols-1] $ \col -> fmap ((,) col) (LibPQ.ftype ret col)
  344. -- Ready to go!
  345. rowRef <- newIORef (LibPQ.Row 0)
  346. rowCount <- LibPQ.ntuples ret
  347. return (ret, rowRef, rowCount, oids)
  348. let getters
  349. = map (\(col, oid) -> getGetter conn oid $ PG.Field rt col oid) ids
  350. return (rt, rr, rc, getters)
  351. closeS (ret, _, _, _) = LibPQ.unsafeFreeResult ret
  352. pull x = do
  353. y <- liftIO $ pullS x
  354. case y of
  355. Nothing -> return ()
  356. Just z -> yield z >> pull x
  357. pullS (ret, rowRef, rowCount, getters) = do
  358. row <- atomicModifyIORef rowRef (\r -> (r+1, r))
  359. if row == rowCount
  360. then return Nothing
  361. else fmap Just $ forM (zip getters [0..]) $ \(getter, col) -> do
  362. mbs <- LibPQ.getvalue' ret row col
  363. case mbs of
  364. Nothing ->
  365. -- getvalue' verified that the value is NULL.
  366. -- However, that does not mean that there are
  367. -- no NULL values inside the value (e.g., if
  368. -- we're dealing with an array of optional values).
  369. return PersistNull
  370. Just bs -> do
  371. ok <- PGFF.runConversion (getter mbs) conn
  372. bs `seq` case ok of
  373. Errors (exc:_) -> throw exc
  374. Errors [] -> error "Got an Errors, but no exceptions"
  375. Ok v -> return v
  376. -- | Avoid orphan instances.
  377. newtype P = P PersistValue
  378. instance PGTF.ToField P where
  379. toField (P (PersistText t)) = PGTF.toField t
  380. toField (P (PersistByteString bs)) = PGTF.toField (PG.Binary bs)
  381. toField (P (PersistInt64 i)) = PGTF.toField i
  382. toField (P (PersistDouble d)) = PGTF.toField d
  383. toField (P (PersistRational r)) = PGTF.Plain $
  384. BBB.fromString $
  385. show (fromRational r :: Pico) -- FIXME: Too Ambigous, can not select precision without information about field
  386. toField (P (PersistBool b)) = PGTF.toField b
  387. toField (P (PersistDay d)) = PGTF.toField d
  388. toField (P (PersistTimeOfDay t)) = PGTF.toField t
  389. toField (P (PersistUTCTime t)) = PGTF.toField t
  390. toField (P PersistNull) = PGTF.toField PG.Null
  391. toField (P (PersistList l)) = PGTF.toField $ listToJSON l
  392. toField (P (PersistMap m)) = PGTF.toField $ mapToJSON m
  393. toField (P (PersistDbSpecific s)) = PGTF.toField (Unknown s)
  394. toField (P (PersistArray a)) = PGTF.toField $ PG.PGArray $ P <$> a
  395. toField (P (PersistObjectId _)) =
  396. error "Refusing to serialize a PersistObjectId to a PostgreSQL value"
  397. newtype Unknown = Unknown { unUnknown :: ByteString }
  398. deriving (Eq, Show, Read, Ord, Typeable)
  399. instance PGFF.FromField Unknown where
  400. fromField f mdata =
  401. case mdata of
  402. Nothing -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField Unknown"
  403. Just dat -> return (Unknown dat)
  404. instance PGTF.ToField Unknown where
  405. toField (Unknown a) = PGTF.Escape a
  406. type Getter a = PGFF.FieldParser a
  407. convertPV :: PGFF.FromField a => (a -> b) -> Getter b
  408. convertPV f = (fmap f .) . PGFF.fromField
  409. builtinGetters :: I.IntMap (Getter PersistValue)
  410. builtinGetters = I.fromList
  411. [ (k PS.bool, convertPV PersistBool)
  412. , (k PS.bytea, convertPV (PersistByteString . unBinary))
  413. , (k PS.char, convertPV PersistText)
  414. , (k PS.name, convertPV PersistText)
  415. , (k PS.int8, convertPV PersistInt64)
  416. , (k PS.int2, convertPV PersistInt64)
  417. , (k PS.int4, convertPV PersistInt64)
  418. , (k PS.text, convertPV PersistText)
  419. , (k PS.xml, convertPV PersistText)
  420. , (k PS.float4, convertPV PersistDouble)
  421. , (k PS.float8, convertPV PersistDouble)
  422. , (k PS.money, convertPV PersistRational)
  423. , (k PS.bpchar, convertPV PersistText)
  424. , (k PS.varchar, convertPV PersistText)
  425. , (k PS.date, convertPV PersistDay)
  426. , (k PS.time, convertPV PersistTimeOfDay)
  427. , (k PS.timestamp, convertPV (PersistUTCTime. localTimeToUTC utc))
  428. , (k PS.timestamptz, convertPV PersistUTCTime)
  429. , (k PS.bit, convertPV PersistInt64)
  430. , (k PS.varbit, convertPV PersistInt64)
  431. , (k PS.numeric, convertPV PersistRational)
  432. , (k PS.void, \_ _ -> return PersistNull)
  433. , (k PS.json, convertPV (PersistByteString . unUnknown))
  434. , (k PS.jsonb, convertPV (PersistByteString . unUnknown))
  435. , (k PS.unknown, convertPV (PersistByteString . unUnknown))
  436. -- Array types: same order as above.
  437. -- The OIDs were taken from pg_type.
  438. , (1000, listOf PersistBool)
  439. , (1001, listOf (PersistByteString . unBinary))
  440. , (1002, listOf PersistText)
  441. , (1003, listOf PersistText)
  442. , (1016, listOf PersistInt64)
  443. , (1005, listOf PersistInt64)
  444. , (1007, listOf PersistInt64)
  445. , (1009, listOf PersistText)
  446. , (143, listOf PersistText)
  447. , (1021, listOf PersistDouble)
  448. , (1022, listOf PersistDouble)
  449. , (1023, listOf PersistUTCTime)
  450. , (1024, listOf PersistUTCTime)
  451. , (791, listOf PersistRational)
  452. , (1014, listOf PersistText)
  453. , (1015, listOf PersistText)
  454. , (1182, listOf PersistDay)
  455. , (1183, listOf PersistTimeOfDay)
  456. , (1115, listOf PersistUTCTime)
  457. , (1185, listOf PersistUTCTime)
  458. , (1561, listOf PersistInt64)
  459. , (1563, listOf PersistInt64)
  460. , (1231, listOf PersistRational)
  461. -- no array(void) type
  462. , (2951, listOf (PersistDbSpecific . unUnknown))
  463. , (199, listOf (PersistByteString . unUnknown))
  464. , (3807, listOf (PersistByteString . unUnknown))
  465. -- no array(unknown) either
  466. ]
  467. where
  468. k (PGFF.typoid -> i) = PG.oid2int i
  469. -- A @listOf f@ will use a @PGArray (Maybe T)@ to convert
  470. -- the values to Haskell-land. The @Maybe@ is important
  471. -- because the usual way of checking NULLs
  472. -- (c.f. withStmt') won't check for NULL inside
  473. -- arrays---or any other compound structure for that matter.
  474. listOf f = convertPV (PersistList . map (nullable f) . PG.fromPGArray)
  475. where nullable = maybe PersistNull
  476. getGetter :: PG.Connection -> PG.Oid -> Getter PersistValue
  477. getGetter _conn oid
  478. = fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters
  479. where defaultGetter = convertPV (PersistDbSpecific . unUnknown)
  480. unBinary :: PG.Binary a -> a
  481. unBinary (PG.Binary x) = x
  482. doesTableExist :: (Text -> IO Statement)
  483. -> DBName -- ^ table name
  484. -> IO Bool
  485. doesTableExist getter (DBName name) = do
  486. stmt <- getter sql
  487. with (stmtQuery stmt vals) (\src -> runConduit $ src .| start)
  488. where
  489. sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'"
  490. <> " AND schemaname != 'information_schema' AND tablename=?"
  491. vals = [PersistText name]
  492. start = await >>= maybe (error "No results when checking doesTableExist") start'
  493. start' [PersistInt64 0] = finish False
  494. start' [PersistInt64 1] = finish True
  495. start' res = error $ "doesTableExist returned unexpected result: " ++ show res
  496. finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist")
  497. migrate' :: [EntityDef]
  498. -> (Text -> IO Statement)
  499. -> EntityDef
  500. -> IO (Either [Text] [(Bool, Text)])
  501. migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do
  502. old <- getColumns getter entity
  503. case partitionEithers old of
  504. ([], old'') -> do
  505. exists <-
  506. if null old
  507. then doesTableExist getter name
  508. else return True
  509. return $ Right $ migrationText exists old''
  510. (errs, _) -> return $ Left errs
  511. where
  512. name = entityDB entity
  513. migrationText exists old'' =
  514. if not exists
  515. then createText newcols fdefs udspair
  516. else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old'
  517. acs' = map (AlterColumn name) acs
  518. ats' = map (AlterTable name) ats
  519. in acs' ++ ats'
  520. where
  521. old' = partitionEithers old''
  522. (newcols', udefs, fdefs) = mkColumns allDefs entity
  523. newcols = filter (not . safeToRemove entity . cName) newcols'
  524. udspair = map udToPair udefs
  525. -- Check for table existence if there are no columns, workaround
  526. -- for https://github.com/yesodweb/persistent/issues/152
  527. createText newcols fdefs udspair =
  528. (addTable newcols entity) : uniques ++ references ++ foreignsAlt
  529. where
  530. uniques = flip concatMap udspair $ \(uname, ucols) ->
  531. [AlterTable name $ AddUniqueConstraint uname ucols]
  532. references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } ->
  533. getAddReference allDefs name refTblName cname (cReference c))
  534. $ filter (isJust . cReference) newcols
  535. foreignsAlt = flip map fdefs (\fdef ->
  536. let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  537. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields (map escape parentfields)))
  538. addTable :: [Column] -> EntityDef -> AlterDB
  539. addTable cols entity = AddTable $ T.concat
  540. -- Lower case e: see Database.Persist.Sql.Migration
  541. [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION!
  542. , escape name
  543. , "("
  544. , idtxt
  545. , if null cols then "" else ","
  546. , T.intercalate "," $ map showColumn cols
  547. , ")"
  548. ]
  549. where
  550. name = entityDB entity
  551. idtxt = case entityPrimary entity of
  552. Just pdef -> T.concat [" PRIMARY KEY (", T.intercalate "," $ map (escape . fieldDB) $ compositeFields pdef, ")"]
  553. Nothing ->
  554. let defText = defaultAttribute $ fieldAttrs $ entityId entity
  555. sType = fieldSqlType $ entityId entity
  556. in T.concat
  557. [ escape $ fieldDB (entityId entity)
  558. , maySerial sType defText
  559. , " PRIMARY KEY UNIQUE"
  560. , mayDefault defText
  561. ]
  562. maySerial :: SqlType -> Maybe Text -> Text
  563. maySerial SqlInt64 Nothing = " SERIAL8 "
  564. maySerial sType _ = " " <> showSqlType sType
  565. mayDefault :: Maybe Text -> Text
  566. mayDefault def = case def of
  567. Nothing -> ""
  568. Just d -> " DEFAULT " <> d
  569. type SafeToRemove = Bool
  570. data AlterColumn = ChangeType SqlType Text
  571. | IsNull | NotNull | Add' Column | Drop SafeToRemove
  572. | Default Text | NoDefault | Update' Text
  573. | AddReference DBName [DBName] [Text] | DropReference DBName
  574. type AlterColumn' = (DBName, AlterColumn)
  575. data AlterTable = AddUniqueConstraint DBName [DBName]
  576. | DropConstraint DBName
  577. data AlterDB = AddTable Text
  578. | AlterColumn DBName AlterColumn'
  579. | AlterTable DBName AlterTable
  580. -- | Returns all of the columns in the given table currently in the database.
  581. getColumns :: (Text -> IO Statement)
  582. -> EntityDef
  583. -> IO [Either Text (Either Column (DBName, [DBName]))]
  584. getColumns getter def = do
  585. let sqlv=T.concat ["SELECT "
  586. ,"column_name "
  587. ,",is_nullable "
  588. ,",COALESCE(domain_name, udt_name)" -- See DOMAINS below
  589. ,",column_default "
  590. ,",numeric_precision "
  591. ,",numeric_scale "
  592. ,",character_maximum_length "
  593. ,"FROM information_schema.columns "
  594. ,"WHERE table_catalog=current_database() "
  595. ,"AND table_schema=current_schema() "
  596. ,"AND table_name=? "
  597. ,"AND column_name <> ?"]
  598. -- DOMAINS Postgres supports the concept of domains, which are data types with optional constraints.
  599. -- An app might make an "email" domain over the varchar type, with a CHECK that the emails are valid
  600. -- In this case the generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN foo TYPE email
  601. -- This code exists to use the domain name (email), instead of the underlying type (varchar).
  602. -- This is tested in EquivalentTypeTest.hs
  603. stmt <- getter sqlv
  604. let vals =
  605. [ PersistText $ unDBName $ entityDB def
  606. , PersistText $ unDBName $ fieldDB (entityId def)
  607. ]
  608. cs <- with (stmtQuery stmt vals) (\src -> runConduit $ src .| helper)
  609. let sqlc = T.concat ["SELECT "
  610. ,"c.constraint_name, "
  611. ,"c.column_name "
  612. ,"FROM information_schema.key_column_usage c, "
  613. ,"information_schema.table_constraints k "
  614. ,"WHERE c.table_catalog=current_database() "
  615. ,"AND c.table_catalog=k.table_catalog "
  616. ,"AND c.table_schema=current_schema() "
  617. ,"AND c.table_schema=k.table_schema "
  618. ,"AND c.table_name=? "
  619. ,"AND c.table_name=k.table_name "
  620. ,"AND c.column_name <> ? "
  621. ,"AND c.constraint_name=k.constraint_name "
  622. ,"AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') "
  623. ,"ORDER BY c.constraint_name, c.column_name"]
  624. stmt' <- getter sqlc
  625. us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU)
  626. return $ cs ++ us
  627. where
  628. getAll front = do
  629. x <- CL.head
  630. case x of
  631. Nothing -> return $ front []
  632. Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col))
  633. Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col))
  634. Just o -> error $ "unexpected datatype returned for postgres o="++show o
  635. helperU = do
  636. rows <- getAll id
  637. return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
  638. $ groupBy ((==) `on` fst) rows
  639. helper = do
  640. x <- CL.head
  641. case x of
  642. Nothing -> return []
  643. Just x' -> do
  644. col <- liftIO $ getColumn getter (entityDB def) x'
  645. let col' = case col of
  646. Left e -> Left e
  647. Right c -> Right $ Left c
  648. cols <- helper
  649. return $ col' : cols
  650. -- | Check if a column name is listed as the "safe to remove" in the entity
  651. -- list.
  652. safeToRemove :: EntityDef -> DBName -> Bool
  653. safeToRemove def (DBName colName)
  654. = any (elem "SafeToRemove" . fieldAttrs)
  655. $ filter ((== DBName colName) . fieldDB)
  656. $ entityFields def
  657. getAlters :: [EntityDef]
  658. -> EntityDef
  659. -> ([Column], [(DBName, [DBName])])
  660. -> ([Column], [(DBName, [DBName])])
  661. -> ([AlterColumn'], [AlterTable])
  662. getAlters defs def (c1, u1) (c2, u2) =
  663. (getAltersC c1 c2, getAltersU u1 u2)
  664. where
  665. getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old
  666. getAltersC (new:news) old =
  667. let (alters, old') = findAlters defs (entityDB def) new old
  668. in alters ++ getAltersC news old'
  669. getAltersU :: [(DBName, [DBName])]
  670. -> [(DBName, [DBName])]
  671. -> [AlterTable]
  672. getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old
  673. getAltersU ((name, cols):news) old =
  674. case lookup name old of
  675. Nothing -> AddUniqueConstraint name cols : getAltersU news old
  676. Just ocols ->
  677. let old' = filter (\(x, _) -> x /= name) old
  678. in if sort cols == sort ocols
  679. then getAltersU news old'
  680. else DropConstraint name
  681. : AddUniqueConstraint name cols
  682. : getAltersU news old'
  683. -- Don't drop constraints which were manually added.
  684. isManual (DBName x) = "__manual_" `T.isPrefixOf` x
  685. getColumn :: (Text -> IO Statement)
  686. -> DBName -> [PersistValue]
  687. -> IO (Either Text Column)
  688. getColumn getter tableName' [PersistText columnName, PersistText isNullable, PersistText typeName, defaultValue, numericPrecision, numericScale, maxlen] =
  689. case d' of
  690. Left s -> return $ Left s
  691. Right d'' ->
  692. let typeStr = case maxlen of
  693. PersistInt64 n -> T.concat [typeName, "(", T.pack (show n), ")"]
  694. _ -> typeName
  695. in case getType typeStr of
  696. Left s -> return $ Left s
  697. Right t -> do
  698. let cname = DBName columnName
  699. ref <- getRef cname
  700. return $ Right Column
  701. { cName = cname
  702. , cNull = isNullable == "YES"
  703. , cSqlType = t
  704. , cDefault = fmap stripSuffixes d''
  705. , cDefaultConstraintName = Nothing
  706. , cMaxLen = Nothing
  707. , cReference = ref
  708. }
  709. where
  710. stripSuffixes t =
  711. loop'
  712. [ "::character varying"
  713. , "::text"
  714. ]
  715. where
  716. loop' [] = t
  717. loop' (p:ps) =
  718. case T.stripSuffix p t of
  719. Nothing -> loop' ps
  720. Just t' -> t'
  721. getRef cname = do
  722. let sql = T.concat
  723. [ "SELECT COUNT(*) FROM "
  724. , "information_schema.table_constraints "
  725. , "WHERE table_catalog=current_database() "
  726. , "AND table_schema=current_schema() "
  727. , "AND table_name=? "
  728. , "AND constraint_type='FOREIGN KEY' "
  729. , "AND constraint_name=?"
  730. ]
  731. let ref = refName tableName' cname
  732. stmt <- getter sql
  733. with (stmtQuery stmt
  734. [ PersistText $ unDBName tableName'
  735. , PersistText $ unDBName ref
  736. ]) (\src -> runConduit $ src .| do
  737. Just [PersistInt64 i] <- CL.head
  738. return $ if i == 0 then Nothing else Just (DBName "", ref))
  739. d' = case defaultValue of
  740. PersistNull -> Right Nothing
  741. PersistText t -> Right $ Just t
  742. _ -> Left $ T.pack $ "Invalid default column: " ++ show defaultValue
  743. getType "int4" = Right SqlInt32
  744. getType "int8" = Right SqlInt64
  745. getType "varchar" = Right SqlString
  746. getType "text" = Right SqlString
  747. getType "date" = Right SqlDay
  748. getType "bool" = Right SqlBool
  749. getType "timestamptz" = Right SqlDayTime
  750. getType "float4" = Right SqlReal
  751. getType "float8" = Right SqlReal
  752. getType "bytea" = Right SqlBlob
  753. getType "time" = Right SqlTime
  754. getType "numeric" = getNumeric numericPrecision numericScale
  755. getType a = Right $ SqlOther a
  756. getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b)
  757. getNumeric PersistNull PersistNull = Left $ T.concat
  758. [ "No precision and scale were specified for the column: "
  759. , columnName
  760. , " in table: "
  761. , unDBName tableName'
  762. , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383,"
  763. , " which is probably not what you intended."
  764. , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
  765. ]
  766. getNumeric a b = Left $ T.concat
  767. [ "Can not get numeric field precision for the column: "
  768. , columnName
  769. , " in table: "
  770. , unDBName tableName'
  771. , ". Expected an integer for both precision and scale, "
  772. , "got: "
  773. , T.pack $ show a
  774. , " and "
  775. , T.pack $ show b
  776. , ", respectively."
  777. , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
  778. ]
  779. getColumn _ _ columnName =
  780. return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName
  781. -- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer
  782. sqlTypeEq :: SqlType -> SqlType -> Bool
  783. sqlTypeEq x y =
  784. T.toCaseFold (showSqlType x) == T.toCaseFold (showSqlType y)
  785. findAlters :: [EntityDef] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column])
  786. findAlters defs _tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols =
  787. case filter (\c -> cName c == name) cols of
  788. [] -> ([(name, Add' col)], cols)
  789. Column _ isNull' sqltype' def' _defConstraintName' _maxLen' ref':_ ->
  790. let refDrop Nothing = []
  791. refDrop (Just (_, cname)) = [(name, DropReference cname)]
  792. refAdd Nothing = []
  793. refAdd (Just (tname, a)) =
  794. case find ((==tname) . entityDB) defs of
  795. Just refdef -> [(tname, AddReference a [name] (Util.dbIdColumnsEsc escape refdef))]
  796. Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]"
  797. modRef =
  798. if fmap snd ref == fmap snd ref'
  799. then []
  800. else refDrop ref' ++ refAdd ref
  801. modNull = case (isNull, isNull') of
  802. (True, False) -> [(name, IsNull)]
  803. (False, True) ->
  804. let up = case def of
  805. Nothing -> id
  806. Just s -> (:) (name, Update' s)
  807. in up [(name, NotNull)]
  808. _ -> []
  809. modType
  810. | sqlTypeEq sqltype sqltype' = []
  811. -- When converting from Persistent pre-2.0 databases, we
  812. -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is
  813. -- treated as UTC.
  814. | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" =
  815. [(name, ChangeType sqltype $ T.concat
  816. [ " USING "
  817. , escape name
  818. , " AT TIME ZONE 'UTC'"
  819. ])]
  820. | otherwise = [(name, ChangeType sqltype "")]
  821. modDef =
  822. if def == def'
  823. then []
  824. else case def of
  825. Nothing -> [(name, NoDefault)]
  826. Just s -> [(name, Default s)]
  827. in (modRef ++ modDef ++ modNull ++ modType,
  828. filter (\c -> cName c /= name) cols)
  829. -- | Get the references to be added to a table for the given column.
  830. getAddReference :: [EntityDef] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB
  831. getAddReference allDefs table reftable cname ref =
  832. case ref of
  833. Nothing -> Nothing
  834. Just (s, _) -> Just $ AlterColumn table (s, AddReference (refName table cname) [cname] id_)
  835. where
  836. id_ = fromMaybe (error $ "Could not find ID of entity " ++ show reftable)
  837. $ do
  838. entDef <- find ((== reftable) . entityDB) allDefs
  839. return $ Util.dbIdColumnsEsc escape entDef
  840. showColumn :: Column -> Text
  841. showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = T.concat
  842. [ escape n
  843. , " "
  844. , showSqlType sqlType'
  845. , " "
  846. , if nu then "NULL" else "NOT NULL"
  847. , case def of
  848. Nothing -> ""
  849. Just s -> " DEFAULT " <> s
  850. ]
  851. showSqlType :: SqlType -> Text
  852. showSqlType SqlString = "VARCHAR"
  853. showSqlType SqlInt32 = "INT4"
  854. showSqlType SqlInt64 = "INT8"
  855. showSqlType SqlReal = "DOUBLE PRECISION"
  856. showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ]
  857. showSqlType SqlDay = "DATE"
  858. showSqlType SqlTime = "TIME"
  859. showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE"
  860. showSqlType SqlBlob = "BYTEA"
  861. showSqlType SqlBool = "BOOLEAN"
  862. -- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682
  863. showSqlType (SqlOther (T.toLower -> "integer")) = "INT4"
  864. showSqlType (SqlOther t) = t
  865. showAlterDb :: AlterDB -> (Bool, Text)
  866. showAlterDb (AddTable s) = (False, s)
  867. showAlterDb (AlterColumn t (c, ac)) =
  868. (isUnsafe ac, showAlter t (c, ac))
  869. where
  870. isUnsafe (Drop safeRemove) = not safeRemove
  871. isUnsafe _ = False
  872. showAlterDb (AlterTable t at) = (False, showAlterTable t at)
  873. showAlterTable :: DBName -> AlterTable -> Text
  874. showAlterTable table (AddUniqueConstraint cname cols) = T.concat
  875. [ "ALTER TABLE "
  876. , escape table
  877. , " ADD CONSTRAINT "
  878. , escape cname
  879. , " UNIQUE("
  880. , T.intercalate "," $ map escape cols
  881. , ")"
  882. ]
  883. showAlterTable table (DropConstraint cname) = T.concat
  884. [ "ALTER TABLE "
  885. , escape table
  886. , " DROP CONSTRAINT "
  887. , escape cname
  888. ]
  889. showAlter :: DBName -> AlterColumn' -> Text
  890. showAlter table (n, ChangeType t extra) =
  891. T.concat
  892. [ "ALTER TABLE "
  893. , escape table
  894. , " ALTER COLUMN "
  895. , escape n
  896. , " TYPE "
  897. , showSqlType t
  898. , extra
  899. ]
  900. showAlter table (n, IsNull) =
  901. T.concat
  902. [ "ALTER TABLE "
  903. , escape table
  904. , " ALTER COLUMN "
  905. , escape n
  906. , " DROP NOT NULL"
  907. ]
  908. showAlter table (n, NotNull) =
  909. T.concat
  910. [ "ALTER TABLE "
  911. , escape table
  912. , " ALTER COLUMN "
  913. , escape n
  914. , " SET NOT NULL"
  915. ]
  916. showAlter table (_, Add' col) =
  917. T.concat
  918. [ "ALTER TABLE "
  919. , escape table
  920. , " ADD COLUMN "
  921. , showColumn col
  922. ]
  923. showAlter table (n, Drop _) =
  924. T.concat
  925. [ "ALTER TABLE "
  926. , escape table
  927. , " DROP COLUMN "
  928. , escape n
  929. ]
  930. showAlter table (n, Default s) =
  931. T.concat
  932. [ "ALTER TABLE "
  933. , escape table
  934. , " ALTER COLUMN "
  935. , escape n
  936. , " SET DEFAULT "
  937. , s
  938. ]
  939. showAlter table (n, NoDefault) = T.concat
  940. [ "ALTER TABLE "
  941. , escape table
  942. , " ALTER COLUMN "
  943. , escape n
  944. , " DROP DEFAULT"
  945. ]
  946. showAlter table (n, Update' s) = T.concat
  947. [ "UPDATE "
  948. , escape table
  949. , " SET "
  950. , escape n
  951. , "="
  952. , s
  953. , " WHERE "
  954. , escape n
  955. , " IS NULL"
  956. ]
  957. showAlter table (reftable, AddReference fkeyname t2 id2) = T.concat
  958. [ "ALTER TABLE "
  959. , escape table
  960. , " ADD CONSTRAINT "
  961. , escape fkeyname
  962. , " FOREIGN KEY("
  963. , T.intercalate "," $ map escape t2
  964. , ") REFERENCES "
  965. , escape reftable
  966. , "("
  967. , T.intercalate "," id2
  968. , ")"
  969. ]
  970. showAlter table (_, DropReference cname) = T.concat
  971. [ "ALTER TABLE "
  972. , escape table
  973. , " DROP CONSTRAINT "
  974. , escape cname
  975. ]
  976. -- | Get the SQL string for the table that a PeristEntity represents.
  977. -- Useful for raw SQL queries.
  978. tableName :: (PersistEntity record) => record -> Text
  979. tableName = escape . tableDBName
  980. -- | Get the SQL string for the field that an EntityField represents.
  981. -- Useful for raw SQL queries.
  982. fieldName :: (PersistEntity record) => EntityField record typ -> Text
  983. fieldName = escape . fieldDBName
  984. escape :: DBName -> Text
  985. escape (DBName s) =
  986. T.pack $ '"' : go (T.unpack s) ++ "\""
  987. where
  988. go "" = ""
  989. go ('"':xs) = "\"\"" ++ go xs
  990. go (x:xs) = x : go xs
  991. -- | Information required to connect to a PostgreSQL database
  992. -- using @persistent@'s generic facilities. These values are the
  993. -- same that are given to 'withPostgresqlPool'.
  994. data PostgresConf = PostgresConf
  995. { pgConnStr :: ConnectionString
  996. -- ^ The connection string.
  997. , pgPoolSize :: Int
  998. -- ^ How many connections should be held in the connection pool.
  999. } deriving (Show, Read, Data, Typeable)
  1000. instance FromJSON PostgresConf where
  1001. parseJSON v = modifyFailure ("Persistent: error loading PostgreSQL conf: " ++) $
  1002. flip (withObject "PostgresConf") v $ \o -> do
  1003. database <- o .: "database"
  1004. host <- o .: "host"
  1005. port <- o .:? "port" .!= 5432
  1006. user <- o .: "user"
  1007. password <- o .: "password"
  1008. pool <- o .: "poolsize"
  1009. let ci = PG.ConnectInfo
  1010. { PG.connectHost = host
  1011. , PG.connectPort = port
  1012. , PG.connectUser = user
  1013. , PG.connectPassword = password
  1014. , PG.connectDatabase = database
  1015. }
  1016. cstr = PG.postgreSQLConnectionString ci
  1017. return $ PostgresConf cstr pool
  1018. instance PersistConfig PostgresConf where
  1019. type PersistConfigBackend PostgresConf = SqlPersistT
  1020. type PersistConfigPool PostgresConf = ConnectionPool
  1021. createPoolConfig (PostgresConf cs size) = runNoLoggingT $ createPostgresqlPool cs size -- FIXME
  1022. runPool _ = runSqlPool
  1023. loadConfig = parseJSON
  1024. applyEnv c0 = do
  1025. env <- getEnvironment
  1026. return $ addUser env
  1027. $ addPass env
  1028. $ addDatabase env
  1029. $ addPort env
  1030. $ addHost env c0
  1031. where
  1032. addParam param val c =
  1033. c { pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"] }
  1034. pgescape = B8.pack . go
  1035. where
  1036. go ('\'':rest) = '\\' : '\'' : go rest
  1037. go ('\\':rest) = '\\' : '\\' : go rest
  1038. go ( x :rest) = x : go rest
  1039. go [] = []
  1040. maybeAddParam param envvar env =
  1041. maybe id (addParam param) $
  1042. lookup envvar env
  1043. addHost = maybeAddParam "host" "PGHOST"
  1044. addPort = maybeAddParam "port" "PGPORT"
  1045. addUser = maybeAddParam "user" "PGUSER"
  1046. addPass = maybeAddParam "password" "PGPASS"
  1047. addDatabase = maybeAddParam "dbname" "PGDATABASE"
  1048. refName :: DBName -> DBName -> DBName
  1049. refName (DBName table) (DBName column) =
  1050. DBName $ T.concat [table, "_", column, "_fkey"]
  1051. udToPair :: UniqueDef -> (DBName, [DBName])
  1052. udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
  1053. mockMigrate :: [EntityDef]
  1054. -> (Text -> IO Statement)
  1055. -> EntityDef
  1056. -> IO (Either [Text] [(Bool, Text)])
  1057. mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do
  1058. case partitionEithers [] of
  1059. ([], old'') -> return $ Right $ migrationText False old''
  1060. (errs, _) -> return $ Left errs
  1061. where
  1062. name = entityDB entity
  1063. migrationText exists old'' =
  1064. if not exists
  1065. then createText newcols fdefs udspair
  1066. else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old'
  1067. acs' = map (AlterColumn name) acs
  1068. ats' = map (AlterTable name) ats
  1069. in acs' ++ ats'
  1070. where
  1071. old' = partitionEithers old''
  1072. (newcols', udefs, fdefs) = mkColumns allDefs entity
  1073. newcols = filter (not . safeToRemove entity . cName) newcols'
  1074. udspair = map udToPair udefs
  1075. -- Check for table existence if there are no columns, workaround
  1076. -- for https://github.com/yesodweb/persistent/issues/152
  1077. createText newcols fdefs udspair =
  1078. (addTable newcols entity) : uniques ++ references ++ foreignsAlt
  1079. where
  1080. uniques = flip concatMap udspair $ \(uname, ucols) ->
  1081. [AlterTable name $ AddUniqueConstraint uname ucols]
  1082. references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } ->
  1083. getAddReference allDefs name refTblName cname (cReference c))
  1084. $ filter (isJust . cReference) newcols
  1085. foreignsAlt = flip map fdefs (\fdef ->
  1086. let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  1087. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields (map escape parentfields)))
  1088. -- | Mock a migration even when the database is not present.
  1089. -- This function performs the same functionality of 'printMigration'
  1090. -- with the difference that an actual database is not needed.
  1091. mockMigration :: Migration -> IO ()
  1092. mockMigration mig = do
  1093. smap <- newIORef $ Map.empty
  1094. let sqlbackend = SqlBackend { connPrepare = \_ -> do
  1095. return Statement
  1096. { stmtFinalize = return ()
  1097. , stmtReset = return ()
  1098. , stmtExecute = undefined
  1099. , stmtQuery = \_ -> return $ return ()
  1100. },
  1101. connInsertManySql = Nothing,
  1102. connInsertSql = undefined,
  1103. connUpsertSql = Nothing,
  1104. connPutManySql = Nothing,
  1105. connStmtMap = smap,
  1106. connClose = undefined,
  1107. connMigrateSql = mockMigrate,
  1108. connBegin = undefined,
  1109. connCommit = undefined,
  1110. connRollback = undefined,
  1111. connEscapeName = escape,
  1112. connNoLimit = undefined,
  1113. connRDBMS = undefined,
  1114. connLimitOffset = undefined,
  1115. connLogFunc = undefined,
  1116. connMaxParams = Nothing,
  1117. connRepsertManySql = Nothing
  1118. }
  1119. result = runReaderT $ runWriterT $ runWriterT mig
  1120. resp <- result sqlbackend
  1121. mapM_ T.putStrLn $ map snd $ snd resp
  1122. putManySql :: EntityDef -> Int -> Text
  1123. putManySql ent n = putManySql' conflictColumns fields ent n
  1124. where
  1125. fields = entityFields ent
  1126. conflictColumns = concatMap (map (escape . snd) . uniqueFields) (entityUniques ent)
  1127. repsertManySql :: EntityDef -> Int -> Text
  1128. repsertManySql ent n = putManySql' conflictColumns fields ent n
  1129. where
  1130. fields = keyAndEntityFields ent
  1131. conflictColumns = escape . fieldDB <$> entityKeyFields ent
  1132. putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text
  1133. putManySql' conflictColumns fields ent n = q
  1134. where
  1135. fieldDbToText = escape . fieldDB
  1136. mkAssignment f = T.concat [f, "=EXCLUDED.", f]
  1137. table = escape . entityDB $ ent
  1138. columns = Util.commaSeparated $ map fieldDbToText fields
  1139. placeholders = map (const "?") fields
  1140. updates = map (mkAssignment . fieldDbToText) fields
  1141. q = T.concat
  1142. [ "INSERT INTO "
  1143. , table
  1144. , Util.parenWrapped columns
  1145. , " VALUES "
  1146. , Util.commaSeparated . replicate n
  1147. . Util.parenWrapped . Util.commaSeparated $ placeholders
  1148. , " ON CONFLICT "
  1149. , Util.parenWrapped . Util.commaSeparated $ conflictColumns
  1150. , " DO UPDATE SET "
  1151. , Util.commaSeparated updates
  1152. ]
  1153. -- | Enable a Postgres extension. See https://www.postgresql.org/docs/current/static/contrib.html
  1154. -- for a list.
  1155. migrateEnableExtension :: Text -> Migration
  1156. migrateEnableExtension extName = WriterT $ WriterT $ do
  1157. res :: [Single Int] <-
  1158. rawSql "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?" [PersistText extName]
  1159. if res == [Single 0]
  1160. then return (((), []) , [(False, "CREATe EXTENSION \"" <> extName <> "\"")])
  1161. else return (((), []), [])