{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE DeriveAnyClass        #-}
{-# LANGUAGE DerivingStrategies    #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
module PlutusTx.Code where

import PlutusTx.Coverage
import PlutusTx.Lift.Instances ()

import PlutusIR qualified as PIR

import PlutusCore qualified as PLC
import PlutusCore.Pretty qualified as PLC
import UntypedPlutusCore qualified as UPLC

import Control.Exception
import Flat (Flat (..), unflat)
import Flat.Decoder (DecodeException)

import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import ErrorCode
-- We do not use qualified import because the whole module contains off-chain code
import Prelude as Haskell

-- NOTE: any changes to this type must be paralleled by changes
-- in the plugin code that generates values of this type. That is
-- done by code generation so it's not typechecked normally.
-- | A compiled Plutus Tx program. The last type parameter indicates
-- the type of the Haskell expression that was compiled, and
-- hence the type of the compiled code.
--
-- Note: the compiled PLC program does *not* have normalized types,
-- if you want to put it on the chain you must normalize the types first.
data CompiledCodeIn uni fun a =
    -- | Serialized UPLC code and possibly serialized PIR code with metadata used for program coverage.
    SerializedCode BS.ByteString (Maybe BS.ByteString) CoverageIndex
    -- | Deserialized UPLC program, and possibly deserialized PIR program with metadata used for program coverage.
    | DeserializedCode (UPLC.Program UPLC.NamedDeBruijn uni fun ()) (Maybe (PIR.Program PLC.TyName PLC.Name uni fun ())) CoverageIndex

-- | 'CompiledCodeIn' instantiated with default built-in types and functions.
type CompiledCode = CompiledCodeIn PLC.DefaultUni PLC.DefaultFun

-- | Apply a compiled function to a compiled argument.
applyCode
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun, uni `PLC.Everywhere` PLC.PrettyConst, PLC.GShow uni, PLC.Pretty fun)
    => CompiledCodeIn uni fun (a -> b) -> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
applyCode :: CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
applyCode CompiledCodeIn uni fun (a -> b)
fun CompiledCodeIn uni fun a
arg = Program NamedDeBruijn uni fun ()
-> Maybe (Program TyName Name uni fun ())
-> CoverageIndex
-> CompiledCodeIn uni fun b
forall (uni :: * -> *) fun a.
Program NamedDeBruijn uni fun ()
-> Maybe (Program TyName Name uni fun ())
-> CoverageIndex
-> CompiledCodeIn uni fun a
DeserializedCode (Program NamedDeBruijn uni fun ()
-> Program NamedDeBruijn uni fun ()
-> Program NamedDeBruijn uni fun ()
forall a name (uni :: * -> *) fun.
Monoid a =>
Program name uni fun a
-> Program name uni fun a -> Program name uni fun a
UPLC.applyProgram (CompiledCodeIn uni fun (a -> b) -> Program NamedDeBruijn uni fun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc CompiledCodeIn uni fun (a -> b)
fun) (CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc CompiledCodeIn uni fun a
arg)) (Program TyName Name uni fun ()
-> Program TyName Name uni fun () -> Program TyName Name uni fun ()
forall a tyname name (uni :: * -> *) fun.
Monoid a =>
Program tyname name uni fun a
-> Program tyname name uni fun a -> Program tyname name uni fun a
PIR.applyProgram (Program TyName Name uni fun ()
 -> Program TyName Name uni fun ()
 -> Program TyName Name uni fun ())
-> Maybe (Program TyName Name uni fun ())
-> Maybe
     (Program TyName Name uni fun () -> Program TyName Name uni fun ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompiledCodeIn uni fun (a -> b)
-> Maybe (Program TyName Name uni fun ())
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Maybe (Program TyName Name uni fun ())
getPir CompiledCodeIn uni fun (a -> b)
fun Maybe
  (Program TyName Name uni fun () -> Program TyName Name uni fun ())
-> Maybe (Program TyName Name uni fun ())
-> Maybe (Program TyName Name uni fun ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CompiledCodeIn uni fun a -> Maybe (Program TyName Name uni fun ())
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Maybe (Program TyName Name uni fun ())
getPir CompiledCodeIn uni fun a
arg) (CompiledCodeIn uni fun (a -> b) -> CoverageIndex
forall (uni :: * -> *) fun a.
CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun (a -> b)
fun CoverageIndex -> CoverageIndex -> CoverageIndex
forall a. Semigroup a => a -> a -> a
<> CompiledCodeIn uni fun a -> CoverageIndex
forall (uni :: * -> *) fun a.
CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun a
arg)

-- | The size of a 'CompiledCodeIn', in AST nodes.
sizePlc :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun, uni `PLC.Everywhere` PLC.PrettyConst, PLC.GShow uni, PLC.Pretty fun) => CompiledCodeIn uni fun a -> Integer
sizePlc :: CompiledCodeIn uni fun a -> Integer
sizePlc = Program NamedDeBruijn uni fun () -> Integer
forall name (uni :: * -> *) fun ann.
Program name uni fun ann -> Integer
UPLC.programSize (Program NamedDeBruijn uni fun () -> Integer)
-> (CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ())
-> CompiledCodeIn uni fun a
-> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc

instance (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun, uni `PLC.Everywhere` PLC.PrettyConst, PLC.GShow uni, PLC.Pretty fun)
    => Flat (CompiledCodeIn uni fun a) where
    encode :: CompiledCodeIn uni fun a -> Encoding
encode CompiledCodeIn uni fun a
c = Program NamedDeBruijn uni fun () -> Encoding
forall a. Flat a => a -> Encoding
encode (CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc CompiledCodeIn uni fun a
c)

    decode :: Get (CompiledCodeIn uni fun a)
decode = do
        Program NamedDeBruijn uni fun ()
p <- Get (Program NamedDeBruijn uni fun ())
forall a. Flat a => Get a
decode
        CompiledCodeIn uni fun a -> Get (CompiledCodeIn uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CompiledCodeIn uni fun a -> Get (CompiledCodeIn uni fun a))
-> CompiledCodeIn uni fun a -> Get (CompiledCodeIn uni fun a)
forall a b. (a -> b) -> a -> b
$ Program NamedDeBruijn uni fun ()
-> Maybe (Program TyName Name uni fun ())
-> CoverageIndex
-> CompiledCodeIn uni fun a
forall (uni :: * -> *) fun a.
Program NamedDeBruijn uni fun ()
-> Maybe (Program TyName Name uni fun ())
-> CoverageIndex
-> CompiledCodeIn uni fun a
DeserializedCode Program NamedDeBruijn uni fun ()
p Maybe (Program TyName Name uni fun ())
forall a. Maybe a
Nothing CoverageIndex
forall a. Monoid a => a
mempty

    size :: CompiledCodeIn uni fun a -> NumBits -> NumBits
size CompiledCodeIn uni fun a
c = Program NamedDeBruijn uni fun () -> NumBits -> NumBits
forall a. Flat a => a -> NumBits -> NumBits
size (CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc CompiledCodeIn uni fun a
c)

{- Note [Deserializing the AST]
The types suggest that we can fail to deserialize the AST that we embedded in the program.
However, we just did it ourselves, so this should be impossible, and we signal this with an
exception.
-}
newtype ImpossibleDeserialisationFailure = ImpossibleDeserialisationFailure DecodeException
    deriving anyclass (Show ImpossibleDeserialisationFailure
Typeable ImpossibleDeserialisationFailure
Typeable ImpossibleDeserialisationFailure
-> Show ImpossibleDeserialisationFailure
-> (ImpossibleDeserialisationFailure -> SomeException)
-> (SomeException -> Maybe ImpossibleDeserialisationFailure)
-> (ImpossibleDeserialisationFailure -> String)
-> Exception ImpossibleDeserialisationFailure
SomeException -> Maybe ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure -> String
ImpossibleDeserialisationFailure -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: ImpossibleDeserialisationFailure -> String
$cdisplayException :: ImpossibleDeserialisationFailure -> String
fromException :: SomeException -> Maybe ImpossibleDeserialisationFailure
$cfromException :: SomeException -> Maybe ImpossibleDeserialisationFailure
toException :: ImpossibleDeserialisationFailure -> SomeException
$ctoException :: ImpossibleDeserialisationFailure -> SomeException
$cp2Exception :: Show ImpossibleDeserialisationFailure
$cp1Exception :: Typeable ImpossibleDeserialisationFailure
Exception)
instance Show ImpossibleDeserialisationFailure where
    show :: ImpossibleDeserialisationFailure -> String
show (ImpossibleDeserialisationFailure DecodeException
e) = String
"Failed to deserialise our own program! This is a bug, please report it. Caused by: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ DecodeException -> String
forall a. Show a => a -> String
show DecodeException
e

instance HasErrorCode ImpossibleDeserialisationFailure where
      errorCode :: ImpossibleDeserialisationFailure -> ErrorCode
errorCode ImpossibleDeserialisationFailure {} = Natural -> ErrorCode
ErrorCode Natural
40

-- | Get the actual Plutus Core program out of a 'CompiledCodeIn'.
getPlc
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun, uni `PLC.Everywhere` PLC.PrettyConst, PLC.GShow uni, PLC.Pretty fun)
    => CompiledCodeIn uni fun a -> UPLC.Program UPLC.NamedDeBruijn uni fun ()
getPlc :: CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlc CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
    SerializedCode ByteString
plc Maybe ByteString
_ CoverageIndex
_ -> case ByteString -> Decoded (Program NamedDeBruijn uni fun ())
forall a b. (Flat a, AsByteString b) => b -> Decoded a
unflat (ByteString -> ByteString
BSL.fromStrict ByteString
plc) of
        Left DecodeException
e  -> ImpossibleDeserialisationFailure
-> Program NamedDeBruijn uni fun ()
forall a e. Exception e => e -> a
throw (ImpossibleDeserialisationFailure
 -> Program NamedDeBruijn uni fun ())
-> ImpossibleDeserialisationFailure
-> Program NamedDeBruijn uni fun ()
forall a b. (a -> b) -> a -> b
$ DecodeException -> ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure DecodeException
e
        Right Program NamedDeBruijn uni fun ()
p -> Program NamedDeBruijn uni fun ()
p
    DeserializedCode Program NamedDeBruijn uni fun ()
plc Maybe (Program TyName Name uni fun ())
_ CoverageIndex
_ -> Program NamedDeBruijn uni fun ()
plc

-- | Get the Plutus IR program, if there is one, out of a 'CompiledCodeIn'.
getPir
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
    => CompiledCodeIn uni fun a -> Maybe (PIR.Program PIR.TyName PIR.Name uni fun ())
getPir :: CompiledCodeIn uni fun a -> Maybe (Program TyName Name uni fun ())
getPir CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
    SerializedCode ByteString
_ Maybe ByteString
pir CoverageIndex
_ -> case Maybe ByteString
pir of
        Just ByteString
bs -> case ByteString -> Decoded (Program TyName Name uni fun ())
forall a b. (Flat a, AsByteString b) => b -> Decoded a
unflat (ByteString -> ByteString
BSL.fromStrict ByteString
bs) of
            Left DecodeException
e  -> ImpossibleDeserialisationFailure
-> Maybe (Program TyName Name uni fun ())
forall a e. Exception e => e -> a
throw (ImpossibleDeserialisationFailure
 -> Maybe (Program TyName Name uni fun ()))
-> ImpossibleDeserialisationFailure
-> Maybe (Program TyName Name uni fun ())
forall a b. (a -> b) -> a -> b
$ DecodeException -> ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure DecodeException
e
            Right Program TyName Name uni fun ()
p -> Program TyName Name uni fun ()
-> Maybe (Program TyName Name uni fun ())
forall a. a -> Maybe a
Just Program TyName Name uni fun ()
p
        Maybe ByteString
Nothing -> Maybe (Program TyName Name uni fun ())
forall a. Maybe a
Nothing
    DeserializedCode Program NamedDeBruijn uni fun ()
_ Maybe (Program TyName Name uni fun ())
pir CoverageIndex
_ -> Maybe (Program TyName Name uni fun ())
pir

getCovIdx :: CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx :: CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
  SerializedCode ByteString
_ Maybe ByteString
_ CoverageIndex
idx   -> CoverageIndex
idx
  DeserializedCode Program NamedDeBruijn uni fun ()
_ Maybe (Program TyName Name uni fun ())
_ CoverageIndex
idx -> CoverageIndex
idx