Skip to content
This repository was archived by the owner on Jul 12, 2025. It is now read-only.

Commit 3a5618b

Browse files
author
Joe Atzberger
committed
json Marshaller interface for Inet
Note that this does allow the serialization/deserialization between empty string and a Null struct. It does NOT permit invalid addresses or masks. See #79
1 parent 00d516f commit 3a5618b

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

inet.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pgtype
33
import (
44
"database/sql/driver"
55
"net"
6+
"strings"
67

78
errors "golang.org/x/xerrors"
89
)
@@ -122,7 +123,7 @@ func (src *Inet) AssignTo(dst interface{}) error {
122123
return errors.Errorf("cannot decode %#v into %T", src, dst)
123124
}
124125

125-
func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
126+
func (dst *Inet) DecodeText(_ *ConnInfo, src []byte) error {
126127
if src == nil {
127128
*dst = Inet{Status: Null}
128129
return nil
@@ -150,7 +151,7 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
150151
return nil
151152
}
152153

153-
func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
154+
func (dst *Inet) DecodeBinary(_ *ConnInfo, src []byte) error {
154155
if src == nil {
155156
*dst = Inet{Status: Null}
156157
return nil
@@ -218,6 +219,27 @@ func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
218219
return append(buf, src.IPNet.IP...), nil
219220
}
220221

222+
// MarshalJSON implements the json.Marshaler interface
223+
func (src Inet) MarshalJSON() ([]byte, error) {
224+
if src.Status != Present {
225+
return []byte(`""`), nil
226+
}
227+
v, err := src.Value()
228+
if err != nil || v == nil {
229+
return []byte(`""`), err
230+
}
231+
return []byte(`"` + v.(string) + `"`), nil
232+
}
233+
234+
// UnmarshalJSON implements the json.Marshaler interface
235+
func (dst *Inet) UnmarshalJSON(data []byte) error {
236+
trimmed := strings.Trim(string(data), `"`)
237+
if trimmed == "" {
238+
return dst.DecodeText(nil, nil)
239+
}
240+
return dst.DecodeText(nil, []byte(trimmed))
241+
}
242+
221243
// Scan implements the database/sql Scanner interface.
222244
func (dst *Inet) Scan(src interface{}) error {
223245
if src == nil {

inet_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,60 @@ func TestInetAssignTo(t *testing.T) {
114114
}
115115
}
116116
}
117+
118+
func TestInetMarshalJSON(t *testing.T) {
119+
successfulTests := []struct {
120+
json string
121+
source pgtype.Inet
122+
}{
123+
{source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`},
124+
{source: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`},
125+
{source: pgtype.Inet{Status: pgtype.Null}, json: `""`},
126+
{source: pgtype.Inet{}, json: `""`},
127+
}
128+
129+
for i, tt := range successfulTests {
130+
got, err := tt.source.MarshalJSON()
131+
if err != nil {
132+
t.Errorf("%d: %v", i, err)
133+
}
134+
if !reflect.DeepEqual(got, []byte(tt.json)) {
135+
t.Errorf("%d: expected JSON `%s`, but it was %s", i, tt.json, string(got))
136+
}
137+
}
138+
}
139+
140+
func TestInetUnmarshalJSON(t *testing.T) {
141+
successfulTests := []struct {
142+
json string
143+
expected pgtype.Inet
144+
}{
145+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`},
146+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1"`},
147+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`},
148+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e"`},
149+
{expected: pgtype.Inet{Status: pgtype.Null}, json: `""`}, // empty is OK, equivalent to our null struct
150+
}
151+
badJSON := []string{
152+
`"127.0.0.1/"`, // no network
153+
`"444.555.666.777/32"`, // bad addr
154+
`"nonsense"`, // bad everything
155+
}
156+
157+
for i, tt := range successfulTests {
158+
got := pgtype.Inet{}
159+
if err := got.UnmarshalJSON([]byte(tt.json)); err != nil {
160+
t.Errorf("%d: %v", i, err)
161+
}
162+
if !reflect.DeepEqual(got, tt.expected) {
163+
t.Errorf("%d: expected %v from JSON `%s`, but it was %v", i, tt.expected, tt.json, got)
164+
}
165+
}
166+
167+
for i, example := range badJSON {
168+
got := pgtype.Inet{}
169+
if err := got.UnmarshalJSON([]byte(example)); err == nil {
170+
t.Errorf("%d: Expected error for %s, but got none", i, example)
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)