reddit-pub/reddit_pub/src/MyLib.hs

132 lines
4.3 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE LambdaCase #-}
module MyLib (defaultMain) where
import Control.Concurrent (threadDelay)
import Control.Exception (bracket)
import Control.Lens
import Control.Monad (forever, void)
import Data.Aeson (FromJSON, ToJSON, Value)
import qualified Data.Aeson as A
import Data.Aeson.Lens (_String, key)
import Data.Bool (bool)
import Data.Config
import Data.Deriving.Aeson
import Data.Foldable (for_)
import Data.Functor.Contravariant ((>$<))
import Data.Text (Text)
import qualified Data.Text.IO as TI
import qualified Data.Text.Strict.Lens as T
import qualified Database.SQLite.Simple as SQL
import GHC.Generics (Generic)
import qualified Membership
import Network.AMQP
( Channel
, DeliveryMode(Persistent)
, closeConnection
, declareExchange
, exchangeName
, exchangeType
, msgBody
, msgDeliveryMode
, newExchange
, newMsg
, openChannel
, openConnection
, publishMsg
)
import Network.Reddit (RedditId(RedditId), publishEntries)
import Network.Wreq.Session (newSession)
import Reddit.Publish (Publish(..))
import Text.Printf (printf)
import Data.ByteString.Lazy (ByteString)
data MessageType = Create | Update
deriving stock (Show, Eq, Generic)
deriving anyclass (ToJSON, FromJSON)
data Message = Message
{ messageType :: MessageType
, messageIdentifier :: RedditId
, messageContent :: Value
}
deriving stock (Show, Eq, Generic)
deriving (ToJSON, FromJSON)
via AesonCodec (Field (CamelCase <<< DropPrefix "message")) Message
toMessage :: SQL.Connection -> Publish IO (Maybe Message) -> Publish IO Value
toMessage sqlConn (Publish p) = Publish $ \entry -> do
case RedditId <$> (entry ^? key "id" . _String) of
Nothing -> p Nothing
Just redditId -> do
event <- bool Create Update <$> Membership.isSeen sqlConn redditId
p $ Just $ Message event redditId entry
sqlRecorder :: SQL.Connection -> Publish IO (Maybe RedditId)
sqlRecorder conn = Publish $ maybe (pure ()) (Membership.recordSeen conn)
amqpPublisher :: Channel -> Text -> Publish IO (Maybe ByteString)
amqpPublisher channel exchange = Publish $ \case
Nothing -> pure ()
Just lbs ->
void $ publishMsg channel exchange routingKey (message lbs)
where
routingKey = "doesn't matter on fanout"
message lbs = newMsg
{ msgBody = lbs
, msgDeliveryMode = Just Persistent
}
stdoutPublisher :: Publish IO String
stdoutPublisher = Publish putStrLn
data Fetch
= Fetch Fetcher
| PublishMessage Message
| ParseFailed
fetchToLog :: Fetch -> String
fetchToLog (Fetch fetcher) = printf "Refreshing %s" (show $ fetcherSubreddit fetcher)
fetchToLog ParseFailed = printf "Failed parsing"
fetchToLog (PublishMessage msg) = messageToLog msg
where
messageToLog :: Message -> String
messageToLog m = printf "Publishing %s as type %s" (show $ messageIdentifier m) (show $ messageType m)
defaultMain :: FilePath -> IO ()
defaultMain path = do
conf <- readConfig path
pass <- getPassword (conf ^. amqp . password)
let rabbitConnect = openConnection
(conf ^. amqp . host . T.unpacked)
(conf ^. amqp . vhost)
(conf ^. amqp . username)
pass
bracket rabbitConnect closeConnection $ \conn -> do
SQL.withConnection (conf ^. sqlite) $ \sqlConn -> do
SQL.execute_ sqlConn "create table if not exists membership (reddit_id primary key)"
chan <- openChannel conn
declareExchange chan newExchange { exchangeName = "reddit_posts", exchangeType = "fanout" }
sess <- newSession
let encoder = amqpPublisher chan "reddit_posts"
recorder = sqlRecorder sqlConn
publisher = (fmap A.encode >$< encoder) <> (fmap messageIdentifier >$< recorder) <> (maybe ParseFailed PublishMessage >$< logger)
logger = fetchToLog >$< stdoutPublisher
forever $ do
for_ (conf ^. fetchers) $ \fetcher -> do
publish logger (Fetch fetcher)
publishEntries (toMessage sqlConn publisher) sess fetcher
threadDelay (15 * 60_000_000)
getPassword :: Password -> IO Text
getPassword (Password p) = pure p
getPassword (File path) = TI.readFile path