diff --git a/internal/search/api_search.go b/internal/search/api_search.go index c77dc828..0243ff1f 100644 --- a/internal/search/api_search.go +++ b/internal/search/api_search.go @@ -4,9 +4,12 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "strings" + "time" "github.com/moov-io/base/log" + "github.com/moov-io/watchman/pkg/address" "github.com/moov-io/watchman/pkg/search" "github.com/gorilla/mux" @@ -80,41 +83,128 @@ func (c *controller) search(w http.ResponseWriter, r *http.Request) { func readSearchRequest(r *http.Request) (search.Entity[search.Value], error) { q := r.URL.Query() + var err error var req search.Entity[search.Value] req.Name = strings.TrimSpace(q.Get("name")) - req.Type = search.EntityType(strings.TrimSpace(strings.ToLower(q.Get("entityType")))) + req.Type = search.EntityType(strings.TrimSpace(strings.ToLower(q.Get("type")))) req.Source = search.SourceAPIRequest req.SourceID = strings.TrimSpace(q.Get("requestID")) switch req.Type { - case search.EntityPerson: // "person" + case search.EntityPerson: + req.Person = &search.Person{ + Name: req.Name, + AltNames: q["altNames"], + Gender: search.Gender(strings.TrimSpace(q.Get("gender"))), + BirthDate: readDate(q.Get("birthDate")), + DeathDate: readDate(q.Get("deathDate")), + Titles: q["titles"], + // GovernmentIDs []GovernmentID `json:"governmentIDs"` + } + + case search.EntityBusiness: + req.Business = &search.Business{ + Name: req.Name, + Created: readDate(q.Get("created")), + Dissolved: readDate(q.Get("dissolved")), + // Identifier []Identifier `json:"identifier"` + } + + case search.EntityOrganization: + req.Organization = &search.Organization{ + Name: req.Name, + Created: readDate(q.Get("created")), + Dissolved: readDate(q.Get("dissolved")), + // Identifier []Identifier `json:"identifier"` + } + + case search.EntityAircraft: + req.Aircraft = &search.Aircraft{ + Name: req.Name, + Type: search.AircraftType(q.Get("aircraftType")), + Flag: q.Get("flag"), + Built: readDate("built"), + ICAOCode: q.Get("icaoCode"), + Model: q.Get("model"), + SerialNumber: q.Get("serialNumber"), + } + + case search.EntityVessel: + req.Vessel = &search.Vessel{ + Name: req.Name, + IMONumber: q.Get("imoNumber"), + Type: search.VesselType(q.Get("vesselType")), + Flag: q.Get("flag"), + Built: readDate("built"), + Model: q.Get("model"), + MMSI: q.Get("mmsi"), + CallSign: q.Get("callSign"), + Owner: q.Get("owner"), + } + req.Vessel.Tonnage, err = readInt(q.Get("tonnage")) + if err != nil { + return req, fmt.Errorf("reading vessel tonnage: %w", err) + } + req.Vessel.GrossRegisteredTonnage, err = readInt(q.Get("grossRegisteredTonnage")) + if err != nil { + return req, fmt.Errorf("reading vessel GrossRegisteredTonnage: %w", err) + } + } + + req.CryptoAddresses = readCryptoCurrencyAddresses(q["cryptoAddress"]) + req.Addresses = readAddresses(q["address"]) - case search.EntityBusiness: // "business" + // TODO(adam): + // Affiliations []Affiliation `json:"affiliations"` + // SanctionsInfo *SanctionsInfo `json:"sanctionsInfo"` + // HistoricalInfo []HistoricalInfo `json:"historicalInfo"` - case search.EntityAircraft: // "aircraft" + return req, nil +} - case search.EntityVessel: // "vessel" +var ( + allowedDateFormats = []string{"2006-01-02", "2006-01", "2006"} +) - default: - return req, fmt.Errorf("unsupported entityType: %v", req.Type) +func readDate(input string) *time.Time { + if input == "" { + return nil } - // Person *Person `json:"person"` - // Business *Business `json:"business"` - // Organization *Organization `json:"organization"` - // Aircraft *Aircraft `json:"aircraft"` - // Vessel *Vessel `json:"vessel"` - - // CryptoAddresses []CryptoAddress `json:"cryptoAddresses"` - // TODO(adam): support multiple values? How does Go handle that? + for _, format := range allowedDateFormats { + tt, err := time.Parse(format, input) + if err == nil { + return &tt + } + } + return nil +} - // Addresses []Address `json:"addresses"` +func readInt(input string) (int, error) { + n, err := strconv.ParseInt(input, 10, 32) + return int(n), err +} - // Affiliations []Affiliation `json:"affiliations"` - // SanctionsInfo *SanctionsInfo `json:"sanctionsInfo"` - // HistoricalInfo []HistoricalInfo `json:"historicalInfo"` - // Titles []string `json:"titles"` +func readCryptoCurrencyAddresses(inputs []string) []search.CryptoAddress { + var out []search.CryptoAddress + for _, input := range inputs { + // Query param looks like: cryptoAddress=XBT:x123456 + parts := strings.Split(input, ":") + if len(parts) == 2 { + out = append(out, search.CryptoAddress{ + Currency: parts[0], + Address: parts[1], + }) + } + } + return out +} - return req, nil +func readAddresses(inputs []string) []search.Address { + var out []search.Address + for _, input := range inputs { + out = append(out, address.ParseAddress(input)) + } + return out } diff --git a/pkg/search/models.go b/pkg/search/models.go index cd68b329..47c5ff94 100644 --- a/pkg/search/models.go +++ b/pkg/search/models.go @@ -188,6 +188,9 @@ var ( VesselTypeCargo VesselType = "cargo" ) +// CryptoAddress +// +// &cryptoAddress=XBT:x123456 type CryptoAddress struct { Currency string `json:"currency"` Address string `json:"address"` diff --git a/pkg/search/similarity.go b/pkg/search/similarity.go index f19eb455..41cb7edf 100644 --- a/pkg/search/similarity.go +++ b/pkg/search/similarity.go @@ -2,6 +2,7 @@ package search import ( "fmt" + "io" "math" "strings" "time" @@ -11,26 +12,30 @@ import ( // Similarity calculates a match score between a query and an index entity. func Similarity[Q any, I any](query Entity[Q], index Entity[I]) float64 { + return DebugSimilarity[Q, I](nil, query, index) +} + +func DebugSimilarity[Q any, I any](w io.Writer, query Entity[Q], index Entity[I]) float64 { pieces := make([]scorePiece, 0) // Primary identifiers (IMO number, Call Sign, etc.) - highest weight exactMatchWeight := 50.0 - pieces = append(pieces, compareExactIdentifiers(query, index, exactMatchWeight)) + pieces = append(pieces, compareExactIdentifiers(w, query, index, exactMatchWeight)) // Name match is critical nameWeight := 30.0 - pieces = append(pieces, compareName(query, index, nameWeight)) + pieces = append(pieces, compareName(w, query, index, nameWeight)) // Entity-specific comparisons (type, flag, etc) entityWeight := 15.0 - pieces = append(pieces, compareEntitySpecific(query, index, entityWeight)) + pieces = append(pieces, compareEntitySpecific(w, query, index, entityWeight)) // Supporting information (addresses, sanctions, etc) supportingWeight := 5.0 - pieces = append(pieces, compareSupportingInfo(query, index, supportingWeight)) + pieces = append(pieces, compareSupportingInfo(w, query, index, supportingWeight)) // Compute final score with coverage logic - return calculateFinalScore(pieces, index) + return calculateFinalScore(w, pieces, index) } // scorePiece is a partial scoring result from one comparison function @@ -44,7 +49,7 @@ type scorePiece struct { pieceType string // e.g. "name", "entity", "identifiers", etc. } -func compareExactIdentifiers[Q any, I any](query Entity[Q], index Entity[I], weight float64) scorePiece { +func compareExactIdentifiers[Q any, I any](w io.Writer, query Entity[Q], index Entity[I], weight float64) scorePiece { matches := 0 totalWeight := 0.0 score := 0.0 @@ -107,7 +112,7 @@ func compareExactIdentifiers[Q any, I any](query Entity[Q], index Entity[I], wei } } -func compareName[Q any, I any](query Entity[Q], index Entity[I], weight float64) scorePiece { +func compareName[Q any, I any](w io.Writer, query Entity[Q], index Entity[I], weight float64) scorePiece { qName := strings.TrimSpace(strings.ToLower(query.Name)) iName := strings.TrimSpace(strings.ToLower(index.Name)) @@ -163,11 +168,7 @@ func compareName[Q any, I any](query Entity[Q], index Entity[I], weight float64) } } -const ( - debugEntitySpecific = true -) - -func compareEntitySpecific[Q any, I any](query Entity[Q], index Entity[I], weight float64) scorePiece { +func compareEntitySpecific[Q any, I any](w io.Writer, query Entity[Q], index Entity[I], weight float64) scorePiece { // If types don't match, it's an immediate 0 if query.Type != index.Type { return scorePiece{ @@ -184,19 +185,18 @@ func compareEntitySpecific[Q any, I any](query Entity[Q], index Entity[I], weigh switch query.Type { case EntityVessel: - typeScore, matched, fieldsCompared = compareVesselFields(query.Vessel, index.Vessel) + typeScore, matched, fieldsCompared = compareVesselFields(w, query.Vessel, index.Vessel) case EntityPerson: - typeScore, matched, fieldsCompared = comparePersonFields(query.Person, index.Person) + typeScore, matched, fieldsCompared = comparePersonFields(w, query.Person, index.Person) case EntityAircraft: - typeScore, matched, fieldsCompared = compareAircraftFields(query.Aircraft, index.Aircraft) + typeScore, matched, fieldsCompared = compareAircraftFields(w, query.Aircraft, index.Aircraft) case EntityBusiness: - typeScore, matched, fieldsCompared = compareBusinessFields(query.Business, index.Business) + typeScore, matched, fieldsCompared = compareBusinessFields(w, query.Business, index.Business) case EntityOrganization: - typeScore, matched, fieldsCompared = compareOrganizationFields(query.Organization, index.Organization) - } - if debugEntitySpecific { - debug("%v typeScore=%.4f matched=%v fieldsCompared=%v\n", query.Type, typeScore, matched, fieldsCompared) + typeScore, matched, fieldsCompared = compareOrganizationFields(w, query.Organization, index.Organization) } + debug(w, "compareEntitySpecific\n ") + debug(w, "%v typeScore=%.4f matched=%v fieldsCompared=%v\n", query.Type, typeScore, matched, fieldsCompared) return scorePiece{ score: typeScore, @@ -216,7 +216,7 @@ const ( // ------------------------------- // Person-Specific Fields // ------------------------------- -func comparePersonFields(query *Person, index *Person) (float64, bool, int) { +func comparePersonFields(w io.Writer, query *Person, index *Person) (float64, bool, int) { if query == nil || index == nil { return 0, false, 0 } @@ -273,13 +273,13 @@ func comparePersonFields(query *Person, index *Person) (float64, bool, int) { } avg := sum / float64(len(scores)) - return avg, avg > 0.5, fieldsCompared + return avg, avg > 0.9, fieldsCompared } // ------------------------------- // Vessel-Specific Fields // ------------------------------- -func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { +func compareVesselFields(w io.Writer, query *Vessel, index *Vessel) (float64, bool, int) { if query == nil || index == nil { return 0, false, 0 } @@ -302,6 +302,7 @@ func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { scores = append(scores, fieldScore{0.0, 4.0}) } } + if query.IMONumber != "" && index.IMONumber != "" { fieldsCompared++ if strings.EqualFold(query.IMONumber, index.IMONumber) { @@ -310,12 +311,14 @@ func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { scores = append(scores, fieldScore{0.0, 4.0}) } } + if query.Owner != "" { fieldsCompared++ ownerTerms := strings.Fields(strings.ToLower(query.Owner)) ownerScore := stringscore.BestPairsJaroWinkler(ownerTerms, strings.ToLower(index.Owner)) scores = append(scores, fieldScore{ownerScore, 2.0}) } + if query.Flag != "" { fieldsCompared++ if strings.EqualFold(query.Flag, index.Flag) { @@ -324,6 +327,7 @@ func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { scores = append(scores, fieldScore{0.0, 1.5}) } } + if query.Type != "" { fieldsCompared++ if strings.EqualFold(string(query.Type), string(index.Type)) { @@ -332,12 +336,14 @@ func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { scores = append(scores, fieldScore{0.0, 1.0}) } } + if query.Tonnage > 0 && index.Tonnage > 0 { fieldsCompared++ diff := math.Abs(float64(query.Tonnage - index.Tonnage)) s := vesselTonnageScore(diff) scores = append(scores, fieldScore{s, 1.0}) } + if query.GrossRegisteredTonnage > 0 && index.GrossRegisteredTonnage > 0 { fieldsCompared++ diff := math.Abs(float64(query.GrossRegisteredTonnage - index.GrossRegisteredTonnage)) @@ -356,7 +362,7 @@ func compareVesselFields(query *Vessel, index *Vessel) (float64, bool, int) { } avgScore := totalScore / totalWeight - return avgScore, avgScore > 0.5, fieldsCompared + return avgScore, avgScore > 0.9, fieldsCompared } // Helper for vessel tonnage diffs @@ -376,7 +382,7 @@ func vesselTonnageScore(diff float64) float64 { // ------------------------------- // Aircraft-Specific Fields // ------------------------------- -func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int) { +func compareAircraftFields(w io.Writer, query *Aircraft, index *Aircraft) (float64, bool, int) { if query == nil || index == nil { return 0, false, 0 } @@ -384,7 +390,8 @@ func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int var scores []float64 fieldsCompared := 0 - // ICAO + debug(w, "compareAircraftFields\n ") + if query.ICAOCode != "" { fieldsCompared++ if strings.EqualFold(query.ICAOCode, index.ICAOCode) { @@ -392,8 +399,9 @@ func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int } else { scores = append(scores, 0.0) } + debug(w, " .ICAOCode") } - // Model + if query.Model != "" { fieldsCompared++ if strings.EqualFold(query.Model, index.Model) { @@ -404,8 +412,9 @@ func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int modelScore := stringscore.BestPairsJaroWinkler(qTerms, strings.ToLower(index.Model)) scores = append(scores, modelScore) } + debug(w, " .Model") } - // Flag + if query.Flag != "" { fieldsCompared++ if strings.EqualFold(query.Flag, index.Flag) { @@ -413,18 +422,26 @@ func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int } else { scores = append(scores, 0.0) } + debug(w, " .Flag") } if len(scores) == 0 { return 0, false, fieldsCompared } + debug(w, " (Scores: %v)", scores) + sum := 0.0 for _, s := range scores { sum += s } + + debug(w, " [totalScore=%v ", sum) + avg := sum / float64(len(scores)) + debug(w, "avgScore=%.4f fieldsCompared=%v]\n", avg, fieldsCompared) + return avg, avg > 0.5, fieldsCompared } @@ -432,7 +449,7 @@ func compareAircraftFields(query *Aircraft, index *Aircraft) (float64, bool, int // Business-Specific Fields // ------------------------------- // compareBusinessFields compares fields for the Business entity -func compareBusinessFields(query *Business, index *Business) (float64, bool, int) { +func compareBusinessFields(w io.Writer, query *Business, index *Business) (float64, bool, int) { if query == nil || index == nil { return 0, false, 0 } @@ -445,7 +462,7 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int var scores []fieldScore fieldsCompared := 0 - fmt.Println("compareBusinessFields") + debug(w, "compareBusinessFields\n ") // 1) Primary Name check (fuzzy or exact) if query.Name != "" { @@ -460,8 +477,8 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int nameScore := stringscore.BestPairsJaroWinkler(qTerms, iName) scores = append(scores, fieldScore{score: nameScore, weight: 4.0}) } + debug(w, " .Name") } - fmt.Printf(".Name scores=%v fieldsCompared=%v\n", scores, fieldsCompared) // 2) AltNames check // If the query has alt names, let's see if any overlap. Or, if the index has alt names, @@ -483,8 +500,9 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int } // Weight alt names a bit lower than primary name scores = append(scores, fieldScore{score: bestAltScore, weight: 2.0}) + + debug(w, " .AltName") } - fmt.Printf(".AltName scores=%v fieldsCompared=%v\n", scores, fieldsCompared) // 3) Created date if query.Created != nil && index.Created != nil { @@ -503,8 +521,9 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int scores = append(scores, fieldScore{score: 0.0, weight: 1.0}) } } + + debug(w, " .Created") } - fmt.Printf(".Created scores=%v fieldsCompared=%v\n", scores, fieldsCompared) // 4) Dissolved date if query.Dissolved != nil && index.Dissolved != nil { @@ -523,8 +542,9 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int scores = append(scores, fieldScore{score: 0.0, weight: 1.0}) } } + + debug(w, " .Dissolved") } - fmt.Printf(".Dissolved scores=%v fieldsCompared=%v\n", scores, fieldsCompared) // 5) Identifiers // If you have multiple IDs in each, you might do a best match approach. @@ -559,8 +579,9 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int } // Weight ID matches strongly scores = append(scores, fieldScore{score: bestIDScore, weight: 5.0}) + + debug(w, " .Identifier") } - fmt.Printf(".Identifier scores=%v fieldsCompared=%v\n", scores, fieldsCompared) if len(scores) == 0 { return 0, false, fieldsCompared @@ -574,14 +595,17 @@ func compareBusinessFields(query *Business, index *Business) (float64, bool, int } avgScore := totalScore / totalWeight - // We'll say it's "matched" if > 0.5 on average // TODO(adam): why so low? - return avgScore, avgScore > 0.5, fieldsCompared + debug(w, " (Scores: %v)", scores) + debug(w, " [totalScore=%v totalWeight=%v avgScore=%.4f fieldsCompared=%v]\n", totalScore, totalWeight, avgScore, fieldsCompared) + + // We'll say it's "matched" if > 0.5 on average / + return avgScore, avgScore > 0.9, fieldsCompared } // ------------------------------- // Organization-Specific Fields // ------------------------------- -func compareOrganizationFields(query *Organization, index *Organization) (float64, bool, int) { +func compareOrganizationFields(w io.Writer, query *Organization, index *Organization) (float64, bool, int) { if query == nil || index == nil { return 0, false, 0 } @@ -589,6 +613,8 @@ func compareOrganizationFields(query *Organization, index *Organization) (float6 fieldsCompared := 0 scores := make([]float64, 0) + debug(w, "compareOrganizationFields\n ") + // Created date if query.Created != nil && index.Created != nil { fieldsCompared++ @@ -605,48 +631,60 @@ func compareOrganizationFields(query *Organization, index *Organization) (float6 scores = append(scores, 0.0) } } + debug(w, " .Created") } if len(scores) == 0 { return 0, false, fieldsCompared } + debug(w, " (Scores: %v)", scores) + sum := 0.0 for _, s := range scores { sum += s } + + debug(w, " [totalScore=%v ", sum) + avg := sum / float64(len(scores)) - return avg, avg > 0.5, fieldsCompared + debug(w, "avgScore=%.4f fieldsCompared=%v]\n", avg, fieldsCompared) + + return avg, avg > 0.9, fieldsCompared } // ------------------------------- // Supporting Info (addresses, etc.) // ------------------------------- -func compareSupportingInfo[Q any, I any](query Entity[Q], index Entity[I], weight float64) scorePiece { +func compareSupportingInfo[Q any, I any](w io.Writer, query Entity[Q], index Entity[I], weight float64) scorePiece { var pieces []float64 fieldsCompared := 0 + debug(w, "compareSupportingInfo\n ") + // Compare addresses if len(query.Addresses) > 0 && len(index.Addresses) > 0 { bestAddress := 0.0 fieldsCompared++ for _, qAddr := range query.Addresses { for _, iAddr := range index.Addresses { - addrScore := compareAddress(qAddr, iAddr) + addrScore := compareAddress(w, qAddr, iAddr) if addrScore > bestAddress { bestAddress = addrScore } } } pieces = append(pieces, bestAddress) + debug(w, " .Addresses") } // Compare sanctions programs if query.SanctionsInfo != nil && index.SanctionsInfo != nil { fieldsCompared++ - programScore := compareSanctionsPrograms(query.SanctionsInfo, index.SanctionsInfo) + programScore := compareSanctionsPrograms(w, query.SanctionsInfo, index.SanctionsInfo) pieces = append(pieces, programScore) + debug(w, " .SanctionsInfo") } // Compare crypto addresses (exact matches only) @@ -663,19 +701,26 @@ func compareSupportingInfo[Q any, I any](query Entity[Q], index Entity[I], weigh } score := float64(matches) / float64(len(query.CryptoAddresses)) pieces = append(pieces, score) + debug(w, " .CryptoAddresses") } if len(pieces) == 0 { return scorePiece{score: 0, weight: 0, fieldsCompared: 0, pieceType: "supporting"} } + debug(w, " (Scores: %v)", pieces) + // Average of these pieces sum := 0.0 for _, s := range pieces { sum += s } + debug(w, " [totalScore=%v ", sum) + avgScore := sum / float64(len(pieces)) + debug(w, "avgScore=%.4f fieldsCompared=%v]\n", avgScore, fieldsCompared) + return scorePiece{ score: avgScore, weight: weight, @@ -690,18 +735,21 @@ func compareSupportingInfo[Q any, I any](query Entity[Q], index Entity[I], weigh // ------------------------------- // Address comparison // ------------------------------- -func compareAddress(query Address, index Address) float64 { +func compareAddress(w io.Writer, query Address, index Address) float64 { var ( pieces []float64 weights []float64 ) + debug(w, "compareAddress\n ") + // Line1 if query.Line1 != "" { qTerms := strings.Fields(query.Line1) score := stringscore.BestPairsJaroWinkler(qTerms, index.Line1) pieces = append(pieces, score) weights = append(weights, 3.0) + debug(w, ".Line1") } // Line2 if query.Line2 != "" { @@ -709,6 +757,7 @@ func compareAddress(query Address, index Address) float64 { score := stringscore.BestPairsJaroWinkler(qTerms, index.Line2) pieces = append(pieces, score) weights = append(weights, 1.0) + debug(w, ".Line2") } // City if query.City != "" { @@ -716,6 +765,7 @@ func compareAddress(query Address, index Address) float64 { score := stringscore.BestPairsJaroWinkler(qTerms, index.City) pieces = append(pieces, score) weights = append(weights, 2.0) + debug(w, ".City") } // State (exact) if query.State != "" { @@ -725,6 +775,7 @@ func compareAddress(query Address, index Address) float64 { pieces = append(pieces, 0.0) } weights = append(weights, 1.0) + debug(w, ".State") } // Postal code (exact) if query.PostalCode != "" { @@ -734,6 +785,7 @@ func compareAddress(query Address, index Address) float64 { pieces = append(pieces, 0.0) } weights = append(weights, 1.5) + debug(w, ".PosalCode") } // Country (exact) if query.Country != "" { @@ -743,21 +795,26 @@ func compareAddress(query Address, index Address) float64 { pieces = append(pieces, 0.0) } weights = append(weights, 2.0) + debug(w, ".Country") } if len(pieces) == 0 { return 0 } + debug(w, " (Scores: %v)", pieces) + var totalScore, totalWeight float64 for i := range pieces { totalScore += pieces[i] * weights[i] totalWeight += weights[i] } + debug(w, " [totalScore=%v totalWeight=%v]\n", totalScore, totalWeight) + return totalScore / totalWeight } -func compareSanctionsPrograms(query *SanctionsInfo, index *SanctionsInfo) float64 { +func compareSanctionsPrograms(w io.Writer, query *SanctionsInfo, index *SanctionsInfo) float64 { if query == nil || index == nil { return 0 } @@ -878,12 +935,8 @@ func countIndexUniqueFields[I any](index Entity[I]) int { return count } -const ( - debugFinalScores = true -) - // calculateFinalScore applies coverage logic and final adjustments. -func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { +func calculateFinalScore[I any](w io.Writer, pieces []scorePiece, index Entity[I]) float64 { if len(pieces) == 0 { return 0 } @@ -895,17 +948,11 @@ func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { hasNameMatch bool ) - if debugFinalScores { - defer debug("\n") - } + debug(w, "\ncalculateFinalScore\n") // Sum up the piece scores for _, piece := range pieces { - if debugFinalScores { - if debugFinalScores { - debug("%#v\n", piece) - } - } + debug(w, "%#v\n", piece) // Skip zero-weight pieces entirely if piece.weight <= 0 { @@ -914,7 +961,7 @@ func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { // If "entity" piece has score=0 but fieldsCompared=1, that indicates a type mismatch => overall 0 // if piece.pieceType == "entity" && piece.fieldsCompared == 1 && piece.score == 0 { - // debug("entity - mismatch") + // debug(w, "entity - mismatch") // return 0 // } @@ -938,9 +985,7 @@ func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { } baseScore := totalScore / totalWeight - if debugFinalScores { - debug("baseScore=%.4f ", baseScore) - } + debug(w, "baseScore=%.4f ", baseScore) // Coverage check: only count fields relevant to the index type coveragePenalty := 1.0 @@ -949,9 +994,7 @@ func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { for _, p := range pieces { fieldsCompared += p.fieldsCompared } - if debugFinalScores { - debug("fieldsCompared=%d ", fieldsCompared) - } + debug(w, "fieldsCompared=%d ", fieldsCompared) if indexUniqueCount > 0 { coverage := float64(fieldsCompared) / float64(indexUniqueCount) @@ -963,25 +1006,21 @@ func calculateFinalScore[I any](pieces []scorePiece, index Entity[I]) float64 { } finalScore := baseScore * coveragePenalty - if debugFinalScores { - debug("coveragePenalty=%.2f ", coveragePenalty) - } + debug(w, "coveragePenalty=%.2f ", coveragePenalty) // Perfect match boost: only if coverage wasn't penalized if hasExactMatch && hasNameMatch && finalScore > 0.9 && coveragePenalty == 1.0 { - if debugFinalScores { - debug("PERFECT MATCH BOOST ") - } + debug(w, "PERFECT MATCH BOOST ") finalScore = math.Min(1.0, finalScore*1.15) } - if debugFinalScores { - debug("finalScore=%.2f", finalScore) - } + debug(w, "finalScore=%.2f", finalScore) return finalScore } -func debug(pattern string, args ...any) { - fmt.Printf(pattern, args...) //nolint:forbidigo +func debug(w io.Writer, pattern string, args ...any) { + if w != nil { + fmt.Fprintf(w, pattern, args...) + } } diff --git a/pkg/search/similarity_ofac_test.go b/pkg/search/similarity_ofac_test.go index f320232d..c93c54ae 100644 --- a/pkg/search/similarity_ofac_test.go +++ b/pkg/search/similarity_ofac_test.go @@ -1,6 +1,9 @@ package search_test import ( + "bytes" + "fmt" + "io" "testing" "time" @@ -190,17 +193,15 @@ func TestSimilarity_OFAC_SDN_Vessel(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - score := search.Similarity(tc.query, indexEntity) + score := search.DebugSimilarity(debug(t), tc.query, indexEntity) require.InDelta(t, tc.expected, score, 0.02) // Additional assertions for specific score thresholds if tc.expected >= 0.95 { - require.GreaterOrEqual(t, score, 0.95, - "High confidence matches should score >= 0.95") + require.GreaterOrEqual(t, score, 0.95, "High confidence matches should score >= 0.95") } if tc.expected <= 0.40 { - require.LessOrEqual(t, score, 0.40, - "Clear mismatches should score <= 0.40") + require.LessOrEqual(t, score, 0.40, "Clear mismatches should score <= 0.40") } }) } @@ -330,7 +331,7 @@ func TestSimilarity_OFAC_SDN_Person(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - score := search.Similarity(tc.query, indexEntity) + score := search.DebugSimilarity(debug(t), tc.query, indexEntity) require.InDelta(t, tc.expected, score, 0.02) }) } @@ -472,7 +473,7 @@ func TestSimilarity_OFAC_SDN_Business(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - score := search.Similarity(tc.query, indexEntity) + score := search.DebugSimilarity(debug(t), tc.query, indexEntity) require.InDelta(t, tc.expected, score, 0.05) }) } @@ -515,8 +516,22 @@ func TestSimilarity_Edge_Cases(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - score := search.Similarity(tc.query, indexEntity) + score := search.DebugSimilarity(debug(t), tc.query, indexEntity) require.InDelta(t, tc.expected, score, 0.02) }) } } + +func debug(t *testing.T) io.Writer { + t.Helper() + + if testing.Verbose() { + buf := new(bytes.Buffer) + t.Cleanup(func() { + fmt.Printf("\n%s\n", buf.String()) + }) + return buf + } + + return nil +}