// Copyright 2020 Shivaram Lingamneni // Copyright 2009 The Go Authors package flatip import ( "bytes" "errors" "net" ) var ( v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff} IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} ErrInvalidIPString = errors.New("String could not be interpreted as an IP address") ) // packed versions of net.IP and net.IPNet; these are pure value types, // so they can be compared with == and used as map keys. // IP is the 128-bit representation of the IPv6 address, using the 4-in-6 mapping // if necessary: type IP [16]byte // IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes. type IPNet struct { IP PrefixLen uint8 } // NetIP converts an IP into a net.IP. func (ip IP) NetIP() (result net.IP) { result = make(net.IP, 16) copy(result[:], ip[:]) return } // FromNetIP converts a net.IP into an IP. func FromNetIP(ip net.IP) (result IP) { if len(ip) == 16 { copy(result[:], ip[:]) } else { result[10] = 0xff result[11] = 0xff copy(result[12:], ip[:]) } return } // IPv4 returns the IP address representation of a.b.c.d func IPv4(a, b, c, d byte) (result IP) { copy(result[:12], v4InV6Prefix) result[12] = a result[13] = b result[14] = c result[15] = d return } // ParseIP parses a string representation of an IP address into an IP. // Unlike net.ParseIP, it returns an error instead of a zero value on failure, // since the zero value of `IP` is a representation of a valid IP (::0, the // IPv6 "unspecified address"). func ParseIP(ipstr string) (ip IP, err error) { // TODO reimplement this without net.ParseIP netip := net.ParseIP(ipstr) if netip == nil { err = ErrInvalidIPString return } netip = netip.To16() copy(ip[:], netip) return } // String returns the string representation of an IP func (ip IP) String() string { // TODO reimplement this without using (net.IP).String() return (net.IP)(ip[:]).String() } // IsIPv4 returns whether the IP is an IPv4 address. func (ip IP) IsIPv4() bool { return bytes.Equal(ip[:12], v4InV6Prefix) } // IsLoopback returns whether the IP is a loopback address. func (ip IP) IsLoopback() bool { if ip.IsIPv4() { return ip[12] == 127 } else { return ip == IPv6loopback } } func rawCidrMask(length int) (m IP) { n := uint(length) for i := 0; i < 16; i++ { if n >= 8 { m[i] = 0xff n -= 8 continue } m[i] = ^byte(0xff >> n) return } return } func (ip IP) applyMask(mask IP) (result IP) { for i := 0; i < 16; i += 1 { result[i] = ip[i] & mask[i] } return } func cidrMask(ones, bits int) (result IP) { switch bits { case 32: return rawCidrMask(96 + ones) case 128: return rawCidrMask(ones) default: return } } // Mask returns the result of masking ip with the CIDR mask of // length 'ones', out of a total of 'bits' (which must be either // 32 for an IPv4 subnet or 128 for an IPv6 subnet). func (ip IP) Mask(ones, bits int) (result IP) { return ip.applyMask(cidrMask(ones, bits)) } // ToNetIPNet converts an IPNet into a net.IPNet. func (cidr IPNet) ToNetIPNet() (result net.IPNet) { return net.IPNet{ IP: cidr.IP.NetIP(), Mask: net.CIDRMask(int(cidr.PrefixLen), 128), } } // Contains retuns whether the network contains `ip`. func (cidr IPNet) Contains(ip IP) bool { maskedIP := ip.Mask(int(cidr.PrefixLen), 128) return cidr.IP == maskedIP } // FromNetIPnet converts a net.IPNet into an IPNet. func FromNetIPNet(network net.IPNet) (result IPNet) { ones, _ := network.Mask.Size() if len(network.IP) == 16 { copy(result.IP[:], network.IP[:]) } else { result.IP[10] = 0xff result.IP[11] = 0xff copy(result.IP[12:], network.IP[:]) ones += 96 } // perform masking so that equal CIDRs are == result.IP = result.IP.Mask(ones, 128) result.PrefixLen = uint8(ones) return } // String returns a string representation of an IPNet. func (cidr IPNet) String() string { ip := make(net.IP, 16) copy(ip[:], cidr.IP[:]) ipnet := net.IPNet{ IP: ip, Mask: net.CIDRMask(int(cidr.PrefixLen), 128), } return ipnet.String() } // ParseCIDR parses a string representation of an IP network in CIDR notation, // then returns it as an IPNet (along with the original, unmasked address). func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) { // TODO reimplement this without net.ParseCIDR nip, nipnet, err := net.ParseCIDR(netstr) if err != nil { return } return FromNetIP(nip), FromNetIPNet(*nipnet), nil } // begin ad-hoc utilities // ParseToNormalizedNet attempts to interpret a string either as an IP // network in CIDR notation, returning an IPNet, or as an IP address, // returning an IPNet that contains only that address. func ParseToNormalizedNet(netstr string) (ipnet IPNet, err error) { _, ipnet, err = ParseCIDR(netstr) if err == nil { return } ip, err := ParseIP(netstr) if err == nil { ipnet.IP = ip ipnet.PrefixLen = 128 } return } // IPInNets is a convenience function for testing whether an IP is contained // in any member of a slice of IPNet's. func IPInNets(addr IP, nets []IPNet) bool { for _, net := range nets { if net.Contains(addr) { return true } } return false }