From 2ac01d13110383b826ad94d75222990efb86f6c2 Mon Sep 17 00:00:00 2001 From: Daniel Larraz Date: Thu, 20 Nov 2025 16:50:35 -0600 Subject: [PATCH] Pass ctx to _py2expr when it is available --- cvc5_pythonic_api/cvc5_pythonic.py | 95 ++++++++++++++++-------------- 1 file changed, 51 insertions(+), 44 deletions(-) diff --git a/cvc5_pythonic_api/cvc5_pythonic.py b/cvc5_pythonic_api/cvc5_pythonic.py index 15e0b9a..44c0d3b 100644 --- a/cvc5_pythonic_api/cvc5_pythonic.py +++ b/cvc5_pythonic_api/cvc5_pythonic.py @@ -213,6 +213,16 @@ def _get_ctx(ctx): return ctx +def _get_ctx2(a, b, ctx=None): + if is_expr(a): + return a.ctx + if is_expr(b): + return b.ctx + if ctx is None: + ctx = main_ctx() + return ctx + + def get_ctx(ctx): """ Returns `ctx` if it is not `None`, and the default context otherwise. @@ -2128,9 +2138,8 @@ def Length(s, ctx=None): >>> simplify(l) 1 """ - s = _py2expr(s) - ctx = _get_ctx(ctx) - return ArithRef(ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), ctx) + s = _py2expr(s, ctx) + return ArithRef(s.ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), s.ctx) def SubString(s, offset, length, ctx=None): @@ -2141,16 +2150,15 @@ def SubString(s, offset, length, ctx=None): >>> simplify(SubString(StringVal('hello'),3,2)) "lo" """ - ctx = _get_ctx(ctx) - s = _py2expr(s) - offset = _py2expr(offset) - length = _py2expr(length) + s = _py2expr(s, ctx) + offset = _py2expr(offset, s.ctx) + length = _py2expr(length, s.ctx) return StringRef( - ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), ctx + s.ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), s.ctx ) -def SubSeq(s, offset, length): +def SubSeq(s, offset, length, ctx=None): """Extract subsequence starting at offset >>> seq = Concat(Unit(IntVal(1)),Unit(IntVal(2))) @@ -2159,9 +2167,9 @@ def SubSeq(s, offset, length): >>> simplify(SubSeq(seq,1,0)) (as seq.empty (Seq Int))() """ - s = _py2expr(s) - offset = _py2expr(offset) - length = _py2expr(length) + s = _py2expr(s, ctx) + offset = _py2expr(offset, s.ctx) + length = _py2expr(length, s.ctx) return SeqRef( s.ctx.tm.mkTerm(Kind.SEQ_EXTRACT, s.ast, offset.ast, length.ast), s.ctx ) @@ -2178,7 +2186,7 @@ def SeqUpdate(s, t, i): >>> simplify(SeqUpdate(lst,Unit(IntVal(1)),4)) (seq.++ (seq.unit 1) (seq.unit 2) (seq.unit 3))() """ - i = _py2expr(i) + i = _py2expr(i, t.ctx) return SeqRef(t.ctx.tm.mkTerm(Kind.SEQ_UPDATE, s.ast, i.ast, t.ast), t.ctx) @@ -2206,9 +2214,9 @@ def Contains(a, b, ctx=None): >>> simplify(s) True """ - ctx = _get_ctx(ctx) - a = _py2expr(a) - b = _py2expr(b) + ctx = _get_ctx2(a, b, ctx) + a = _py2expr(a, ctx) + b = _py2expr(b, ctx) if is_string(a) and is_string(b): return BoolRef(ctx.tm.mkTerm(Kind.STRING_CONTAINS, a.ast, b.ast), ctx) return BoolRef(ctx.tm.mkTerm(Kind.SEQ_CONTAINS, a.ast, b.ast), ctx) @@ -2224,9 +2232,9 @@ def PrefixOf(a, b, ctx=None): >>> simplify(s2) False """ - ctx = _get_ctx(ctx) - a = _py2expr(a) - b = _py2expr(b) + ctx = _get_ctx2(a, b, ctx) + a = _py2expr(a, ctx) + b = _py2expr(b, ctx) if is_string(a) and is_string(b): return BoolRef(ctx.tm.mkTerm(Kind.STRING_PREFIX, a.ast, b.ast), ctx) return BoolRef(ctx.tm.mkTerm(Kind.SEQ_PREFIX, a.ast, b.ast), ctx) @@ -2242,9 +2250,9 @@ def SuffixOf(a, b, ctx=None): >>> simplify(s2) True """ - ctx = _get_ctx(ctx) - a = _py2expr(a) - b = _py2expr(b) + ctx = _get_ctx2(a, b, ctx) + a = _py2expr(a, ctx) + b = _py2expr(b, ctx) if is_string(a) and is_string(b): return BoolRef(ctx.tm.mkTerm(Kind.STRING_SUFFIX, a.ast, b.ast), ctx) return BoolRef(ctx.tm.mkTerm(Kind.SEQ_SUFFIX, a.ast, b.ast), ctx) @@ -2260,11 +2268,11 @@ def IndexOf(s, substr, offset=None): >>> simplify(IndexOf("abcabc", "bc", 2)) 4 """ + ctx = _get_ctx2(s, substr) if offset is None: - offset = IntVal(0) - s = _py2expr(s) - substr = _py2expr(substr) - ctx = _get_ctx(None) + offset = IntVal(0, ctx) + s = _py2expr(s, ctx) + substr = _py2expr(substr, ctx) if _is_int(offset): offset = IntVal(offset, ctx) if is_string(s) and is_string(substr): @@ -2285,10 +2293,12 @@ def Replace(s, src, dst): >>> simplify(Replace(seq,Unit(IntVal(1)),Unit(IntVal(5)))) (seq.++ (seq.unit 5) (seq.unit 2))() """ - s = _py2expr(s) - src = _py2expr(src) - dst = _py2expr(dst) - ctx = _get_ctx(None) + ctx = _get_ctx2(dst, s) + if ctx is None and is_expr(src): + ctx = src.ctx + s = _py2expr(s, ctx) + src = _py2expr(src, ctx) + dst = _py2expr(dst, ctx) if is_string(s) and is_string(src) and is_string(dst): return StringRef( ctx.tm.mkTerm(Kind.STRING_REPLACE, s.ast, src.ast, dst.ast), ctx @@ -2307,8 +2317,7 @@ def StrToInt(s): 123 """ s = _py2expr(s) - ctx = _get_ctx(s.ctx) - return ArithRef(ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), ctx) + return ArithRef(s.ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), s.ctx) def IntToStr(s): @@ -2399,7 +2408,7 @@ def Re(s, ctx=None): >>> simplify(InRe('b',re)) False """ - s = _py2expr(s) + s = _py2expr(s, ctx) return ReRef(s.ctx.tm.mkTerm(Kind.STRING_TO_REGEXP, s.ast), s.ctx) @@ -2426,7 +2435,9 @@ def InRe(s, re): >>> print (simplify(InRe("c", re))) False """ - s = _py2expr(s) + ctx = _get_ctx2(s, re) + s = _py2expr(s, ctx) + re = _py2expr(re, ctx) return BoolRef(s.ctx.tm.mkTerm(Kind.STRING_IN_REGEXP, s.ast, re.ast), s.ctx) @@ -2556,6 +2567,7 @@ def Range(lo, hi, ctx=None): >>> print(simplify(InRe("bb", range))) False """ + ctx = _get_ctx2(lo, hi) lo = _py2expr(lo, ctx) hi = _py2expr(hi, ctx) return ReRef(lo.ctx.tm.mkTerm(Kind.REGEXP_RANGE, lo.ast, hi.ast), lo.ctx) @@ -4766,12 +4778,8 @@ def Concat(*args): sz = len(args) if debugging(): _assert(sz >= 2, "At least two arguments expected.") - args = [_py2expr(s) for s in args] - ctx = _get_ctx(None) - for a in args: - if is_expr(a): - ctx = a.ctx - break + ctx = _get_ctx(_ctx_from_ast_arg_list(args)) + args = [_py2expr(s, ctx) for s in args] if debugging(): _assert( all([is_bv(a) or is_string(a) or is_seq(a) or is_re(a) for a in args]), @@ -5860,15 +5868,14 @@ def SetComplement(s): return ArrayRef(ctx.tm.mkTerm(Kind.SET_COMPLEMENT, s.ast), ctx) -def Singleton(s): +def Singleton(s, ctx=None): """The single element set of just e >>> Singleton(IntVal(1)) Singleton(1) """ - s = _py2expr(s) - ctx = s.ctx - return SetRef(ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), ctx) + s = _py2expr(s, ctx) + return SetRef(s.ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), s.ctx) def SetDifference(a, b):