Skip to content

Commit

Permalink
feat: persist protocols sets and fix hanging context
Browse files Browse the repository at this point in the history
  • Loading branch information
kasteph committed Oct 18, 2024
1 parent 917b506 commit 33b069e
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 34 deletions.
34 changes: 27 additions & 7 deletions cmd/honeypot/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"syscall"
"time"

logging "github.com/ipfs/go-log/v2"
"github.com/probe-lab/ants-watch"
Expand All @@ -25,6 +26,7 @@ func main() {
flag.Parse()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
Expand All @@ -40,14 +42,32 @@ func main() {
panic(err)
}

go queen.Run(ctx)

errChan := make(chan error, 1)
go func() {
sig := <-sigChan
logger.Infof("Received signal: %s, shutting down...", sig)
cancel()
logger.Debugln("Starting Queen.Run")
errChan <- queen.Run(ctx)
logger.Debugln("Queen.Run completed")
}()

<-ctx.Done()
logger.Info("Context canceled, queen stopped")
select {
case err := <-errChan:
if err != nil {
logger.Errorf("Queen.Run returned an error: %v", err)
} else {
logger.Debugln("Queen.Run completed successfully")
}
case sig := <-sigChan:
logger.Infof("Received signal: %v, initiating shutdown...", sig)
}

cancel()

select {
case <-errChan:
logger.Debugln("Queen.Run stopped after context cancellation")
case <-time.After(30 * time.Second):
logger.Warnln("Timeout waiting for Queen.Run to stop")
}

logger.Debugln("Work is done")
}
18 changes: 9 additions & 9 deletions db/client_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,12 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque
i := 1

for _, request := range requests {
valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6))
valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion)
i += 7
valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6, i+7))
valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion, request.Protocols)
i += 8
}

stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version) VALUES %s RETURNING id;",
stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols) VALUES %s RETURNING id;",
strings.Join(valueStrings, ", "))

rows, err := queries.Raw(stmt, valueArgs...).QueryContext(ctx, db)
Expand All @@ -613,15 +613,15 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque
}

func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) error {
rows, err := db.Query("SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version FROM requests_denormalized WHERE normalized_at IS NULL")
rows, err := db.QueryContext(ctx, "SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols FROM requests_denormalized WHERE normalized_at IS NULL")
if err != nil {
return err
}
defer rows.Close()

for rows.Next() {
var request models.RequestsDenormalized
if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion); err != nil {
if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion, &request.Protocols); err != nil {
return err
}

Expand All @@ -633,14 +633,14 @@ func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) erro
request.PeerMultihash,
request.KeyMultihash,
request.MultiAddresses,
request.AgentVersion, // agent versions
nil, // protocol sets
request.AgentVersion,
request.Protocols,
)
if err != nil {
return fmt.Errorf("failed to normalize request ID %d: %w, timestamp: %v", request.ID, err, request.RequestStartedAt)
}

_, err = db.Exec("UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID)
_, err = db.ExecContext(ctx, "UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID)
if err != nil {
return fmt.Errorf("failed to update normalized_at for request ID %d: %w", request.ID, err)
}
Expand Down
2 changes: 2 additions & 0 deletions db/migrations/000019_add_protocols_to_requests_table.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE requests
DROP COLUMN IF EXISTS protocols_set_id;
7 changes: 7 additions & 0 deletions db/migrations/000019_add_protocols_to_requests_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
BEGIN;

ALTER TABLE requests
ADD COLUMN protocols_set_id INT,
ADD CONSTRAINT fk_requests_protocols_set_id FOREIGN KEY (protocols_set_id) REFERENCES protocols_sets (id) ON DELETE SET NULL;

COMMIT;
5 changes: 5 additions & 0 deletions db/migrations/000020_alter_insert_requests_function.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
BEGIN;

DROP FUNCTION IF EXISTS insert_request;

COMMIT;
59 changes: 59 additions & 0 deletions db/migrations/000020_alter_insert_requests_function.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
BEGIN;

CREATE OR REPLACE FUNCTION insert_request(
new_timestamp TIMESTAMPTZ,
new_request_type message_type,
new_ant TEXT,
new_multi_hash TEXT,
new_key_multi_hash TEXT,
new_multi_addresses TEXT[],
new_agent_version_id INT,
new_protocols_set_id INT
) RETURNS RECORD AS
$insert_request$
DECLARE
new_multi_addresses_ids INT[];
new_request_id INT;
new_peer_id INT;
new_ant_id INT;
new_key_id INT;
BEGIN
SELECT upsert_peer(
new_multi_hash,
new_agent_version_id,
new_protocols_set_id,
new_timestamp
) INTO new_peer_id;

SELECT id INTO new_ant_id
FROM peers
WHERE multi_hash = new_ant;

SELECT insert_key(new_key_multi_hash) INTO new_key_id;

SELECT array_agg(id) FROM upsert_multi_addresses(new_multi_addresses) INTO new_multi_addresses_ids;

DELETE
FROM peers_x_multi_addresses pxma
WHERE peer_id = new_peer_id;

INSERT INTO peers_x_multi_addresses (peer_id, multi_address_id)
SELECT new_peer_id, new_multi_address_id
FROM unnest(new_multi_addresses_ids) new_multi_address_id
ON CONFLICT DO NOTHING;

INSERT INTO requests (timestamp, request_type, ant_id, peer_id, key_id, multi_address_ids, protocols_set_id)
SELECT new_timestamp,
new_request_type,
new_ant_id,
new_peer_id,
new_key_id,
new_multi_addresses_ids,
new_protocols_set_id
RETURNING id INTO new_request_id;

RETURN ROW(new_peer_id, new_request_id, new_key_id);
END;
$insert_request$ LANGUAGE plpgsql;

COMMIT;
12 changes: 10 additions & 2 deletions db/models/requests.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 40 additions & 16 deletions queen.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/probe-lab/go-libdht/kad"
"github.com/probe-lab/go-libdht/kad/key"
Expand Down Expand Up @@ -83,7 +84,7 @@ func NewQueen(ctx context.Context, dbConnString string, keysDbPath string, nPort
mmc: mmc,
uclient: getUdgerClient(),
resolveBatchSize: getBatchSize(),
resolveBatchTime: getBatchSize(),
resolveBatchTime: getBatchTime(),
}

if nPorts != 0 {
Expand Down Expand Up @@ -182,7 +183,10 @@ func (q *Queen) freePort(port uint16) {
}
}

func (q *Queen) Run(ctx context.Context) {
func (q *Queen) Run(ctx context.Context) error {
logger.Debugln("Queen.Run started")
defer logger.Debugln("Queen.Run completing")

go q.consumeAntsLogs(ctx)

crawlTime := time.NewTicker(CRAWL_INTERVAL)
Expand All @@ -195,14 +199,17 @@ func (q *Queen) Run(ctx context.Context) {

for {
select {
case <-ctx.Done():
logger.Debugln("Queen.Run done..")
q.persistLiveAntsKeys()
return ctx.Err()
case <-crawlTime.C:
q.routine(ctx)
case <-normalizationTime.C:
go q.normalizeRequests(ctx)
// time.Sleep(10 * time.Second)
case <-ctx.Done():
q.persistLiveAntsKeys()
return
default:
// busy-loop guard
time.Sleep(100 * time.Millisecond)
}
}
}
Expand All @@ -213,8 +220,24 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
ticker := time.NewTicker(time.Duration(q.resolveBatchTime) * time.Second)
defer ticker.Stop()

// Add a fail-safe timer
failSafeTimer := time.NewTimer(30 * time.Minute)
defer failSafeTimer.Stop()

for {
select {

case <-ctx.Done():
logger.Debugln("Gracefully shutting down ants...")
logger.Debugln("Number of requests remaining to be inserted:", len(requests))
if len(requests) > 0 {
err := db.BulkInsertRequests(context.Background(), q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting remaining requests: %v", err)
}
}
return

case log := <-q.antsLogs:
reqType := kadpb.Message_MessageType(log.Type).String()
maddrs := q.peerstore.Addrs(log.Requester)
Expand All @@ -225,7 +248,9 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
} else {
agent = peerstoreAgent.(string)
}
// protocols, _ := q.peerstore.GetProtocols(log.Requester)

protocols, _ := q.peerstore.GetProtocols(log.Requester)
protocolsAsStr := protocol.ConvertToStrings(protocols)

request := models.RequestsDenormalized{
RequestStartedAt: log.Timestamp,
Expand All @@ -235,12 +260,13 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
KeyMultihash: log.Target.B58String(),
MultiAddresses: db.MaddrsToAddrs(maddrs),
AgentVersion: null.StringFrom(agent),
Protocols: protocolsAsStr,
}
requests = append(requests, request)
if len(requests) >= q.resolveBatchSize {
err = db.BulkInsertRequests(ctx, q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting requests: %v", err)
logger.Errorf("Error inserting requests: %v", err)
}
requests = requests[:0]
}
Expand All @@ -254,16 +280,12 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
requests = requests[:0]
}

case <-ctx.Done():
if len(requests) > 0 {
err := db.BulkInsertRequests(ctx, q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting remaining requests: %v", err)
}
}
return
default:
// against busy-looping since <-q.antsLogs is a busy chan
time.Sleep(10 * time.Millisecond)
}
}

}

func (q *Queen) normalizeRequests(ctx context.Context) {
Expand All @@ -281,11 +303,13 @@ func (q *Queen) normalizeRequests(ctx context.Context) {
}

func (q *Queen) persistLiveAntsKeys() {
logger.Debugln("Persisting live ants keys")
antsKeys := make([]crypto.PrivKey, 0, len(q.ants))
for _, ant := range q.ants {
antsKeys = append(antsKeys, ant.Host.Peerstore().PrivKey(ant.Host.ID()))
}
q.keysDB.MatchingKeys(nil, antsKeys)
logger.Debugf("Number of antsKeys persisted: %d", len(antsKeys))
}

func (q *Queen) routine(ctx context.Context) {
Expand Down

0 comments on commit 33b069e

Please sign in to comment.