186 lines
3.7 KiB
Go
186 lines
3.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/oschwald/maxminddb-golang"
|
|
)
|
|
|
|
type asnRecord struct {
|
|
ASN uint `maxminddb:"autonomous_system_number"`
|
|
Org string `maxminddb:"autonomous_system_organization"`
|
|
}
|
|
|
|
type server struct {
|
|
db *maxminddb.Reader
|
|
nrenASNs map[uint]struct{}
|
|
ready atomic.Bool
|
|
versionTag string
|
|
minASN int
|
|
asnCount int
|
|
}
|
|
|
|
func loadASNSet(path string) (map[uint]struct{}, error) {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
set := make(map[uint]struct{}, 4096)
|
|
sc := bufio.NewScanner(f)
|
|
for sc.Scan() {
|
|
line := strings.TrimSpace(sc.Text())
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
v, err := strconv.ParseUint(line, 10, 32)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
set[uint(v)] = struct{}{}
|
|
}
|
|
return set, sc.Err()
|
|
}
|
|
|
|
func firstForwardedFor(r *http.Request) string {
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff == "" {
|
|
return ""
|
|
}
|
|
parts := strings.Split(xff, ",")
|
|
if len(parts) == 0 {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
|
|
func remoteIP(r *http.Request) string {
|
|
// Prefer XFF (because Traefik is proxy)
|
|
ip := firstForwardedFor(r)
|
|
if ip != "" {
|
|
return ip
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err == nil {
|
|
return host
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
func (s *server) authHandler(w http.ResponseWriter, r *http.Request) {
|
|
if !s.ready.Load() {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
ipStr := remoteIP(r)
|
|
parsed := net.ParseIP(ipStr)
|
|
if parsed == nil {
|
|
// Always 200: we enrich, not block
|
|
w.Header().Set("X-NREN", "0")
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
var rec asnRecord
|
|
if err := s.db.Lookup(parsed, &rec); err != nil || rec.ASN == 0 {
|
|
w.Header().Set("X-NREN", "0")
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("X-ASN", strconv.FormatUint(uint64(rec.ASN), 10))
|
|
if rec.Org != "" {
|
|
// optional: keep it short; some org strings can be long
|
|
w.Header().Set("X-ASN-ORG", rec.Org)
|
|
}
|
|
|
|
_, ok := s.nrenASNs[rec.ASN]
|
|
if ok {
|
|
w.Header().Set("X-NREN", "1")
|
|
} else {
|
|
w.Header().Set("X-NREN", "0")
|
|
}
|
|
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.Header().Set("X-Service", s.versionTag)
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func main() {
|
|
mmdbPath := getenv("MMDB_PATH", "/data/GeoLite2-ASN.mmdb")
|
|
asnListPath := getenv("ASN_LIST_PATH", "/data/nren_asns.txt")
|
|
addr := getenv("ADDR", ":8080")
|
|
version := getenv("VERSION_TAG", "asn-header-service")
|
|
minASN := getenvInt("MIN_ASN_COUNT", 10)
|
|
|
|
db, err := maxminddb.Open(mmdbPath)
|
|
if err != nil {
|
|
log.Fatalf("failed to open mmdb: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
set, err := loadASNSet(asnListPath)
|
|
if err != nil {
|
|
log.Fatalf("failed to load asn list: %v", err)
|
|
}
|
|
asnCount := len(set)
|
|
|
|
s := &server{
|
|
db: db,
|
|
nrenASNs: set,
|
|
versionTag: version,
|
|
minASN: minASN,
|
|
asnCount: asnCount,
|
|
}
|
|
s.ready.Store(true)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/auth", s.authHandler)
|
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
|
if s.asnCount < s.minASN {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: mux,
|
|
ReadHeaderTimeout: 2 * time.Second,
|
|
}
|
|
|
|
log.Printf("listening on %s (asn_count=%d, min_asn=%d)", addr, asnCount, minASN)
|
|
log.Fatal(srv.ListenAndServe())
|
|
}
|
|
|
|
func getenv(k, def string) string {
|
|
v := strings.TrimSpace(os.Getenv(k))
|
|
if v == "" {
|
|
return def
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getenvInt(k string, def int) int {
|
|
v := strings.TrimSpace(os.Getenv(k))
|
|
if v == "" {
|
|
return def
|
|
}
|
|
parsed, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return def
|
|
}
|
|
return parsed
|
|
}
|