Skip to content
158 changes: 158 additions & 0 deletions lib/Echidna/SymExec/Bounds.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
{-# LANGUAGE GADTs #-}

module Echidna.SymExec.Bounds (findApproximateBounds) where

import Control.Monad.IO.Unlift (MonadUnliftIO, liftIO)
import Data.Function ((&))
import Data.Maybe (mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text qualified as T
import Data.List (nub)
import GHC.IORef (IORef)
import Optics.Core ((.~), (%), (%~))
import EVM.Fetch qualified as Fetch
import EVM (loadContract, resetState)
import EVM.Effects (TTY, ReadConfig)
import EVM.Solidity (SolcContract(..), Method(..))
import EVM.Solvers (SolverGroup)
import EVM.SymExec (abstractVM, mkCalldata, verifyInputs, VeriOpts(..), showModel)
import EVM.Types (Addr, VMType(..), Expr(..), W256, Prop(..))
import qualified EVM.Types (VM(..))
import Control.Monad.ST (stToIO, RealWorld)
import Control.Monad.State.Strict (execState, runStateT)

import Echidna.Types.Config (EConfig(..))
import Echidna.Types.Solidity (SolConf(..))
import Echidna.Types.Tx (TxConf(..))
import Echidna.Types.Cache (ContractCache, SlotCache)
import Echidna.Types (fromEVM)
import Echidna.SymExec.Helpers

import Data.Map.Strict qualified as Map
import qualified EVM.Types as EVM
import EVM.ABI (Sig(..))

data ValueRange = VRLEq (Set T.Text) W256
| VRGEq (Set T.Text) W256
| VREq (Set T.Text) W256
deriving (Eq, Ord)

instance Show ValueRange where
show (VRLEq vars lit) = T.unpack $ T.concat [T.intercalate "," (Set.toList vars), " <= ", T.pack (show lit)]
show (VRGEq vars lit) = T.unpack $ T.concat [T.intercalate "," (Set.toList vars), " >= ", T.pack (show lit)]
show (VREq vars lit) = T.unpack $ T.concat [T.intercalate "," (Set.toList vars), " == ", T.pack (show lit)]

data ResolvedInterval = RInterval (Maybe W256) (Maybe W256)
deriving (Eq, Ord)

instance Show ResolvedInterval where
show (RInterval (Just l) (Just u)) | l == u = " == " ++ show l
show (RInterval (Just l) (Just u)) = " in [" ++ show l ++ ", " ++ show u ++ "]"
show (RInterval (Just l) Nothing) = " >= " ++ show l
show (RInterval Nothing (Just u)) = " <= " ++ show u
show (RInterval Nothing Nothing) = ""

resolveRanges :: [ValueRange] -> Map.Map (Set T.Text) ResolvedInterval
resolveRanges ranges =
let
grouped = Map.fromListWith (++) $ map (\case
VRLEq vars lit -> (vars, [VRLEq vars lit])
VRGEq vars lit -> (vars, [VRGEq vars lit])
VREq vars lit -> (vars, [VREq vars lit])) ranges

processGroup :: [ValueRange] -> ResolvedInterval
processGroup rs =
let
acc (RInterval l u) (VRLEq _ lit) = RInterval l (Just $ maybe lit (min lit) u)
acc (RInterval l u) (VRGEq _ lit) = RInterval (Just $ maybe lit (max lit) l) u
acc (RInterval _ _) (VREq _ lit) = RInterval (Just lit) (Just lit)
in foldl acc (RInterval Nothing Nothing) rs

in Map.map processGroup grouped

getProps :: Expr a -> [Prop]
getProps (Failure asserts _ _) = asserts
getProps (Success asserts _ _ _) = asserts
getProps (Partial asserts _ _) = asserts
getProps _ = []

computeRange :: Prop -> Maybe ValueRange
computeRange (PLEq e1 e2) = resolve e1 e2 VRLEq VRGEq
computeRange (PGEq e1 e2) = resolve e1 e2 VRGEq VRLEq
computeRange (PLT e1 e2) = resolve e1 e2 VRLEq VRGEq
computeRange (PGT e1 e2) = resolve e1 e2 VRGEq VRLEq
computeRange (PEq e1 e2) = resolve e1 e2 VREq VREq
computeRange _ = Nothing

-- Decompose an expression into a set of variables and a literal, if possible
decompose :: Expr a -> Maybe (Set T.Text, W256)
decompose (Add e (Lit k)) = decompose e & fmap (\(v, k') -> (v, k' + k))
decompose (Add (Lit k) e) = decompose e & fmap (\(v, k') -> (v, k' + k))
decompose (Sub e (Lit k)) = decompose e & fmap (\(v, k') -> (v, k' - k))
decompose e = case (exprDependsOnArg e, exprLit e) of
(vars, _) | not (Set.null vars) -> Just (vars, 0)
(_, Just lit) -> Just (Set.empty, lit)
_ -> Nothing

resolve :: Expr a -> Expr a -> (Set T.Text -> W256 -> ValueRange) -> (Set T.Text -> W256 -> ValueRange) -> Maybe ValueRange
resolve e1 e2 op op' =
case (decompose e1, decompose e2) of
(Just (vars, k1), Just (vars', k2)) | Set.null vars' -> Just $ op vars (k2 - k1)
(Just (vars, k1), Just (vars', k2)) | Set.null vars -> Just $ op' vars' (k1 - k2)
_ -> Nothing

exprLit :: Expr a -> Maybe W256
exprLit (Lit l) = Just l
exprLit _ = Nothing

exprDependsOnArg :: Expr a -> Set T.Text
exprDependsOnArg (Failure asserts _ _) = Set.unions $ map propDependsOnArg asserts
exprDependsOnArg (Partial asserts _ _) = Set.unions $ map propDependsOnArg asserts
exprDependsOnArg (Var v) = if T.isPrefixOf "arg" v then Set.singleton v else Set.empty
exprDependsOnArg (Lit _) = Set.empty
exprDependsOnArg (SEx _ e) = exprDependsOnArg e
exprDependsOnArg (SLT e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
exprDependsOnArg (Add e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
exprDependsOnArg (Sub e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
exprDependsOnArg _ = Set.empty

propDependsOnArg :: Prop -> Set T.Text
propDependsOnArg (PLEq e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
propDependsOnArg (PGEq e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
propDependsOnArg (PLT e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
propDependsOnArg (PGT e1 e2) = Set.union (exprDependsOnArg e1) (exprDependsOnArg e2)
propDependsOnArg _ = Set.empty

findApproximateBounds :: (MonadUnliftIO m, ReadConfig m, TTY m) =>
Method -> SolcContract -> EVM.Types.VM Concrete RealWorld -> Addr -> EConfig -> VeriOpts -> SolverGroup -> Fetch.RpcInfo -> IORef ContractCache -> IORef SlotCache -> m [T.Text]
findApproximateBounds method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef = do
calldataSym@(cd, constraints) <- mkCalldata (Just (Sig method.methodSignature (snd <$> method.inputs))) []
let
fetcher = cachedOracle contractCacheRef slotCacheRef solvers rpcInfo
dst = conf.solConf.contractAddr
vmSym = abstractVM calldataSym contract.runtimeCode Nothing False
vmSym' <- liftIO $ stToIO vmSym
vmReset <- liftIO $ snd <$> runStateT (fromEVM resetState) vm
let vm' = vmReset & execState (loadContract (LitAddr dst))
& vmMakeSymbolic conf.txConf.maxTimeDelay conf.txConf.maxBlockDelay
& #constraints %~ (++ constraints)
& #state % #callvalue .~ Lit 0
& #state % #caller .~ LitAddr defaultSender
& #state % #calldata .~ cd
& #env % #contracts .~ (Map.union vmSym'.env.contracts vm.env.contracts)
(_, models, partials) <- verifyInputs solvers veriOpts fetcher vm' (Just nonReverts)
liftIO $ mapM_ (showModel mempty) $ map (\(x, y) -> (y, x)) models
let
modelProps = concatMap (getProps . snd) models
partialExprProps = concatMap (getProps . snd) partials
--partialProps = map (fst . fst) partials
allProps = modelProps ++ partialExprProps -- ++ partialProps
allRanges = nub $ mapMaybe computeRange allProps
resolved = resolveRanges allRanges
allRangesText = map (T.pack . show) allRanges
showResolved (vars, interval) = T.concat [T.intercalate "," (Set.toList vars), T.pack (show interval)]
bounds = if Map.null resolved
then allRangesText
else map showResolved (Map.toList resolved)
return $ if null bounds then [] else (T.pack "Constraints inferred:" : allRangesText) ++ (T.pack "Constraints resolved:" : bounds)
95 changes: 12 additions & 83 deletions lib/Echidna/SymExec/Common.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE GADTs #-}

module Echidna.SymExec.Common where

import Control.Monad.IO.Unlift (MonadUnliftIO, liftIO)
Expand All @@ -8,22 +10,27 @@ import Data.Maybe (fromMaybe, mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text qualified as T
import Data.Text.IO qualified as TIO
import Data.Vector (toList, fromList)
import Data.List (nub)
import GHC.IORef (IORef, readIORef)
import Optics.Core ((.~), (%), (%~))
import EVM.ABI (abiKind, AbiKind(Dynamic), AbiValue(..), AbiType(..), Sig(..), decodeAbiValue)
import EVM.Fetch qualified as Fetch
import EVM (loadContract, resetState, forceLit)
--import EVM.Expr (Expr(..))
import EVM.Effects (TTY, ReadConfig)
import EVM.Solidity (SolcContract(..), Method(..))
import EVM.Solvers (SolverGroup)
import EVM.SymExec (abstractVM, mkCalldata, verifyInputs, VeriOpts(..), checkAssertions)
import EVM.SymExec (abstractVM, mkCalldata, verifyInputs, VeriOpts(..), Postcondition, showModel)
import EVM.Types (Addr, Frame(..), FrameState(..), VMType(..), EType(..), Expr(..), word256Bytes, Block(..), W256, SMTCex(..), ProofResult(..), Prop(..), Query(..))
import qualified EVM.Types (VM(..), Env(..))
import EVM.Format (formatPartial)
import Control.Monad.ST (stToIO, RealWorld)
import Control.Monad.State.Strict (execState, runStateT)

import Echidna.SymExec.Bounds (findApproximateBounds)
import Echidna.SymExec.Helpers
import Echidna.Types (fromEVM)
import Echidna.Types.Config (EConfig(..))
import Echidna.Types.Solidity (SolConf(..))
Expand All @@ -50,71 +57,9 @@ suitableForSymExec m = not $ null m.inputs
&& null (filter (\(_, t) -> abiKind t == Dynamic) m.inputs)
&& not (T.isInfixOf "_no_symexec" m.name)

-- | Sets result to Nothing, and sets gas to ()
vmMakeSymbolic :: W256 -> W256 -> EVM.Types.VM Concrete s -> EVM.Types.VM Symbolic s
vmMakeSymbolic maxTimestampDiff maxNumberDiff vm
= EVM.Types.VM
{ result = Nothing
, state = frameStateMakeSymbolic vm.state
, frames = map frameMakeSymbolic vm.frames
, env = vm.env
, block = blockMakeSymbolic vm.block
, tx = vm.tx
, logs = vm.logs
, traces = vm.traces
, cache = vm.cache
, burned = ()
, iterations = vm.iterations
, constraints = addBlockConstraints maxTimestampDiff maxNumberDiff vm.block vm.constraints
, config = vm.config
, forks = vm.forks
, currentFork = vm.currentFork
, labels = vm.labels
, osEnv = vm.osEnv
, freshVar = vm.freshVar
, exploreDepth = 0
, keccakPreImgs = vm.keccakPreImgs
}

blockMakeSymbolic :: Block -> Block
blockMakeSymbolic b
= b {
timestamp = Var "symbolic_block_timestamp"
, number = Var "symbolic_block_number"
}

addBlockConstraints :: W256 -> W256 -> Block -> [Prop] -> [Prop]
addBlockConstraints maxTimestampDiff maxNumberDiff block cs =
cs ++ [
PGEq (Var "symbolic_block_timestamp") (block.timestamp), PLEq (Sub (Var "symbolic_block_timestamp") (block.timestamp)) $ Lit maxTimestampDiff,
PGEq (Var "symbolic_block_number") (block.number), PLEq (Sub (Var "symbolic_block_number") (block.number)) $ Lit maxNumberDiff
]

senderConstraints :: Set Addr -> [Prop]
senderConstraints as = [foldr (\a b -> POr b (PEq (SymAddr "caller") (LitAddr a))) (PBool False) $ Set.toList as]

frameStateMakeSymbolic :: FrameState Concrete s -> FrameState Symbolic s
frameStateMakeSymbolic fs
= FrameState
{ contract = fs.contract
, codeContract = fs.codeContract
, code = fs.code
, pc = fs.pc
, stack = fs.stack
, memory = fs.memory
, memorySize = fs.memorySize
, calldata = fs.calldata
, callvalue = fs.callvalue
, caller = fs.caller
, gas = ()
, returndata = fs.returndata
, static = fs.static
, overrideCaller = fs.overrideCaller
, resetCaller = fs.resetCaller
}

frameMakeSymbolic :: Frame Concrete s -> Frame Symbolic s
frameMakeSymbolic fr = Frame { context = fr.context, state = frameStateMakeSymbolic fr.state }

-- | Convert a n-bit unsigned integer to a n-bit signed integer.
uintToInt :: W256 -> Integer
Expand Down Expand Up @@ -186,22 +131,6 @@ modelToTx dst oldTimestamp oldNumber method senders fallbackSender result =
r -> error ("Unexpected value in `modelToTx`: " ++ show r)


cachedOracle :: IORef ContractCache -> IORef SlotCache -> SolverGroup -> Fetch.RpcInfo -> Fetch.Fetcher t m s
cachedOracle contractCacheRef slotCacheRef solvers info q = do
case q of
PleaseFetchContract addr _ continue -> do
cache <- liftIO $ readIORef contractCacheRef
case Map.lookup addr cache of
Just (Just contract) -> pure $ continue contract
_ -> oracle q
PleaseFetchSlot addr slot continue -> do
cache <- liftIO $ readIORef slotCacheRef
case Map.lookup addr cache >>= Map.lookup slot of
Just (Just value) -> pure $ continue value
_ -> oracle q
_ -> oracle q

where oracle = Fetch.oracle solvers info

rpcFetcher :: Functor f =>
f a -> Maybe W256 -> f (Fetch.BlockNumber, a)
Expand All @@ -221,10 +150,10 @@ getUnknownLogs = mapMaybe (\case
Error err -> Just $ T.pack err
_ -> Nothing)


exploreMethod :: (MonadUnliftIO m, ReadConfig m, TTY m) =>
Method -> SolcContract -> EVM.Types.VM Concrete RealWorld -> Addr -> EConfig -> VeriOpts -> SolverGroup -> Fetch.RpcInfo -> IORef ContractCache -> IORef SlotCache -> m ([TxOrError], PartialsLogs)

exploreMethod method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef = do
Method -> SolcContract -> EVM.Types.VM Concrete RealWorld -> Addr -> EConfig -> VeriOpts -> SolverGroup -> Fetch.RpcInfo -> IORef ContractCache -> IORef SlotCache -> Postcondition RealWorld -> m ([TxOrError], PartialsLogs)
exploreMethod method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef post = do
--liftIO $ putStrLn ("Exploring: " ++ T.unpack method.methodSignature)
--pushWorkerEvent undefined
calldataSym@(cd, constraints) <- mkCalldata (Just (Sig method.methodSignature (snd <$> method.inputs))) []
Expand All @@ -243,7 +172,7 @@ exploreMethod method contract vm defaultSender conf veriOpts solvers rpcInfo con
& #env % #contracts .~ (Map.union vmSym'.env.contracts vm.env.contracts)
-- TODO we might want to switch vm's state.baseState value to to AbstractBase eventually.
-- Doing so might mess up concolic execution.
(_, models, partials) <- verifyInputs solvers veriOpts fetcher vm' (Just $ checkAssertions [0x1])
(_, models, partials) <- verifyInputs solvers veriOpts fetcher vm' (Just post)
let results = map fst models
--liftIO $ mapM_ TIO.putStrLn partials
return (map (modelToTx dst vm.block.timestamp vm.block.number method conf.solConf.sender defaultSender) results, map (formatPartial . fst) partials)
18 changes: 13 additions & 5 deletions lib/Echidna/SymExec/Exploration.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ import Data.Maybe (fromJust)
import Data.Set qualified as Set
import Data.Text (unpack, Text)
import Data.Text.Encoding (encodeUtf8)
import Data.Text.IO qualified as TIO
import Data.List.NonEmpty (fromList)
import EVM.Effects (defaultEnv, defaultConfig, Config(..), Env(..))
import EVM.Solidity (SolcContract(..), Method(..))
import EVM.Solvers (withSolvers)
import EVM.SymExec (IterConfig(..), LoopHeuristic (..), VeriOpts(..), defaultVeriOpts)
import EVM.SymExec (IterConfig(..), LoopHeuristic (..), VeriOpts(..), defaultVeriOpts, checkAssertions)
import EVM.Types (abiKeccak, FunctionSelector, VMType(..))
import qualified EVM.Types (VM(..))
import Control.Monad.ST (RealWorld)
Expand All @@ -31,6 +32,7 @@ import Echidna.Types.Tx (Tx(..), TxCall(..))
import Echidna.Types.Worker (WorkerEvent(..))
import Echidna.Types.Random (rElem)
import Echidna.SymExec.Common (suitableForSymExec, exploreMethod, rpcFetcher, TxOrError(..), PartialsLogs)
import Echidna.SymExec.Bounds (findApproximateBounds)
import Echidna.Worker (pushWorkerEvent)

-- | Uses symbolic execution to find transactions which would increase coverage.
Expand Down Expand Up @@ -82,8 +84,8 @@ filterTarget :: Maybe [Text] -> [FunctionSelector] -> Maybe Tx -> Method -> Bool
filterTarget symExecTargets assertSigs tx method =
case (symExecTargets, tx) of
(Just ms, _) -> method.name `elem` ms
(_, Just (Tx { call = SolCall (methodName, _) })) -> (null assertSigs || methodSig `elem` assertSigs) && method.name == methodName && suitableForSymExec method
_ -> (null assertSigs || methodSig `elem` assertSigs) && suitableForSymExec method
(_, Just (Tx { call = SolCall (methodName, _) })) -> (null assertSigs) && method.name == methodName && suitableForSymExec method
_ -> (null assertSigs) && suitableForSymExec method
where methodSig = abiKeccak $ encodeUtf8 method.methodSignature

exploreContract :: (MonadIO m, MonadThrow m, MonadReader Echidna.Types.Config.Env m, MonadState WorkerState m) => SolcContract -> Method -> EVM.Types.VM Concrete RealWorld -> m (ThreadId, MVar ([TxOrError], PartialsLogs))
Expand Down Expand Up @@ -112,11 +114,17 @@ exploreContract contract method vm = do
-- For now, we will be exploring a single method at a time.
-- In some cases, this methods list will have only one method, but in other cases, it will have several methods.
-- This is to improve the user experience, as it will produce results more often, instead having to wait for exploring several
res <- exploreMethod method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef
res <- exploreMethod method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef (checkAssertions [0x1])
liftIO $ putMVar resultChan res
liftIO $ putMVar doneChan ()
liftIO $ putMVar threadIdChan threadId
liftIO $ takeMVar doneChan

threadId <- liftIO $ takeMVar threadIdChan
pure (threadId, resultChan)
let boundsEnv = defaultEnv { config = defaultConfig { maxWidth = 5, maxDepth = Just 8, maxBufSize = 12, promiseNoReent = True, debug = False, dumpQueries = False, numCexFuzz = 10 } }
liftIO $ flip runReaderT boundsEnv $ withSolvers conf.campaignConf.symExecSMTSolver 1 1 (Just 1) $ \solvers -> do
liftIO $ flip runReaderT boundsEnv $ do
bounds <- findApproximateBounds method contract vm defaultSender conf veriOpts solvers rpcInfo contractCacheRef slotCacheRef
liftIO $ mapM_ TIO.putStrLn bounds

pure (threadId, resultChan)
Loading
Loading