MySQL.hs 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315
  1. {-# LANGUAGE ExistentialQuantification #-}
  2. {-# LANGUAGE FlexibleContexts #-}
  3. {-# LANGUAGE GADTs #-}
  4. {-# LANGUAGE OverloadedStrings #-}
  5. {-# LANGUAGE PatternSynonyms #-}
  6. {-# LANGUAGE TypeFamilies #-}
  7. -- | A MySQL backend for @persistent@.
  8. module Database.Persist.MySQL
  9. ( withMySQLPool
  10. , withMySQLConn
  11. , createMySQLPool
  12. , module Database.Persist.Sql
  13. , MySQL.ConnectInfo(..)
  14. , MySQLBase.SSLInfo(..)
  15. , MySQL.defaultConnectInfo
  16. , MySQLBase.defaultSSLInfo
  17. , MySQLConf(..)
  18. , mockMigration
  19. -- * @ON DUPLICATE KEY UPDATE@ Functionality
  20. , insertOnDuplicateKeyUpdate
  21. , insertManyOnDuplicateKeyUpdate
  22. , HandleUpdateCollision
  23. , pattern SomeField
  24. , SomeField
  25. , copyField
  26. , copyUnlessNull
  27. , copyUnlessEmpty
  28. , copyUnlessEq
  29. ) where
  30. import qualified Blaze.ByteString.Builder.Char8 as BBB
  31. import qualified Blaze.ByteString.Builder.ByteString as BBS
  32. import Control.Arrow
  33. import Control.Monad
  34. import Control.Monad.IO.Class (MonadIO (..))
  35. import Control.Monad.IO.Unlift (MonadUnliftIO)
  36. import Control.Monad.Logger (MonadLogger, runNoLoggingT)
  37. import Control.Monad.Trans.Class (lift)
  38. import Control.Monad.Trans.Except (runExceptT)
  39. import Control.Monad.Trans.Reader (runReaderT, ReaderT)
  40. import Control.Monad.Trans.Writer (runWriterT)
  41. import Data.Conduit
  42. import qualified Data.Conduit.List as CL
  43. import Data.Acquire (Acquire, mkAcquire, with)
  44. import Data.Aeson
  45. import Data.Aeson.Types (modifyFailure)
  46. import Data.ByteString (ByteString)
  47. import Data.Either (partitionEithers)
  48. import Data.Fixed (Pico)
  49. import Data.Function (on)
  50. import Data.Int (Int64)
  51. import Data.IORef
  52. import Data.List (find, intercalate, sort, groupBy)
  53. import qualified Data.Map as Map
  54. import Data.Monoid ((<>))
  55. import qualified Data.Monoid as Monoid
  56. import Data.Pool (Pool)
  57. import Data.Text (Text, pack)
  58. import qualified Data.Text as T
  59. import qualified Data.Text.Encoding as T
  60. import qualified Data.Text.IO as T
  61. import Text.Read (readMaybe)
  62. import System.Environment (getEnvironment)
  63. import Database.Persist.Sql
  64. import Database.Persist.Sql.Types.Internal (makeIsolationLevelStatement)
  65. import qualified Database.Persist.Sql.Util as Util
  66. import qualified Database.MySQL.Base as MySQLBase
  67. import qualified Database.MySQL.Base.Types as MySQLBase
  68. import qualified Database.MySQL.Simple as MySQL
  69. import qualified Database.MySQL.Simple.Param as MySQL
  70. import qualified Database.MySQL.Simple.Result as MySQL
  71. import qualified Database.MySQL.Simple.Types as MySQL
  72. -- | Create a MySQL connection pool and run the given action.
  73. -- The pool is properly released after the action finishes using
  74. -- it. Note that you should not use the given 'ConnectionPool'
  75. -- outside the action since it may be already been released.
  76. withMySQLPool :: (MonadLogger m, MonadUnliftIO m)
  77. => MySQL.ConnectInfo
  78. -- ^ Connection information.
  79. -> Int
  80. -- ^ Number of connections to be kept open in the pool.
  81. -> (Pool SqlBackend -> m a)
  82. -- ^ Action to be executed that uses the connection pool.
  83. -> m a
  84. withMySQLPool ci = withSqlPool $ open' ci
  85. -- | Create a MySQL connection pool. Note that it's your
  86. -- responsibility to properly close the connection pool when
  87. -- unneeded. Use 'withMySQLPool' for automatic resource control.
  88. createMySQLPool :: (MonadUnliftIO m, MonadLogger m)
  89. => MySQL.ConnectInfo
  90. -- ^ Connection information.
  91. -> Int
  92. -- ^ Number of connections to be kept open in the pool.
  93. -> m (Pool SqlBackend)
  94. createMySQLPool ci = createSqlPool $ open' ci
  95. -- | Same as 'withMySQLPool', but instead of opening a pool
  96. -- of connections, only one connection is opened.
  97. withMySQLConn :: (MonadUnliftIO m, MonadLogger m)
  98. => MySQL.ConnectInfo
  99. -- ^ Connection information.
  100. -> (SqlBackend -> m a)
  101. -- ^ Action to be executed that uses the connection.
  102. -> m a
  103. withMySQLConn = withSqlConn . open'
  104. -- | Internal function that opens a connection to the MySQL
  105. -- server.
  106. open' :: MySQL.ConnectInfo -> LogFunc -> IO SqlBackend
  107. open' ci logFunc = do
  108. conn <- MySQL.connect ci
  109. MySQLBase.autocommit conn False -- disable autocommit!
  110. smap <- newIORef $ Map.empty
  111. return $ SqlBackend
  112. { connPrepare = prepare' conn
  113. , connStmtMap = smap
  114. , connInsertSql = insertSql'
  115. , connInsertManySql = Nothing
  116. , connUpsertSql = Nothing
  117. , connPutManySql = Just putManySql
  118. , connClose = MySQL.close conn
  119. , connMigrateSql = migrate' ci
  120. , connBegin = \_ mIsolation -> do
  121. forM_ mIsolation $ \iso -> MySQL.execute_ conn (makeIsolationLevelStatement iso)
  122. MySQL.execute_ conn "start transaction" >> return ()
  123. , connCommit = const $ MySQL.commit conn
  124. , connRollback = const $ MySQL.rollback conn
  125. , connEscapeName = pack . escapeDBName
  126. , connNoLimit = "LIMIT 18446744073709551615"
  127. -- This noLimit is suggested by MySQL's own docs, see
  128. -- <http://dev.mysql.com/doc/refman/5.5/en/select.html>
  129. , connRDBMS = "mysql"
  130. , connLimitOffset = decorateSQLWithLimitOffset "LIMIT 18446744073709551615"
  131. , connLogFunc = logFunc
  132. , connMaxParams = Nothing
  133. , connRepsertManySql = Just repsertManySql
  134. }
  135. -- | Prepare a query. We don't support prepared statements, but
  136. -- we'll do some client-side preprocessing here.
  137. prepare' :: MySQL.Connection -> Text -> IO Statement
  138. prepare' conn sql = do
  139. let query = MySQL.Query (T.encodeUtf8 sql)
  140. return Statement
  141. { stmtFinalize = return ()
  142. , stmtReset = return ()
  143. , stmtExecute = execute' conn query
  144. , stmtQuery = withStmt' conn query
  145. }
  146. -- | SQL code to be executed when inserting an entity.
  147. insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
  148. insertSql' ent vals =
  149. let sql = pack $ concat
  150. [ "INSERT INTO "
  151. , escapeDBName $ entityDB ent
  152. , "("
  153. , intercalate "," $ map (escapeDBName . fieldDB) $ entityFields ent
  154. , ") VALUES("
  155. , intercalate "," (map (const "?") $ entityFields ent)
  156. , ")"
  157. ]
  158. in case entityPrimary ent of
  159. Just _ -> ISRManyKeys sql vals
  160. Nothing -> ISRInsertGet sql "SELECT LAST_INSERT_ID()"
  161. -- | Execute an statement that doesn't return any results.
  162. execute' :: MySQL.Connection -> MySQL.Query -> [PersistValue] -> IO Int64
  163. execute' conn query vals = MySQL.execute conn query (map P vals)
  164. -- | Execute an statement that does return results. The results
  165. -- are fetched all at once and stored into memory.
  166. withStmt' :: MonadIO m
  167. => MySQL.Connection
  168. -> MySQL.Query
  169. -> [PersistValue]
  170. -> Acquire (ConduitM () [PersistValue] m ())
  171. withStmt' conn query vals = do
  172. result <- mkAcquire createResult MySQLBase.freeResult
  173. return $ fetchRows result >>= CL.sourceList
  174. where
  175. createResult = do
  176. -- Execute the query
  177. formatted <- MySQL.formatQuery conn query (map P vals)
  178. MySQLBase.query conn formatted
  179. MySQLBase.storeResult conn
  180. fetchRows result = liftIO $ do
  181. -- Find out the type of the columns
  182. fields <- MySQLBase.fetchFields result
  183. let getters = [ maybe PersistNull (getGetter f f . Just) | f <- fields]
  184. convert = use getters
  185. where use (g:gs) (col:cols) =
  186. let v = g col
  187. vs = use gs cols
  188. in v `seq` vs `seq` (v:vs)
  189. use _ _ = []
  190. -- Ready to go!
  191. let go acc = do
  192. row <- MySQLBase.fetchRow result
  193. case row of
  194. [] -> return (acc [])
  195. _ -> let converted = convert row
  196. in converted `seq` go (acc . (converted:))
  197. go id
  198. -- | @newtype@ around 'PersistValue' that supports the
  199. -- 'MySQL.Param' type class.
  200. newtype P = P PersistValue
  201. instance MySQL.Param P where
  202. render (P (PersistText t)) = MySQL.render t
  203. render (P (PersistByteString bs)) = MySQL.render bs
  204. render (P (PersistInt64 i)) = MySQL.render i
  205. render (P (PersistDouble d)) = MySQL.render d
  206. render (P (PersistBool b)) = MySQL.render b
  207. render (P (PersistDay d)) = MySQL.render d
  208. render (P (PersistTimeOfDay t)) = MySQL.render t
  209. render (P (PersistUTCTime t)) = MySQL.render t
  210. render (P PersistNull) = MySQL.render MySQL.Null
  211. render (P (PersistList l)) = MySQL.render $ listToJSON l
  212. render (P (PersistMap m)) = MySQL.render $ mapToJSON m
  213. render (P (PersistRational r)) =
  214. MySQL.Plain $ BBB.fromString $ show (fromRational r :: Pico)
  215. -- FIXME: Too Ambigous, can not select precision without information about field
  216. render (P (PersistDbSpecific s)) = MySQL.Plain $ BBS.fromByteString s
  217. render (P (PersistArray a)) = MySQL.render (P (PersistList a))
  218. render (P (PersistObjectId _)) =
  219. error "Refusing to serialize a PersistObjectId to a MySQL value"
  220. -- | @Getter a@ is a function that converts an incoming value
  221. -- into a data type @a@.
  222. type Getter a = MySQLBase.Field -> Maybe ByteString -> a
  223. -- | Helper to construct 'Getter'@s@ using 'MySQL.Result'.
  224. convertPV :: MySQL.Result a => (a -> b) -> Getter b
  225. convertPV f = (f .) . MySQL.convert
  226. -- | Get the corresponding @'Getter' 'PersistValue'@ depending on
  227. -- the type of the column.
  228. getGetter :: MySQLBase.Field -> Getter PersistValue
  229. getGetter field = go (MySQLBase.fieldType field)
  230. (MySQLBase.fieldLength field)
  231. (MySQLBase.fieldCharSet field)
  232. where
  233. -- Bool
  234. go MySQLBase.Tiny 1 _ = convertPV PersistBool
  235. go MySQLBase.Tiny _ _ = convertPV PersistInt64
  236. -- Int64
  237. go MySQLBase.Int24 _ _ = convertPV PersistInt64
  238. go MySQLBase.Short _ _ = convertPV PersistInt64
  239. go MySQLBase.Long _ _ = convertPV PersistInt64
  240. go MySQLBase.LongLong _ _ = convertPV PersistInt64
  241. -- Double
  242. go MySQLBase.Float _ _ = convertPV PersistDouble
  243. go MySQLBase.Double _ _ = convertPV PersistDouble
  244. go MySQLBase.Decimal _ _ = convertPV PersistDouble
  245. go MySQLBase.NewDecimal _ _ = convertPV PersistDouble
  246. -- ByteString and Text
  247. -- The MySQL C client (and by extension the Haskell mysql package) doesn't distinguish between binary and non-binary string data at the type level.
  248. -- (e.g. both BLOB and TEXT have the MySQLBase.Blob type).
  249. -- Instead, the character set distinguishes them. Binary data uses character set number 63.
  250. -- See https://dev.mysql.com/doc/refman/5.6/en/c-api-data-structures.html (Search for "63")
  251. go MySQLBase.VarChar _ 63 = convertPV PersistByteString
  252. go MySQLBase.VarString _ 63 = convertPV PersistByteString
  253. go MySQLBase.String _ 63 = convertPV PersistByteString
  254. go MySQLBase.VarChar _ _ = convertPV PersistText
  255. go MySQLBase.VarString _ _ = convertPV PersistText
  256. go MySQLBase.String _ _ = convertPV PersistText
  257. go MySQLBase.Blob _ 63 = convertPV PersistByteString
  258. go MySQLBase.TinyBlob _ 63 = convertPV PersistByteString
  259. go MySQLBase.MediumBlob _ 63 = convertPV PersistByteString
  260. go MySQLBase.LongBlob _ 63 = convertPV PersistByteString
  261. go MySQLBase.Blob _ _ = convertPV PersistText
  262. go MySQLBase.TinyBlob _ _ = convertPV PersistText
  263. go MySQLBase.MediumBlob _ _ = convertPV PersistText
  264. go MySQLBase.LongBlob _ _ = convertPV PersistText
  265. -- Time-related
  266. go MySQLBase.Time _ _ = convertPV PersistTimeOfDay
  267. go MySQLBase.DateTime _ _ = convertPV PersistUTCTime
  268. go MySQLBase.Timestamp _ _ = convertPV PersistUTCTime
  269. go MySQLBase.Date _ _ = convertPV PersistDay
  270. go MySQLBase.NewDate _ _ = convertPV PersistDay
  271. go MySQLBase.Year _ _ = convertPV PersistDay
  272. -- Null
  273. go MySQLBase.Null _ _ = \_ _ -> PersistNull
  274. -- Controversial conversions
  275. go MySQLBase.Set _ _ = convertPV PersistText
  276. go MySQLBase.Enum _ _ = convertPV PersistText
  277. -- Conversion using PersistDbSpecific
  278. go MySQLBase.Geometry _ _ = \_ m ->
  279. case m of
  280. Just g -> PersistDbSpecific g
  281. Nothing -> error "Unexpected null in database specific value"
  282. -- Unsupported
  283. go other _ _ = error $ "MySQL.getGetter: type " ++
  284. show other ++ " not supported."
  285. ----------------------------------------------------------------------
  286. -- | Create the migration plan for the given 'PersistEntity'
  287. -- @val@.
  288. migrate' :: MySQL.ConnectInfo
  289. -> [EntityDef]
  290. -> (Text -> IO Statement)
  291. -> EntityDef
  292. -> IO (Either [Text] [(Bool, Text)])
  293. migrate' connectInfo allDefs getter val = do
  294. let name = entityDB val
  295. (idClmn, old) <- getColumns connectInfo getter val
  296. let (newcols, udefs, fdefs) = mkColumns allDefs val
  297. let udspair = map udToPair udefs
  298. case (idClmn, old, partitionEithers old) of
  299. -- Nothing found, create everything
  300. ([], [], _) -> do
  301. let uniques = flip concatMap udspair $ \(uname, ucols) ->
  302. [ AlterTable name $
  303. AddUniqueConstraint uname $
  304. map (findTypeAndMaxLen name) ucols ]
  305. let foreigns = do
  306. Column { cName=cname, cReference=Just (refTblName, _a) } <- newcols
  307. return $ AlterColumn name (refTblName, addReference allDefs (refName name cname) refTblName cname)
  308. let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  309. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
  310. return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
  311. -- No errors and something found, migrate
  312. (_, _, ([], old')) -> do
  313. let excludeForeignKeys (xs,ys) = (map (\c -> case cReference c of
  314. Just (_,fk) -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of
  315. Just _ -> c { cReference = Nothing }
  316. Nothing -> c
  317. Nothing -> c) xs,ys)
  318. (acs, ats) = getAlters allDefs name (newcols, udspair) $ excludeForeignKeys $ partitionEithers old'
  319. acs' = map (AlterColumn name) acs
  320. ats' = map (AlterTable name) ats
  321. return $ Right $ map showAlterDb $ acs' ++ ats'
  322. -- Errors
  323. (_, _, (errs, _)) -> return $ Left errs
  324. where
  325. findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
  326. (_, ml) = findMaxLenOfColumn allDefs tblName col
  327. in (col', ty, ml)
  328. addTable :: [Column] -> EntityDef -> AlterDB
  329. addTable cols entity = AddTable $ concat
  330. -- Lower case e: see Database.Persist.Sql.Migration
  331. [ "CREATe TABLE "
  332. , escapeDBName name
  333. , "("
  334. , idtxt
  335. , if null cols then [] else ","
  336. , intercalate "," $ map showColumn cols
  337. , ")"
  338. ]
  339. where
  340. name = entityDB entity
  341. idtxt = case entityPrimary entity of
  342. Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (escapeDBName . fieldDB) $ compositeFields pdef, ")"]
  343. Nothing ->
  344. let defText = defaultAttribute $ fieldAttrs $ entityId entity
  345. sType = fieldSqlType $ entityId entity
  346. autoIncrementText = case (sType, defText) of
  347. (SqlInt64, Nothing) -> " AUTO_INCREMENT"
  348. _ -> ""
  349. maxlen = findMaxLenOfField (entityId entity)
  350. in concat
  351. [ escapeDBName $ fieldDB $ entityId entity
  352. , " " <> showSqlType sType maxlen False
  353. , " NOT NULL"
  354. , autoIncrementText
  355. , " PRIMARY KEY"
  356. ]
  357. -- | Find out the type of a column.
  358. findTypeOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, FieldType)
  359. findTypeOfColumn allDefs name col =
  360. maybe (error $ "Could not find type of column " ++
  361. show col ++ " on table " ++ show name ++
  362. " (allDefs = " ++ show allDefs ++ ")")
  363. ((,) col) $ do
  364. entDef <- find ((== name) . entityDB) allDefs
  365. fieldDef <- find ((== col) . fieldDB) (entityFields entDef)
  366. return (fieldType fieldDef)
  367. -- | Find out the maxlen of a column (default to 200)
  368. findMaxLenOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, Integer)
  369. findMaxLenOfColumn allDefs name col =
  370. maybe (col, 200)
  371. ((,) col) $ do
  372. entDef <- find ((== name) . entityDB) allDefs
  373. fieldDef <- find ((== col) . fieldDB) (entityFields entDef)
  374. findMaxLenOfField fieldDef
  375. -- | Find out the maxlen of a field
  376. findMaxLenOfField :: FieldDef -> Maybe Integer
  377. findMaxLenOfField fieldDef = do
  378. maxLenAttr <- find ((T.isPrefixOf "maxlen=") . T.toLower) (fieldAttrs fieldDef)
  379. readMaybe . T.unpack . T.drop 7 $ maxLenAttr
  380. -- | Helper for 'AddReference' that finds out the which primary key columns to reference.
  381. addReference :: [EntityDef] -> DBName -> DBName -> DBName -> AlterColumn
  382. addReference allDefs fkeyname reftable cname = AddReference reftable fkeyname [cname] referencedColumns
  383. where
  384. referencedColumns = maybe (error $ "Could not find ID of entity " ++ show reftable
  385. ++ " (allDefs = " ++ show allDefs ++ ")")
  386. id $ do
  387. entDef <- find ((== reftable) . entityDB) allDefs
  388. return $ map fieldDB $ entityKeyFields entDef
  389. data AlterColumn = Change Column
  390. | Add' Column
  391. | Drop
  392. | Default String
  393. | NoDefault
  394. | Update' String
  395. -- | See the definition of the 'showAlter' function to see how these fields are used.
  396. | AddReference
  397. DBName -- Referenced table
  398. DBName -- Foreign key name
  399. [DBName] -- Referencing columns
  400. [DBName] -- Referenced columns
  401. | DropReference DBName
  402. type AlterColumn' = (DBName, AlterColumn)
  403. data AlterTable = AddUniqueConstraint DBName [(DBName, FieldType, Integer)]
  404. | DropUniqueConstraint DBName
  405. data AlterDB = AddTable String
  406. | AlterColumn DBName AlterColumn'
  407. | AlterTable DBName AlterTable
  408. udToPair :: UniqueDef -> (DBName, [DBName])
  409. udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
  410. ----------------------------------------------------------------------
  411. -- | Returns all of the 'Column'@s@ in the given table currently
  412. -- in the database.
  413. getColumns :: MySQL.ConnectInfo
  414. -> (Text -> IO Statement)
  415. -> EntityDef
  416. -> IO ( [Either Text (Either Column (DBName, [DBName]))] -- ID column
  417. , [Either Text (Either Column (DBName, [DBName]))] -- everything else
  418. )
  419. getColumns connectInfo getter def = do
  420. -- Find out ID column.
  421. stmtIdClmn <- getter $ T.concat
  422. [ "SELECT COLUMN_NAME, "
  423. , "IS_NULLABLE, "
  424. , "DATA_TYPE, "
  425. , "COLUMN_DEFAULT "
  426. , "FROM INFORMATION_SCHEMA.COLUMNS "
  427. , "WHERE TABLE_SCHEMA = ? "
  428. , "AND TABLE_NAME = ? "
  429. , "AND COLUMN_NAME = ?"
  430. ]
  431. inter1 <- with (stmtQuery stmtIdClmn vals) (\src -> runConduit $ src .| CL.consume)
  432. ids <- runConduitRes $ CL.sourceList inter1 .| helperClmns -- avoid nested queries
  433. -- Find out all columns.
  434. stmtClmns <- getter $ T.concat
  435. [ "SELECT COLUMN_NAME, "
  436. , "IS_NULLABLE, "
  437. , "DATA_TYPE, "
  438. , "COLUMN_TYPE, "
  439. , "CHARACTER_MAXIMUM_LENGTH, "
  440. , "NUMERIC_PRECISION, "
  441. , "NUMERIC_SCALE, "
  442. , "COLUMN_DEFAULT "
  443. , "FROM INFORMATION_SCHEMA.COLUMNS "
  444. , "WHERE TABLE_SCHEMA = ? "
  445. , "AND TABLE_NAME = ? "
  446. , "AND COLUMN_NAME <> ?"
  447. ]
  448. inter2 <- with (stmtQuery stmtClmns vals) (\src -> runConduitRes $ src .| CL.consume)
  449. cs <- runConduitRes $ CL.sourceList inter2 .| helperClmns -- avoid nested queries
  450. -- Find out the constraints.
  451. stmtCntrs <- getter $ T.concat
  452. [ "SELECT CONSTRAINT_NAME, "
  453. , "COLUMN_NAME "
  454. , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
  455. , "WHERE TABLE_SCHEMA = ? "
  456. , "AND TABLE_NAME = ? "
  457. , "AND COLUMN_NAME <> ? "
  458. , "AND CONSTRAINT_NAME <> 'PRIMARY' "
  459. , "AND REFERENCED_TABLE_SCHEMA IS NULL "
  460. , "ORDER BY CONSTRAINT_NAME, "
  461. , "COLUMN_NAME"
  462. ]
  463. us <- with (stmtQuery stmtCntrs vals) (\src -> runConduitRes $ src .| helperCntrs)
  464. -- Return both
  465. return (ids, cs ++ us)
  466. where
  467. vals = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
  468. , PersistText $ unDBName $ entityDB def
  469. , PersistText $ unDBName $ fieldDB $ entityId def ]
  470. helperClmns = CL.mapM getIt .| CL.consume
  471. where
  472. getIt = fmap (either Left (Right . Left)) .
  473. liftIO .
  474. getColumn connectInfo getter (entityDB def)
  475. helperCntrs = do
  476. let check [ PersistText cntrName
  477. , PersistText clmnName] = return ( cntrName, clmnName )
  478. check other = fail $ "helperCntrs: unexpected " ++ show other
  479. rows <- mapM check =<< CL.consume
  480. return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
  481. $ groupBy ((==) `on` fst) rows
  482. -- | Get the information about a column in a table.
  483. getColumn :: MySQL.ConnectInfo
  484. -> (Text -> IO Statement)
  485. -> DBName
  486. -> [PersistValue]
  487. -> IO (Either Text Column)
  488. getColumn connectInfo getter tname [ PersistText cname
  489. , PersistText null_
  490. , PersistText dataType
  491. , PersistText colType
  492. , colMaxLen
  493. , colPrecision
  494. , colScale
  495. , default'] =
  496. fmap (either (Left . pack) Right) $
  497. runExceptT $ do
  498. -- Default value
  499. default_ <- case default' of
  500. PersistNull -> return Nothing
  501. PersistText t -> return (Just t)
  502. PersistByteString bs ->
  503. case T.decodeUtf8' bs of
  504. Left exc -> fail $ "Invalid default column: " ++
  505. show default' ++ " (error: " ++
  506. show exc ++ ")"
  507. Right t -> return (Just t)
  508. _ -> fail $ "Invalid default column: " ++ show default'
  509. -- Foreign key (if any)
  510. stmt <- lift . getter $ T.concat
  511. [ "SELECT REFERENCED_TABLE_NAME, "
  512. , "CONSTRAINT_NAME, "
  513. , "ORDINAL_POSITION "
  514. , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
  515. , "WHERE TABLE_SCHEMA = ? "
  516. , "AND TABLE_NAME = ? "
  517. , "AND COLUMN_NAME = ? "
  518. , "AND REFERENCED_TABLE_SCHEMA = ? "
  519. , "ORDER BY CONSTRAINT_NAME, "
  520. , "COLUMN_NAME"
  521. ]
  522. let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
  523. , PersistText $ unDBName $ tname
  524. , PersistText cname
  525. , PersistText $ pack $ MySQL.connectDatabase connectInfo ]
  526. cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
  527. ref <- case cntrs of
  528. [] -> return Nothing
  529. [[PersistText tab, PersistText ref, PersistInt64 pos]] ->
  530. return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
  531. _ -> fail "MySQL.getColumn/getRef: never here"
  532. let colMaxLen' = case colMaxLen of
  533. PersistInt64 l -> Just (fromIntegral l)
  534. _ -> Nothing
  535. ci = ColumnInfo
  536. { ciColumnType = colType
  537. , ciMaxLength = colMaxLen'
  538. , ciNumericPrecision = colPrecision
  539. , ciNumericScale = colScale
  540. }
  541. (typ, maxLen) <- parseColumnType dataType ci
  542. -- Okay!
  543. return Column
  544. { cName = DBName $ cname
  545. , cNull = null_ == "YES"
  546. , cSqlType = typ
  547. , cDefault = default_
  548. , cDefaultConstraintName = Nothing
  549. , cMaxLen = maxLen
  550. , cReference = ref
  551. }
  552. getColumn _ _ _ x =
  553. return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x
  554. -- | Extra column information from MySQL schema
  555. data ColumnInfo = ColumnInfo
  556. { ciColumnType :: Text
  557. , ciMaxLength :: Maybe Integer
  558. , ciNumericPrecision :: PersistValue
  559. , ciNumericScale :: PersistValue
  560. }
  561. -- | Parse the type of column as returned by MySQL's
  562. -- @INFORMATION_SCHEMA@ tables.
  563. parseColumnType :: Monad m => Text -> ColumnInfo -> m (SqlType, Maybe Integer)
  564. -- Ints
  565. parseColumnType "tinyint" ci | ciColumnType ci == "tinyint(1)" = return (SqlBool, Nothing)
  566. parseColumnType "int" ci | ciColumnType ci == "int(11)" = return (SqlInt32, Nothing)
  567. parseColumnType "bigint" ci | ciColumnType ci == "bigint(20)" = return (SqlInt64, Nothing)
  568. -- Double
  569. parseColumnType x@("double") ci | ciColumnType ci == x = return (SqlReal, Nothing)
  570. parseColumnType "decimal" ci =
  571. case (ciNumericPrecision ci, ciNumericScale ci) of
  572. (PersistInt64 p, PersistInt64 s) ->
  573. return (SqlNumeric (fromIntegral p) (fromIntegral s), Nothing)
  574. _ ->
  575. fail "missing DECIMAL precision in DB schema"
  576. -- Text
  577. parseColumnType "varchar" ci = return (SqlString, ciMaxLength ci)
  578. parseColumnType "text" _ = return (SqlString, Nothing)
  579. -- ByteString
  580. parseColumnType "varbinary" ci = return (SqlBlob, ciMaxLength ci)
  581. parseColumnType "blob" _ = return (SqlBlob, Nothing)
  582. -- Time-related
  583. parseColumnType "time" _ = return (SqlTime, Nothing)
  584. parseColumnType "datetime" _ = return (SqlDayTime, Nothing)
  585. parseColumnType "date" _ = return (SqlDay, Nothing)
  586. parseColumnType _ ci = return (SqlOther (ciColumnType ci), Nothing)
  587. ----------------------------------------------------------------------
  588. -- | @getAlters allDefs tblName new old@ finds out what needs to
  589. -- be changed from @old@ to become @new@.
  590. getAlters :: [EntityDef]
  591. -> DBName
  592. -> ([Column], [(DBName, [DBName])])
  593. -> ([Column], [(DBName, [DBName])])
  594. -> ([AlterColumn'], [AlterTable])
  595. getAlters allDefs tblName (c1, u1) (c2, u2) =
  596. (getAltersC c1 c2, getAltersU u1 u2)
  597. where
  598. getAltersC [] old = concatMap dropColumn old
  599. getAltersC (new:news) old =
  600. let (alters, old') = findAlters tblName allDefs new old
  601. in alters ++ getAltersC news old'
  602. dropColumn col =
  603. map ((,) (cName col)) $
  604. [DropReference n | Just (_, n) <- [cReference col]] ++
  605. [Drop]
  606. getAltersU [] old = map (DropUniqueConstraint . fst) old
  607. getAltersU ((name, cols):news) old =
  608. case lookup name old of
  609. Nothing ->
  610. AddUniqueConstraint name (map findTypeAndMaxLen cols) : getAltersU news old
  611. Just ocols ->
  612. let old' = filter (\(x, _) -> x /= name) old
  613. in if sort cols == ocols
  614. then getAltersU news old'
  615. else DropUniqueConstraint name
  616. : AddUniqueConstraint name (map findTypeAndMaxLen cols)
  617. : getAltersU news old'
  618. where
  619. findTypeAndMaxLen col = let (col', ty) = findTypeOfColumn allDefs tblName col
  620. (_, ml) = findMaxLenOfColumn allDefs tblName col
  621. in (col', ty, ml)
  622. -- | @findAlters newColumn oldColumns@ finds out what needs to be
  623. -- changed in the columns @oldColumns@ for @newColumn@ to be
  624. -- supported.
  625. findAlters :: DBName -> [EntityDef] -> Column -> [Column] -> ([AlterColumn'], [Column])
  626. findAlters tblName allDefs col@(Column name isNull type_ def _defConstraintName maxLen ref) cols =
  627. case filter ((name ==) . cName) cols of
  628. -- new fkey that didnt exist before
  629. [] -> case ref of
  630. Nothing -> ([(name, Add' col)],[])
  631. Just (tname, _b) -> let cnstr = [addReference allDefs (refName tblName name) tname name]
  632. in (map ((,) tname) (Add' col : cnstr), cols)
  633. Column _ isNull' type_' def' _defConstraintName' maxLen' ref':_ ->
  634. let -- Foreign key
  635. refDrop = case (ref == ref', ref') of
  636. (False, Just (_, cname)) -> [(name, DropReference cname)]
  637. _ -> []
  638. refAdd = case (ref == ref', ref) of
  639. (False, Just (tname, _cname)) -> [(tname, addReference allDefs (refName tblName name) tname name)]
  640. _ -> []
  641. -- Type and nullability
  642. modType | showSqlType type_ maxLen False `ciEquals` showSqlType type_' maxLen' False && isNull == isNull' = []
  643. | otherwise = [(name, Change col)]
  644. -- Default value
  645. -- Avoid DEFAULT NULL, since it is always unnecessary, and is an error for text/blob fields
  646. modDef | def == def' = []
  647. | otherwise = case def of
  648. Nothing -> [(name, NoDefault)]
  649. Just s -> if T.toUpper s == "NULL" then []
  650. else [(name, Default $ T.unpack s)]
  651. in ( refDrop ++ modType ++ modDef ++ refAdd
  652. , filter ((name /=) . cName) cols )
  653. where
  654. ciEquals x y = T.toCaseFold (T.pack x) == T.toCaseFold (T.pack y)
  655. ----------------------------------------------------------------------
  656. -- | Prints the part of a @CREATE TABLE@ statement about a given
  657. -- column.
  658. showColumn :: Column -> String
  659. showColumn (Column n nu t def _defConstraintName maxLen ref) = concat
  660. [ escapeDBName n
  661. , " "
  662. , showSqlType t maxLen True
  663. , " "
  664. , if nu then "NULL" else "NOT NULL"
  665. , case def of
  666. Nothing -> ""
  667. Just s -> -- Avoid DEFAULT NULL, since it is always unnecessary, and is an error for text/blob fields
  668. if T.toUpper s == "NULL" then ""
  669. else " DEFAULT " ++ T.unpack s
  670. , case ref of
  671. Nothing -> ""
  672. Just (s, _) -> " REFERENCES " ++ escapeDBName s
  673. ]
  674. -- | Renders an 'SqlType' in MySQL's format.
  675. showSqlType :: SqlType
  676. -> Maybe Integer -- ^ @maxlen@
  677. -> Bool -- ^ include character set information?
  678. -> String
  679. showSqlType SqlBlob Nothing _ = "BLOB"
  680. showSqlType SqlBlob (Just i) _ = "VARBINARY(" ++ show i ++ ")"
  681. showSqlType SqlBool _ _ = "TINYINT(1)"
  682. showSqlType SqlDay _ _ = "DATE"
  683. showSqlType SqlDayTime _ _ = "DATETIME"
  684. showSqlType SqlInt32 _ _ = "INT(11)"
  685. showSqlType SqlInt64 _ _ = "BIGINT"
  686. showSqlType SqlReal _ _ = "DOUBLE"
  687. showSqlType (SqlNumeric s prec) _ _ = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
  688. showSqlType SqlString Nothing True = "TEXT CHARACTER SET utf8"
  689. showSqlType SqlString Nothing False = "TEXT"
  690. showSqlType SqlString (Just i) True = "VARCHAR(" ++ show i ++ ") CHARACTER SET utf8"
  691. showSqlType SqlString (Just i) False = "VARCHAR(" ++ show i ++ ")"
  692. showSqlType SqlTime _ _ = "TIME"
  693. showSqlType (SqlOther t) _ _ = T.unpack t
  694. -- | Render an action that must be done on the database.
  695. showAlterDb :: AlterDB -> (Bool, Text)
  696. showAlterDb (AddTable s) = (False, pack s)
  697. showAlterDb (AlterColumn t (c, ac)) =
  698. (isUnsafe ac, pack $ showAlter t (c, ac))
  699. where
  700. isUnsafe Drop = True
  701. isUnsafe _ = False
  702. showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)
  703. -- | Render an action that must be done on a table.
  704. showAlterTable :: DBName -> AlterTable -> String
  705. showAlterTable table (AddUniqueConstraint cname cols) = concat
  706. [ "ALTER TABLE "
  707. , escapeDBName table
  708. , " ADD CONSTRAINT "
  709. , escapeDBName cname
  710. , " UNIQUE("
  711. , intercalate "," $ map escapeDBName' cols
  712. , ")"
  713. ]
  714. where
  715. escapeDBName' (name, (FTTypeCon _ "Text" ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  716. escapeDBName' (name, (FTTypeCon _ "String" ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  717. escapeDBName' (name, (FTTypeCon _ "ByteString"), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  718. escapeDBName' (name, _ , _) = escapeDBName name
  719. showAlterTable table (DropUniqueConstraint cname) = concat
  720. [ "ALTER TABLE "
  721. , escapeDBName table
  722. , " DROP INDEX "
  723. , escapeDBName cname
  724. ]
  725. -- | Render an action that must be done on a column.
  726. showAlter :: DBName -> AlterColumn' -> String
  727. showAlter table (oldName, Change (Column n nu t def defConstraintName maxLen _ref)) =
  728. concat
  729. [ "ALTER TABLE "
  730. , escapeDBName table
  731. , " CHANGE "
  732. , escapeDBName oldName
  733. , " "
  734. , showColumn (Column n nu t def defConstraintName maxLen Nothing)
  735. ]
  736. showAlter table (_, Add' col) =
  737. concat
  738. [ "ALTER TABLE "
  739. , escapeDBName table
  740. , " ADD COLUMN "
  741. , showColumn col
  742. ]
  743. showAlter table (n, Drop) =
  744. concat
  745. [ "ALTER TABLE "
  746. , escapeDBName table
  747. , " DROP COLUMN "
  748. , escapeDBName n
  749. ]
  750. showAlter table (n, Default s) =
  751. concat
  752. [ "ALTER TABLE "
  753. , escapeDBName table
  754. , " ALTER COLUMN "
  755. , escapeDBName n
  756. , " SET DEFAULT "
  757. , s
  758. ]
  759. showAlter table (n, NoDefault) =
  760. concat
  761. [ "ALTER TABLE "
  762. , escapeDBName table
  763. , " ALTER COLUMN "
  764. , escapeDBName n
  765. , " DROP DEFAULT"
  766. ]
  767. showAlter table (n, Update' s) =
  768. concat
  769. [ "UPDATE "
  770. , escapeDBName table
  771. , " SET "
  772. , escapeDBName n
  773. , "="
  774. , s
  775. , " WHERE "
  776. , escapeDBName n
  777. , " IS NULL"
  778. ]
  779. showAlter table (_, AddReference reftable fkeyname t2 id2) = concat
  780. [ "ALTER TABLE "
  781. , escapeDBName table
  782. , " ADD CONSTRAINT "
  783. , escapeDBName fkeyname
  784. , " FOREIGN KEY("
  785. , intercalate "," $ map escapeDBName t2
  786. , ") REFERENCES "
  787. , escapeDBName reftable
  788. , "("
  789. , intercalate "," $ map escapeDBName id2
  790. , ")"
  791. ]
  792. showAlter table (_, DropReference cname) = concat
  793. [ "ALTER TABLE "
  794. , escapeDBName table
  795. , " DROP FOREIGN KEY "
  796. , escapeDBName cname
  797. ]
  798. refName :: DBName -> DBName -> DBName
  799. refName (DBName table) (DBName column) =
  800. DBName $ T.concat [table, "_", column, "_fkey"]
  801. ----------------------------------------------------------------------
  802. escape :: DBName -> Text
  803. escape = T.pack . escapeDBName
  804. -- | Escape a database name to be included on a query.
  805. escapeDBName :: DBName -> String
  806. escapeDBName (DBName s) = '`' : go (T.unpack s)
  807. where
  808. go ('`':xs) = '`' : '`' : go xs
  809. go ( x :xs) = x : go xs
  810. go "" = "`"
  811. -- | Information required to connect to a MySQL database
  812. -- using @persistent@'s generic facilities. These values are the
  813. -- same that are given to 'withMySQLPool'.
  814. data MySQLConf = MySQLConf
  815. { myConnInfo :: MySQL.ConnectInfo
  816. -- ^ The connection information.
  817. , myPoolSize :: Int
  818. -- ^ How many connections should be held on the connection pool.
  819. } deriving Show
  820. instance FromJSON MySQLConf where
  821. parseJSON v = modifyFailure ("Persistent: error loading MySQL conf: " ++) $
  822. flip (withObject "MySQLConf") v $ \o -> do
  823. database <- o .: "database"
  824. host <- o .: "host"
  825. port <- o .: "port"
  826. path <- o .:? "path"
  827. user <- o .: "user"
  828. password <- o .: "password"
  829. pool <- o .: "poolsize"
  830. let ci = MySQL.defaultConnectInfo
  831. { MySQL.connectHost = host
  832. , MySQL.connectPort = port
  833. , MySQL.connectPath = case path of
  834. Just p -> p
  835. Nothing -> MySQL.connectPath MySQL.defaultConnectInfo
  836. , MySQL.connectUser = user
  837. , MySQL.connectPassword = password
  838. , MySQL.connectDatabase = database
  839. }
  840. return $ MySQLConf ci pool
  841. instance PersistConfig MySQLConf where
  842. type PersistConfigBackend MySQLConf = SqlPersistT
  843. type PersistConfigPool MySQLConf = ConnectionPool
  844. createPoolConfig (MySQLConf cs size) = runNoLoggingT $ createMySQLPool cs size -- FIXME
  845. runPool _ = runSqlPool
  846. loadConfig = parseJSON
  847. applyEnv conf = do
  848. env <- getEnvironment
  849. let maybeEnv old var = maybe old id $ lookup ("MYSQL_" ++ var) env
  850. return conf
  851. { myConnInfo =
  852. case myConnInfo conf of
  853. MySQL.ConnectInfo
  854. { MySQL.connectHost = host
  855. , MySQL.connectPort = port
  856. , MySQL.connectPath = path
  857. , MySQL.connectUser = user
  858. , MySQL.connectPassword = password
  859. , MySQL.connectDatabase = database
  860. } -> (myConnInfo conf)
  861. { MySQL.connectHost = maybeEnv host "HOST"
  862. , MySQL.connectPort = read $ maybeEnv (show port) "PORT"
  863. , MySQL.connectPath = maybeEnv path "PATH"
  864. , MySQL.connectUser = maybeEnv user "USER"
  865. , MySQL.connectPassword = maybeEnv password "PASSWORD"
  866. , MySQL.connectDatabase = maybeEnv database "DATABASE"
  867. }
  868. }
  869. mockMigrate :: MySQL.ConnectInfo
  870. -> [EntityDef]
  871. -> (Text -> IO Statement)
  872. -> EntityDef
  873. -> IO (Either [Text] [(Bool, Text)])
  874. mockMigrate _connectInfo allDefs _getter val = do
  875. let name = entityDB val
  876. let (newcols, udefs, fdefs) = mkColumns allDefs val
  877. let udspair = map udToPair udefs
  878. case () of
  879. -- Nothing found, create everything
  880. () -> do
  881. let uniques = flip concatMap udspair $ \(uname, ucols) ->
  882. [ AlterTable name $
  883. AddUniqueConstraint uname $
  884. map (findTypeAndMaxLen name) ucols ]
  885. let foreigns = do
  886. Column { cName=cname, cReference=Just (refTblName, _a) } <- newcols
  887. return $ AlterColumn name (refTblName, addReference allDefs (refName name cname) refTblName cname)
  888. let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  889. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
  890. return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
  891. {- FIXME redundant, why is this here? The whole case expression is weird
  892. -- No errors and something found, migrate
  893. (_, _, ([], old')) -> do
  894. let excludeForeignKeys (xs,ys) = (map (\c -> case cReference c of
  895. Just (_,fk) -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of
  896. Just _ -> c { cReference = Nothing }
  897. Nothing -> c
  898. Nothing -> c) xs,ys)
  899. (acs, ats) = getAlters allDefs name (newcols, udspair) $ excludeForeignKeys $ partitionEithers old'
  900. acs' = map (AlterColumn name) acs
  901. ats' = map (AlterTable name) ats
  902. return $ Right $ map showAlterDb $ acs' ++ ats'
  903. -- Errors
  904. (_, _, (errs, _)) -> return $ Left errs
  905. -}
  906. where
  907. findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
  908. (_, ml) = findMaxLenOfColumn allDefs tblName col
  909. in (col', ty, ml)
  910. -- | Mock a migration even when the database is not present.
  911. -- This function will mock the migration for a database even when
  912. -- the actual database isn't already present in the system.
  913. mockMigration :: Migration -> IO ()
  914. mockMigration mig = do
  915. smap <- newIORef $ Map.empty
  916. let sqlbackend = SqlBackend { connPrepare = \_ -> do
  917. return Statement
  918. { stmtFinalize = return ()
  919. , stmtReset = return ()
  920. , stmtExecute = undefined
  921. , stmtQuery = \_ -> return $ return ()
  922. },
  923. connInsertManySql = Nothing,
  924. connInsertSql = undefined,
  925. connStmtMap = smap,
  926. connClose = undefined,
  927. connMigrateSql = mockMigrate undefined,
  928. connBegin = undefined,
  929. connCommit = undefined,
  930. connRollback = undefined,
  931. connEscapeName = undefined,
  932. connNoLimit = undefined,
  933. connRDBMS = undefined,
  934. connLimitOffset = undefined,
  935. connLogFunc = undefined,
  936. connUpsertSql = undefined,
  937. connPutManySql = undefined,
  938. connMaxParams = Nothing,
  939. connRepsertManySql = Nothing
  940. }
  941. result = runReaderT . runWriterT . runWriterT $ mig
  942. resp <- result sqlbackend
  943. mapM_ T.putStrLn $ map snd $ snd resp
  944. -- | MySQL specific 'upsert_'. This will prevent multiple queries, when one will
  945. -- do. The record will be inserted into the database. In the event that the
  946. -- record already exists in the database, the record will have the
  947. -- relevant updates performed.
  948. insertOnDuplicateKeyUpdate
  949. :: ( backend ~ PersistEntityBackend record
  950. , PersistEntity record
  951. , MonadIO m
  952. , PersistStore backend
  953. , BackendCompatible SqlBackend backend
  954. )
  955. => record
  956. -> [Update record]
  957. -> ReaderT backend m ()
  958. insertOnDuplicateKeyUpdate record =
  959. insertManyOnDuplicateKeyUpdate [record] []
  960. -- | This type is used to determine how to update rows using MySQL's
  961. -- @INSERT ... ON DUPLICATE KEY UPDATE@ functionality, exposed via
  962. -- 'insertManyOnDuplicateKeyUpdate' in this library.
  963. --
  964. -- @since 2.8.0
  965. data HandleUpdateCollision record where
  966. -- | Copy the field directly from the record.
  967. CopyField :: EntityField record typ -> HandleUpdateCollision record
  968. -- | Only copy the field if it is not equal to the provided value.
  969. CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
  970. -- | An alias for 'HandleUpdateCollision'. The type previously was only
  971. -- used to copy a single value, but was expanded to be handle more complex
  972. -- queries.
  973. --
  974. -- @since 2.6.2
  975. type SomeField = HandleUpdateCollision
  976. pattern SomeField :: EntityField record typ -> SomeField record
  977. pattern SomeField x = CopyField x
  978. {-# DEPRECATED SomeField "The type SomeField is deprecated. Use the type HandleUpdateCollision instead, and use the function copyField instead of the data constructor." #-}
  979. -- | Copy the field into the database only if the value in the
  980. -- corresponding record is non-@NULL@.
  981. --
  982. -- @since 2.6.2
  983. copyUnlessNull :: PersistField typ => EntityField record (Maybe typ) -> HandleUpdateCollision record
  984. copyUnlessNull field = CopyUnlessEq field Nothing
  985. -- | Copy the field into the database only if the value in the
  986. -- corresponding record is non-empty, where "empty" means the Monoid
  987. -- definition for 'mempty'. Useful for 'Text', 'String', 'ByteString', etc.
  988. --
  989. -- The resulting 'HandleUpdateCollision' type is useful for the
  990. -- 'insertManyOnDuplicateKeyUpdate' function.
  991. --
  992. -- @since 2.6.2
  993. copyUnlessEmpty :: (Monoid.Monoid typ, PersistField typ) => EntityField record typ -> HandleUpdateCollision record
  994. copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty
  995. -- | Copy the field into the database only if the field is not equal to the
  996. -- provided value. This is useful to avoid copying weird nullary data into
  997. -- the database.
  998. --
  999. -- The resulting 'HandleUpdateCollision' type is useful for the
  1000. -- 'insertManyOnDuplicateKeyUpdate' function.
  1001. --
  1002. -- @since 2.6.2
  1003. copyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
  1004. copyUnlessEq = CopyUnlessEq
  1005. -- | Copy the field directly from the record.
  1006. --
  1007. -- @since 3.0
  1008. copyField :: PersistField typ => EntityField record typ -> HandleUpdateCollision record
  1009. copyField = CopyField
  1010. -- | Do a bulk insert on the given records in the first parameter. In the event
  1011. -- that a key conflicts with a record currently in the database, the second and
  1012. -- third parameters determine what will happen.
  1013. --
  1014. -- The second parameter is a list of fields to copy from the original value.
  1015. -- This allows you to specify which fields to copy from the record you're trying
  1016. -- to insert into the database to the preexisting row.
  1017. --
  1018. -- The third parameter is a list of updates to perform that are independent of
  1019. -- the value that is provided. You can use this to increment a counter value.
  1020. -- These updates only occur if the original record is present in the database.
  1021. --
  1022. -- === __More details on 'HandleUpdateCollision' usage__
  1023. --
  1024. -- The @['HandleUpdateCollision']@ parameter allows you to specify which fields (and
  1025. -- under which conditions) will be copied from the inserted rows. For
  1026. -- a brief example, consider the following data model and existing data set:
  1027. --
  1028. -- @
  1029. -- Item
  1030. -- name Text
  1031. -- description Text
  1032. -- price Double Maybe
  1033. -- quantity Int Maybe
  1034. --
  1035. -- Primary name
  1036. -- @
  1037. --
  1038. -- > items:
  1039. -- > +------+-------------+-------+----------+
  1040. -- > | name | description | price | quantity |
  1041. -- > +------+-------------+-------+----------+
  1042. -- > | foo | very good | | 3 |
  1043. -- > | bar | | 3.99 | |
  1044. -- > +------+-------------+-------+----------+
  1045. --
  1046. -- This record type has a single natural key on @itemName@. Let's suppose
  1047. -- that we download a CSV of new items to store into the database. Here's
  1048. -- our CSV:
  1049. --
  1050. -- > name,description,price,quantity
  1051. -- > foo,,2.50,6
  1052. -- > bar,even better,,5
  1053. -- > yes,wow,,
  1054. --
  1055. -- We parse that into a list of Haskell records:
  1056. --
  1057. -- @
  1058. -- records =
  1059. -- [ Item { itemName = "foo", itemDescription = ""
  1060. -- , itemPrice = Just 2.50, itemQuantity = Just 6
  1061. -- }
  1062. -- , Item "bar" "even better" Nothing (Just 5)
  1063. -- , Item "yes" "wow" Nothing Nothing
  1064. -- ]
  1065. -- @
  1066. --
  1067. -- The new CSV data is partial. It only includes __updates__ from the
  1068. -- upstream vendor. Our CSV library parses the missing description field as
  1069. -- an empty string. We don't want to override the existing description. So
  1070. -- we can use the 'copyUnlessEmpty' function to say: "Don't update when the
  1071. -- value is empty."
  1072. --
  1073. -- Likewise, the new row for @bar@ includes a quantity, but no price. We do
  1074. -- not want to overwrite the existing price in the database with a @NULL@
  1075. -- value. So we can use 'copyUnlessNull' to only copy the existing values
  1076. -- in.
  1077. --
  1078. -- The final code looks like this:
  1079. -- @
  1080. -- 'insertManyOnDuplicateKeyUpdate' records
  1081. -- [ 'copyUnlessEmpty' ItemDescription
  1082. -- , 'copyUnlessNull' ItemPrice
  1083. -- , 'copyUnlessNull' ItemQuantity
  1084. -- ]
  1085. -- []
  1086. -- @
  1087. --
  1088. -- Once we run that code on the datahase, the new data set looks like this:
  1089. --
  1090. -- > items:
  1091. -- > +------+-------------+-------+----------+
  1092. -- > | name | description | price | quantity |
  1093. -- > +------+-------------+-------+----------+
  1094. -- > | foo | very good | 2.50 | 6 |
  1095. -- > | bar | even better | 3.99 | 5 |
  1096. -- > | yes | wow | | |
  1097. -- > +------+-------------+-------+----------+
  1098. insertManyOnDuplicateKeyUpdate
  1099. :: forall record backend m.
  1100. ( backend ~ PersistEntityBackend record
  1101. , BackendCompatible SqlBackend backend
  1102. , PersistEntity record
  1103. , MonadIO m
  1104. )
  1105. => [record] -- ^ A list of the records you want to insert, or update
  1106. -> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over.
  1107. -> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
  1108. -> ReaderT backend m ()
  1109. insertManyOnDuplicateKeyUpdate [] _ _ = return ()
  1110. insertManyOnDuplicateKeyUpdate records fieldValues updates =
  1111. uncurry rawExecute
  1112. $ mkBulkInsertQuery records fieldValues updates
  1113. -- | This creates the query for 'bulkInsertOnDuplicateKeyUpdate'. If you
  1114. -- provide an empty list of updates to perform, then it will generate
  1115. -- a dummy/no-op update using the first field of the record. This avoids
  1116. -- duplicate key exceptions.
  1117. mkBulkInsertQuery
  1118. :: PersistEntity record
  1119. => [record] -- ^ A list of the records you want to insert, or update
  1120. -> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over.
  1121. -> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
  1122. -> (Text, [PersistValue])
  1123. mkBulkInsertQuery records fieldValues updates =
  1124. (q, recordValues <> updsValues <> copyUnlessValues)
  1125. where
  1126. mfieldDef x = case x of
  1127. CopyField rec -> Right (fieldDbToText (persistFieldDef rec))
  1128. CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val)
  1129. (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
  1130. fieldDbToText = T.pack . escapeDBName . fieldDB
  1131. entityDef' = entityDef records
  1132. firstField = case entityFieldNames of
  1133. [] -> error "The entity you're trying to insert does not have any fields."
  1134. (field:_) -> field
  1135. entityFieldNames = map fieldDbToText (entityFields entityDef')
  1136. tableName = T.pack . escapeDBName . entityDB $ entityDef'
  1137. copyUnlessValues = map snd fieldsToMaybeCopy
  1138. recordValues = concatMap (map toPersistValue . toPersistFields) records
  1139. recordPlaceholders = Util.commaSeparated $ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) records
  1140. mkCondFieldSet n _ = T.concat
  1141. [ n
  1142. , "=COALESCE("
  1143. , "NULLIF("
  1144. , "VALUES(", n, "),"
  1145. , "?"
  1146. , "),"
  1147. , n
  1148. , ")"
  1149. ]
  1150. condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
  1151. fieldSets = map (\n -> T.concat [n, "=VALUES(", n, ")"]) updateFieldNames
  1152. upds = map (Util.mkUpdateText' (pack . escapeDBName) id) updates
  1153. updsValues = map (\(Update _ val _) -> toPersistValue val) updates
  1154. updateText = case fieldSets <> upds <> condFieldSets of
  1155. [] -> T.concat [firstField, "=", firstField]
  1156. xs -> Util.commaSeparated xs
  1157. q = T.concat
  1158. [ "INSERT INTO "
  1159. , tableName
  1160. , " ("
  1161. , Util.commaSeparated entityFieldNames
  1162. , ") "
  1163. , " VALUES "
  1164. , recordPlaceholders
  1165. , " ON DUPLICATE KEY UPDATE "
  1166. , updateText
  1167. ]
  1168. putManySql :: EntityDef -> Int -> Text
  1169. putManySql ent n = putManySql' fields ent n
  1170. where
  1171. fields = entityFields ent
  1172. repsertManySql :: EntityDef -> Int -> Text
  1173. repsertManySql ent n = putManySql' fields ent n
  1174. where
  1175. fields = keyAndEntityFields ent
  1176. putManySql' :: [FieldDef] -> EntityDef -> Int -> Text
  1177. putManySql' fields ent n = q
  1178. where
  1179. fieldDbToText = escape . fieldDB
  1180. mkAssignment f = T.concat [f, "=VALUES(", f, ")"]
  1181. table = escape . entityDB $ ent
  1182. columns = Util.commaSeparated $ map fieldDbToText fields
  1183. placeholders = map (const "?") fields
  1184. updates = map (mkAssignment . fieldDbToText) fields
  1185. q = T.concat
  1186. [ "INSERT INTO "
  1187. , table
  1188. , Util.parenWrapped columns
  1189. , " VALUES "
  1190. , Util.commaSeparated . replicate n
  1191. . Util.parenWrapped . Util.commaSeparated $ placeholders
  1192. , " ON DUPLICATE KEY UPDATE "
  1193. , Util.commaSeparated updates
  1194. ]