diff --git a/src/API/Channels.hs b/src/API/Channels.hs index 2321b83..13507f0 100644 --- a/src/API/Channels.hs +++ b/src/API/Channels.hs @@ -11,6 +11,8 @@ {-# Language FlexibleInstances #-} {-# Language TypeApplications #-} {-# Language DataKinds #-} +{-# Language DuplicateRecordFields #-} +{-# Language NamedFieldPuns #-} module API.Channels (API, handler, JsonChannel(..)) where import Servant @@ -28,17 +30,34 @@ import Data.Generics.Product data JsonChannel = JsonChannel { channel :: Text , visibility :: Visibility } deriving (Show, Generic) +data UpdateChannel = UpdateChannel { identifier :: ChannelID + , channel :: Text + , visibility :: Visibility } + deriving (Show, Generic) instance ToJSON JsonChannel instance FromJSON JsonChannel +instance ToJSON UpdateChannel +instance FromJSON UpdateChannel type API = Auth '[SA.BasicAuth, SA.Cookie, SA.JWT] SafeUser :> BaseAPI -type BaseAPI = "channels" :> ReqBody '[JSON] JsonChannel :> Post '[JSON] JsonChannel +type BaseAPI = "channels" :> ReqBody '[JSON] JsonChannel :> Post '[JSON] UpdateChannel + :<|> "channels" :> Capture "channel_id" ChannelID :> ReqBody '[JSON] UpdateChannel :> Put '[JSON] UpdateChannel :<|> "channels" :> Get '[JSON] [JsonChannel] handler :: ServerT API AppM -handler user = newChannelHandler user :<|> listChannelsHandler user +handler user = newChannelHandler user :<|> updateChannelHandler user :<|> listChannelsHandler user + +requireChannelOwner :: AuthResult SafeUser -> ChannelID -> (SafeUser -> AppM a) -> AppM a +requireChannelOwner auth channelId f = flip requireLoggedIn auth $ \u@SafeUser{username} -> do + unlessM (runDB . channelExists $ channelId) $ throwM err404 + runDB (isChannelOwner channelId username) >>= \o -> if o then f u else throwM err403 + +updateChannelHandler :: AuthResult SafeUser -> ChannelID -> UpdateChannel -> AppM UpdateChannel +updateChannelHandler auth channelId UpdateChannel{visibility} = requireChannelOwner auth channelId $ \_ -> do + mChannel <- fmap toChannel <$> runDB (updateChannelPrivacy channelId visibility) + maybe (throwM err403) return mChannel listChannelsHandler :: AuthResult SafeUser -> AppM [JsonChannel] listChannelsHandler = requireLoggedIn $ \user -> @@ -46,8 +65,11 @@ listChannelsHandler = requireLoggedIn $ \user -> -- use the 'channel' accessor somehow or export it fmap (\Channel{..} -> JsonChannel{..}) <$> runDB (userChannels (view (field @"username") user)) -newChannelHandler :: AuthResult SafeUser -> JsonChannel -> AppM JsonChannel -newChannelHandler auth ch@JsonChannel{..} = flip requireLoggedIn auth $ \user -> do +newChannelHandler :: AuthResult SafeUser -> JsonChannel -> AppM UpdateChannel +newChannelHandler auth JsonChannel{..} = flip requireLoggedIn auth $ \user -> do $logInfo $ "Creating channel for user " <> pack (show user) - runDB (insertChannel (view (field @"username") user) channel visibility) - return ch + mChannel <- fmap toChannel <$> runDB (insertChannel (view (field @"username") user) channel visibility) + maybe (throwM err403{errBody="Could not create the channel"}) return mChannel + +toChannel :: Channel -> UpdateChannel +toChannel Channel{..} = UpdateChannel{..} diff --git a/src/Database/Channel.hs b/src/Database/Channel.hs index 5952f6c..4ef3627 100644 --- a/src/Database/Channel.hs +++ b/src/Database/Channel.hs @@ -4,6 +4,9 @@ module Database.Channel ( userChannels , insertChannel + , channelExists + , isChannelOwner + , updateChannelPrivacy , attachChannel , Visibility(..) , clearChannels @@ -18,6 +21,30 @@ import Database import Database.Selda import Database.Selda.Generic +import Control.Monad.Trans.Maybe + +getChannel :: (MonadSelda m, MonadMask m, MonadIO m) => ChannelID -> m (Maybe Channel) +getChannel identifier = listToMaybe . fromRels <$> query q + where + q = do + ch@(channelId :*: _) <- select (gen channels) + restrict (channelId .== literal identifier) + return ch + +channelExists :: (MonadSelda m, MonadMask m, MonadIO m) => ChannelID -> m Bool +channelExists identifier = not . null <$> getChannel identifier + +isChannelOwner :: (MonadSelda m, MonadIO m, MonadMask m) => ChannelID -> Username -> m Bool +isChannelOwner identifier username = not . null <$> query q + where + q = do + userId :*: _ :*: username' :*: _ <- select (gen users) + channelId :*: _ :*: channelOwner :*: _ <- select (gen channels) + restrict (userId .== channelOwner) + restrict (username' .== literal username) + restrict (channelId .== literal identifier) + return channelId + userChannels :: (MonadMask m, MonadIO m) => Username -> SeldaT m [Channel] userChannels username = fromRels <$> query q where @@ -28,12 +55,25 @@ userChannels username = fromRels <$> query q restrict (username' .== literal username) return channel -insertChannel :: (MonadMask m, MonadIO m) => Username -> Text -> Visibility -> SeldaT m () -insertChannel username channel visibility = do - mUserId <- listToMaybe <$> getUser - void $ forM mUserId $ \userId -> - insertUnless (gen channels) (doesNotExist userId) [ def :*: channel :*: userId :*: visibility ] +updateChannelPrivacy :: (MonadMask m, MonadIO m, MonadSelda m) => ChannelID -> Visibility -> m (Maybe Channel) +updateChannelPrivacy channelId visibility = do + void $ update (gen channels) predicate (\channel -> channel `with` [pVis := literal visibility]) + getChannel channelId where + predicate (channelId' :*: _) = channelId' .== literal channelId + _ :*: _ :*: _ :*: pVis = selectors (gen channels) + +insertChannel :: (MonadMask m, MonadIO m, MonadSelda m) => Username -> Text -> Visibility -> m (Maybe Channel) +insertChannel username channel visibility = runMaybeT $ do + userId <- MaybeT (listToMaybe <$> getUser) + channelId <- toChannelId <$> MaybeT (insertUnless (gen channels) (doesNotExist userId) [ def :*: channel :*: userId :*: visibility ]) + MaybeT (listToMaybe . fromRels <$> query (q channelId)) + where + q channelId = do + ch@(channelId' :*: _) <- select (gen channels) + restrict (channelId' .== literal channelId) + return ch + toChannelId = ChannelID . fromRowId doesNotExist userId (_ :*: channel' :*: userId' :*: _) = channel' .== literal channel .&& userId' .== literal userId getUser = query $ do userId :*: _ :*: user :*: _ <- select (gen users) diff --git a/src/Database/Schema.hs b/src/Database/Schema.hs index cfa07b4..8084168 100644 --- a/src/Database/Schema.hs +++ b/src/Database/Schema.hs @@ -44,7 +44,7 @@ newtype UserID = UserID {unUserID :: Int} deriving (Show) newtype BookID = BookID {unBookID :: Int} deriving (Show, ToJSON, FromJSON, FromHttpApiData, Eq, Ord) -newtype ChannelID = ChannelID {unChannelID :: Int} deriving (Show, ToHttpApiData, FromHttpApiData) +newtype ChannelID = ChannelID {unChannelID :: Int} deriving (Show, ToHttpApiData, FromHttpApiData, ToJSON, FromJSON) newtype TagID = TagID {unTagID :: Int} deriving (Show)