Files
education-flagger/main.go

275 lines
5.6 KiB
Go

package main
import (
"bufio"
"encoding/json"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/oschwald/maxminddb-golang"
)
type asnRecord struct {
ASN uint `maxminddb:"autonomous_system_number"`
Org string `maxminddb:"autonomous_system_organization"`
}
type lookupResponse struct {
Domain string `json:"domain"`
NREN bool `json:"nren"`
ASN *uint `json:"asn,omitempty"`
ASNOrg string `json:"asn_org,omitempty"`
IPs []string `json:"ips"`
MatchedIP string `json:"matched_ip,omitempty"`
Error string `json:"error,omitempty"`
}
type server struct {
db *maxminddb.Reader
nrenASNs map[uint]struct{}
ready atomic.Bool
versionTag string
minASN int
asnCount int
}
func loadASNSet(path string) (map[uint]struct{}, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
set := make(map[uint]struct{}, 4096)
sc := bufio.NewScanner(f)
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
v, err := strconv.ParseUint(line, 10, 32)
if err != nil {
continue
}
set[uint(v)] = struct{}{}
}
return set, sc.Err()
}
func firstForwardedFor(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return ""
}
parts := strings.Split(xff, ",")
if len(parts) == 0 {
return ""
}
return strings.TrimSpace(parts[0])
}
func remoteIP(r *http.Request) string {
// Prefer XFF (because Traefik is proxy)
ip := firstForwardedFor(r)
if ip != "" {
return ip
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
return host
}
return r.RemoteAddr
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func (s *server) authHandler(w http.ResponseWriter, r *http.Request) {
if !s.ready.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
ipStr := remoteIP(r)
parsed := net.ParseIP(ipStr)
if parsed == nil {
// Always 200: we enrich, not block
w.Header().Set("X-NREN", "0")
w.WriteHeader(http.StatusOK)
return
}
var rec asnRecord
if err := s.db.Lookup(parsed, &rec); err != nil || rec.ASN == 0 {
w.Header().Set("X-NREN", "0")
w.WriteHeader(http.StatusOK)
return
}
w.Header().Set("X-ASN", strconv.FormatUint(uint64(rec.ASN), 10))
if rec.Org != "" {
// optional: keep it short; some org strings can be long
w.Header().Set("X-ASN-ORG", rec.Org)
}
_, ok := s.nrenASNs[rec.ASN]
if ok {
w.Header().Set("X-NREN", "1")
} else {
w.Header().Set("X-NREN", "0")
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("X-Service", s.versionTag)
w.WriteHeader(http.StatusOK)
}
func (s *server) lookupHandler(w http.ResponseWriter, r *http.Request) {
if !s.ready.Load() {
writeJSON(w, http.StatusServiceUnavailable, lookupResponse{
NREN: false,
Error: "service not ready",
})
return
}
domain := strings.TrimSpace(r.URL.Query().Get("domain"))
if domain == "" {
writeJSON(w, http.StatusBadRequest, lookupResponse{
NREN: false,
Error: "missing domain",
})
return
}
ips, err := net.LookupIP(domain)
if err != nil || len(ips) == 0 {
writeJSON(w, http.StatusOK, lookupResponse{
Domain: domain,
NREN: false,
Error: "domain lookup failed",
})
return
}
resp := lookupResponse{
Domain: domain,
NREN: false,
IPs: make([]string, 0, len(ips)),
}
var firstASN *uint
var firstOrg string
for _, ip := range ips {
ipStr := ip.String()
resp.IPs = append(resp.IPs, ipStr)
var rec asnRecord
if err := s.db.Lookup(ip, &rec); err != nil || rec.ASN == 0 {
continue
}
if firstASN == nil {
firstASN = new(uint)
*firstASN = rec.ASN
firstOrg = rec.Org
}
if _, ok := s.nrenASNs[rec.ASN]; ok {
asn := rec.ASN
resp.NREN = true
resp.ASN = &asn
resp.ASNOrg = rec.Org
resp.MatchedIP = ipStr
writeJSON(w, http.StatusOK, resp)
return
}
}
if firstASN != nil {
resp.ASN = firstASN
resp.ASNOrg = firstOrg
}
writeJSON(w, http.StatusOK, resp)
}
func main() {
mmdbPath := getenv("MMDB_PATH", "/data/GeoLite2-ASN.mmdb")
asnListPath := getenv("ASN_LIST_PATH", "/data/nren_asns.txt")
addr := getenv("ADDR", ":8080")
version := getenv("VERSION_TAG", "asn-header-service")
minASN := getenvInt("MIN_ASN_COUNT", 10)
db, err := maxminddb.Open(mmdbPath)
if err != nil {
log.Fatalf("failed to open mmdb: %v", err)
}
defer db.Close()
set, err := loadASNSet(asnListPath)
if err != nil {
log.Fatalf("failed to load asn list: %v", err)
}
asnCount := len(set)
s := &server{
db: db,
nrenASNs: set,
versionTag: version,
minASN: minASN,
asnCount: asnCount,
}
s.ready.Store(true)
mux := http.NewServeMux()
mux.HandleFunc("/auth", s.authHandler)
mux.HandleFunc("/lookup", s.lookupHandler)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
if s.asnCount < s.minASN {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
srv := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 2 * time.Second,
}
log.Printf("listening on %s (asn_count=%d, min_asn=%d)", addr, asnCount, minASN)
log.Fatal(srv.ListenAndServe())
}
func getenv(k, def string) string {
v := strings.TrimSpace(os.Getenv(k))
if v == "" {
return def
}
return v
}
func getenvInt(k string, def int) int {
v := strings.TrimSpace(os.Getenv(k))
if v == "" {
return def
}
parsed, err := strconv.Atoi(v)
if err != nil {
return def
}
return parsed
}