Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: persist protocols sets #10

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions cmd/honeypot/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"syscall"
"time"

logging "github.com/ipfs/go-log/v2"
"github.com/probe-lab/ants-watch"
Expand All @@ -25,6 +26,7 @@ func main() {
flag.Parse()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
Expand All @@ -40,14 +42,32 @@ func main() {
panic(err)
}

go queen.Run(ctx)

errChan := make(chan error, 1)
go func() {
sig := <-sigChan
logger.Infof("Received signal: %s, shutting down...", sig)
cancel()
logger.Debugln("Starting Queen.Run")
errChan <- queen.Run(ctx)
logger.Debugln("Queen.Run completed")
}()

<-ctx.Done()
logger.Info("Context canceled, queen stopped")
select {
case err := <-errChan:
if err != nil {
logger.Errorf("Queen.Run returned an error: %v", err)
} else {
logger.Debugln("Queen.Run completed successfully")
}
case sig := <-sigChan:
logger.Infof("Received signal: %v, initiating shutdown...", sig)
}

cancel()

select {
case <-errChan:
logger.Debugln("Queen.Run stopped after context cancellation")
case <-time.After(30 * time.Second):
logger.Warnln("Timeout waiting for Queen.Run to stop")
}

logger.Debugln("Work is done")
}
18 changes: 9 additions & 9 deletions db/client_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,12 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque
i := 1

for _, request := range requests {
valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6))
valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion)
i += 7
valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6, i+7))
valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion, request.Protocols)
i += 8
}

stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version) VALUES %s RETURNING id;",
stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols) VALUES %s RETURNING id;",
strings.Join(valueStrings, ", "))

rows, err := queries.Raw(stmt, valueArgs...).QueryContext(ctx, db)
Expand All @@ -613,15 +613,15 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque
}

func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) error {
rows, err := db.Query("SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version FROM requests_denormalized WHERE normalized_at IS NULL")
rows, err := db.QueryContext(ctx, "SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols FROM requests_denormalized WHERE normalized_at IS NULL")
if err != nil {
return err
}
defer rows.Close()

for rows.Next() {
var request models.RequestsDenormalized
if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion); err != nil {
if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion, &request.Protocols); err != nil {
return err
}

Expand All @@ -633,14 +633,14 @@ func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) erro
request.PeerMultihash,
request.KeyMultihash,
request.MultiAddresses,
request.AgentVersion, // agent versions
nil, // protocol sets
request.AgentVersion,
request.Protocols,
)
if err != nil {
return fmt.Errorf("failed to normalize request ID %d: %w, timestamp: %v", request.ID, err, request.RequestStartedAt)
}

_, err = db.Exec("UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID)
_, err = db.ExecContext(ctx, "UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID)
if err != nil {
return fmt.Errorf("failed to update normalized_at for request ID %d: %w", request.ID, err)
}
Expand Down
2 changes: 2 additions & 0 deletions db/migrations/000019_add_protocols_to_requests_table.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE requests
DROP COLUMN IF EXISTS protocols_set_id;
7 changes: 7 additions & 0 deletions db/migrations/000019_add_protocols_to_requests_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
BEGIN;

ALTER TABLE requests
ADD COLUMN protocols_set_id INT,
ADD CONSTRAINT fk_requests_protocols_set_id FOREIGN KEY (protocols_set_id) REFERENCES protocols_sets (id) ON DELETE SET NULL;

COMMIT;
5 changes: 5 additions & 0 deletions db/migrations/000020_alter_insert_requests_function.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
BEGIN;

DROP FUNCTION IF EXISTS insert_request;

COMMIT;
59 changes: 59 additions & 0 deletions db/migrations/000020_alter_insert_requests_function.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
BEGIN;

CREATE OR REPLACE FUNCTION insert_request(
new_timestamp TIMESTAMPTZ,
new_request_type message_type,
new_ant TEXT,
new_multi_hash TEXT,
new_key_multi_hash TEXT,
new_multi_addresses TEXT[],
new_agent_version_id INT,
new_protocols_set_id INT
) RETURNS RECORD AS
$insert_request$
DECLARE
new_multi_addresses_ids INT[];
new_request_id INT;
new_peer_id INT;
new_ant_id INT;
new_key_id INT;
BEGIN
SELECT upsert_peer(
new_multi_hash,
new_agent_version_id,
new_protocols_set_id,
new_timestamp
) INTO new_peer_id;

SELECT id INTO new_ant_id
FROM peers
WHERE multi_hash = new_ant;

SELECT insert_key(new_key_multi_hash) INTO new_key_id;

SELECT array_agg(id) FROM upsert_multi_addresses(new_multi_addresses) INTO new_multi_addresses_ids;

DELETE
FROM peers_x_multi_addresses pxma
WHERE peer_id = new_peer_id;

INSERT INTO peers_x_multi_addresses (peer_id, multi_address_id)
SELECT new_peer_id, new_multi_address_id
FROM unnest(new_multi_addresses_ids) new_multi_address_id
ON CONFLICT DO NOTHING;

INSERT INTO requests (timestamp, request_type, ant_id, peer_id, key_id, multi_address_ids, protocols_set_id)
SELECT new_timestamp,
new_request_type,
new_ant_id,
new_peer_id,
new_key_id,
new_multi_addresses_ids,
new_protocols_set_id
RETURNING id INTO new_request_id;

RETURN ROW(new_peer_id, new_request_id, new_key_id);
END;
$insert_request$ LANGUAGE plpgsql;

COMMIT;
12 changes: 10 additions & 2 deletions db/models/requests.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 36 additions & 22 deletions queen.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/probe-lab/go-libdht/kad"
"github.com/probe-lab/go-libdht/kad/key"
Expand Down Expand Up @@ -83,7 +84,7 @@ func NewQueen(ctx context.Context, dbConnString string, keysDbPath string, nPort
mmc: mmc,
uclient: getUdgerClient(),
resolveBatchSize: getBatchSize(),
resolveBatchTime: getBatchSize(),
resolveBatchTime: getBatchTime(),
}

if nPorts != 0 {
Expand Down Expand Up @@ -182,7 +183,10 @@ func (q *Queen) freePort(port uint16) {
}
}

func (q *Queen) Run(ctx context.Context) {
func (q *Queen) Run(ctx context.Context) error {
logger.Debugln("Queen.Run started")
defer logger.Debugln("Queen.Run completing")

go q.consumeAntsLogs(ctx)

crawlTime := time.NewTicker(CRAWL_INTERVAL)
Expand All @@ -195,14 +199,17 @@ func (q *Queen) Run(ctx context.Context) {

for {
select {
case <-ctx.Done():
logger.Debugln("Queen.Run done..")
q.persistLiveAntsKeys()
return ctx.Err()
case <-crawlTime.C:
q.routine(ctx)
case <-normalizationTime.C:
go q.normalizeRequests(ctx)
// time.Sleep(10 * time.Second)
case <-ctx.Done():
q.persistLiveAntsKeys()
return
default:
// busy-loop guard
time.Sleep(100 * time.Millisecond)
}
}
}
Expand All @@ -215,6 +222,18 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {

for {
select {

case <-ctx.Done():
logger.Debugln("Gracefully shutting down ants...")
logger.Debugln("Number of requests remaining to be inserted:", len(requests))
if len(requests) > 0 {
err := db.BulkInsertRequests(context.Background(), q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting remaining requests: %v", err)
}
}
return

case log := <-q.antsLogs:
reqType := kadpb.Message_MessageType(log.Type).String()
maddrs := q.peerstore.Addrs(log.Requester)
Expand All @@ -225,12 +244,9 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
} else {
agent = peerstoreAgent.(string)
}
// TODO: uncomment when we need to track protocols
// protocols, _ := q.peerstore.GetProtocols(log.Requester)
// protocolsStr := make([]string, len(protocols))
// for i, p := range protocols {
// protocolsStr[i] = string(p)
// }

protocols, _ := q.peerstore.GetProtocols(log.Requester)
protocolsAsStr := protocol.ConvertToStrings(protocols)

request := models.RequestsDenormalized{
RequestStartedAt: log.Timestamp,
Expand All @@ -240,13 +256,13 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
KeyMultihash: log.Target.B58String(),
MultiAddresses: db.MaddrsToAddrs(maddrs),
AgentVersion: null.StringFrom(agent),
// Protocols: protocolsStr,
Protocols: protocolsAsStr,
}
requests = append(requests, request)
if len(requests) >= q.resolveBatchSize {
err = db.BulkInsertRequests(ctx, q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting requests: %v", err)
logger.Errorf("Error inserting requests: %v", err)
}
requests = requests[:0]
}
Expand All @@ -260,16 +276,12 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) {
requests = requests[:0]
}

case <-ctx.Done():
if len(requests) > 0 {
err := db.BulkInsertRequests(ctx, q.dbc.Handler, requests)
if err != nil {
logger.Fatalf("Error inserting remaining requests: %v", err)
}
}
return
default:
// against busy-looping since <-q.antsLogs is a busy chan
time.Sleep(10 * time.Millisecond)
}
}

}

func (q *Queen) normalizeRequests(ctx context.Context) {
Expand All @@ -287,11 +299,13 @@ func (q *Queen) normalizeRequests(ctx context.Context) {
}

func (q *Queen) persistLiveAntsKeys() {
logger.Debugln("Persisting live ants keys")
antsKeys := make([]crypto.PrivKey, 0, len(q.ants))
for _, ant := range q.ants {
antsKeys = append(antsKeys, ant.Host.Peerstore().PrivKey(ant.Host.ID()))
}
q.keysDB.MatchingKeys(nil, antsKeys)
logger.Debugf("Number of antsKeys persisted: %d", len(antsKeys))
}

func (q *Queen) routine(ctx context.Context) {
Expand Down
Loading