Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions cvc5_pythonic_api/cvc5_pythonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ def _get_ctx(ctx):
return ctx


def _get_ctx2(a, b, ctx=None):
if is_expr(a):
return a.ctx
if is_expr(b):
return b.ctx
if ctx is None:
ctx = main_ctx()
return ctx


def get_ctx(ctx):
"""
Returns `ctx` if it is not `None`, and the default context otherwise.
Expand Down Expand Up @@ -2128,9 +2138,8 @@ def Length(s, ctx=None):
>>> simplify(l)
1
"""
s = _py2expr(s)
ctx = _get_ctx(ctx)
return ArithRef(ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), ctx)
s = _py2expr(s, ctx)
return ArithRef(s.ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), s.ctx)


def SubString(s, offset, length, ctx=None):
Expand All @@ -2141,16 +2150,15 @@ def SubString(s, offset, length, ctx=None):
>>> simplify(SubString(StringVal('hello'),3,2))
"lo"
"""
ctx = _get_ctx(ctx)
s = _py2expr(s)
offset = _py2expr(offset)
length = _py2expr(length)
s = _py2expr(s, ctx)
offset = _py2expr(offset, s.ctx)
length = _py2expr(length, s.ctx)
return StringRef(
ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), ctx
s.ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), s.ctx
)


def SubSeq(s, offset, length):
def SubSeq(s, offset, length, ctx=None):
"""Extract subsequence starting at offset

>>> seq = Concat(Unit(IntVal(1)),Unit(IntVal(2)))
Expand All @@ -2159,9 +2167,9 @@ def SubSeq(s, offset, length):
>>> simplify(SubSeq(seq,1,0))
(as seq.empty (Seq Int))()
"""
s = _py2expr(s)
offset = _py2expr(offset)
length = _py2expr(length)
s = _py2expr(s, ctx)
offset = _py2expr(offset, s.ctx)
length = _py2expr(length, s.ctx)
return SeqRef(
s.ctx.tm.mkTerm(Kind.SEQ_EXTRACT, s.ast, offset.ast, length.ast), s.ctx
)
Expand All @@ -2178,7 +2186,7 @@ def SeqUpdate(s, t, i):
>>> simplify(SeqUpdate(lst,Unit(IntVal(1)),4))
(seq.++ (seq.unit 1) (seq.unit 2) (seq.unit 3))()
"""
i = _py2expr(i)
i = _py2expr(i, t.ctx)
return SeqRef(t.ctx.tm.mkTerm(Kind.SEQ_UPDATE, s.ast, i.ast, t.ast), t.ctx)


Expand Down Expand Up @@ -2206,9 +2214,9 @@ def Contains(a, b, ctx=None):
>>> simplify(s)
True
"""
ctx = _get_ctx(ctx)
a = _py2expr(a)
b = _py2expr(b)
ctx = _get_ctx2(a, b, ctx)
a = _py2expr(a, ctx)
b = _py2expr(b, ctx)
if is_string(a) and is_string(b):
return BoolRef(ctx.tm.mkTerm(Kind.STRING_CONTAINS, a.ast, b.ast), ctx)
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_CONTAINS, a.ast, b.ast), ctx)
Expand All @@ -2224,9 +2232,9 @@ def PrefixOf(a, b, ctx=None):
>>> simplify(s2)
False
"""
ctx = _get_ctx(ctx)
a = _py2expr(a)
b = _py2expr(b)
ctx = _get_ctx2(a, b, ctx)
a = _py2expr(a, ctx)
b = _py2expr(b, ctx)
if is_string(a) and is_string(b):
return BoolRef(ctx.tm.mkTerm(Kind.STRING_PREFIX, a.ast, b.ast), ctx)
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_PREFIX, a.ast, b.ast), ctx)
Expand All @@ -2242,9 +2250,9 @@ def SuffixOf(a, b, ctx=None):
>>> simplify(s2)
True
"""
ctx = _get_ctx(ctx)
a = _py2expr(a)
b = _py2expr(b)
ctx = _get_ctx2(a, b, ctx)
a = _py2expr(a, ctx)
b = _py2expr(b, ctx)
if is_string(a) and is_string(b):
return BoolRef(ctx.tm.mkTerm(Kind.STRING_SUFFIX, a.ast, b.ast), ctx)
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_SUFFIX, a.ast, b.ast), ctx)
Expand All @@ -2260,11 +2268,11 @@ def IndexOf(s, substr, offset=None):
>>> simplify(IndexOf("abcabc", "bc", 2))
4
"""
ctx = _get_ctx2(s, substr)
if offset is None:
offset = IntVal(0)
s = _py2expr(s)
substr = _py2expr(substr)
ctx = _get_ctx(None)
offset = IntVal(0, ctx)
s = _py2expr(s, ctx)
substr = _py2expr(substr, ctx)
if _is_int(offset):
offset = IntVal(offset, ctx)
if is_string(s) and is_string(substr):
Expand All @@ -2285,10 +2293,12 @@ def Replace(s, src, dst):
>>> simplify(Replace(seq,Unit(IntVal(1)),Unit(IntVal(5))))
(seq.++ (seq.unit 5) (seq.unit 2))()
"""
s = _py2expr(s)
src = _py2expr(src)
dst = _py2expr(dst)
ctx = _get_ctx(None)
ctx = _get_ctx2(dst, s)
if ctx is None and is_expr(src):
ctx = src.ctx
s = _py2expr(s, ctx)
src = _py2expr(src, ctx)
dst = _py2expr(dst, ctx)
if is_string(s) and is_string(src) and is_string(dst):
return StringRef(
ctx.tm.mkTerm(Kind.STRING_REPLACE, s.ast, src.ast, dst.ast), ctx
Expand All @@ -2307,8 +2317,7 @@ def StrToInt(s):
123
"""
s = _py2expr(s)
ctx = _get_ctx(s.ctx)
return ArithRef(ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), ctx)
return ArithRef(s.ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), s.ctx)


def IntToStr(s):
Expand Down Expand Up @@ -2399,7 +2408,7 @@ def Re(s, ctx=None):
>>> simplify(InRe('b',re))
False
"""
s = _py2expr(s)
s = _py2expr(s, ctx)
return ReRef(s.ctx.tm.mkTerm(Kind.STRING_TO_REGEXP, s.ast), s.ctx)


Expand All @@ -2426,7 +2435,9 @@ def InRe(s, re):
>>> print (simplify(InRe("c", re)))
False
"""
s = _py2expr(s)
ctx = _get_ctx2(s, re)
s = _py2expr(s, ctx)
re = _py2expr(re, ctx)
return BoolRef(s.ctx.tm.mkTerm(Kind.STRING_IN_REGEXP, s.ast, re.ast), s.ctx)


Expand Down Expand Up @@ -2556,6 +2567,7 @@ def Range(lo, hi, ctx=None):
>>> print(simplify(InRe("bb", range)))
False
"""
ctx = _get_ctx2(lo, hi)
lo = _py2expr(lo, ctx)
hi = _py2expr(hi, ctx)
return ReRef(lo.ctx.tm.mkTerm(Kind.REGEXP_RANGE, lo.ast, hi.ast), lo.ctx)
Expand Down Expand Up @@ -4766,12 +4778,8 @@ def Concat(*args):
sz = len(args)
if debugging():
_assert(sz >= 2, "At least two arguments expected.")
args = [_py2expr(s) for s in args]
ctx = _get_ctx(None)
for a in args:
if is_expr(a):
ctx = a.ctx
break
ctx = _get_ctx(_ctx_from_ast_arg_list(args))
args = [_py2expr(s, ctx) for s in args]
if debugging():
_assert(
all([is_bv(a) or is_string(a) or is_seq(a) or is_re(a) for a in args]),
Expand Down Expand Up @@ -5860,15 +5868,14 @@ def SetComplement(s):
return ArrayRef(ctx.tm.mkTerm(Kind.SET_COMPLEMENT, s.ast), ctx)


def Singleton(s):
def Singleton(s, ctx=None):
"""The single element set of just e

>>> Singleton(IntVal(1))
Singleton(1)
"""
s = _py2expr(s)
ctx = s.ctx
return SetRef(ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), ctx)
s = _py2expr(s, ctx)
return SetRef(s.ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), s.ctx)


def SetDifference(a, b):
Expand Down