Skip to content

Commit fd8a7bb

Browse files
committed
AI bug fixes
1 parent c1b24e9 commit fd8a7bb

9 files changed

Lines changed: 55 additions & 23 deletions

File tree

certgraph.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func main() {
165165

166166
// create the output directory if it does not exist
167167
if len(config.savePath) > 0 {
168-
err := os.MkdirAll(config.savePath, 0777)
168+
err := os.MkdirAll(config.savePath, 0755)
169169
if err != nil {
170170
fmt.Fprintln(os.Stderr, err)
171171
return
@@ -215,7 +215,7 @@ func getDriverSingle(name string) (driver.Driver, error) {
215215
case "censys":
216216
d, err = censys.Driver(config.savePath, config.includeCTSubdomains, config.includeCTExpired)
217217
default:
218-
return nil, fmt.Errorf("unknown driver name: %s", config.driver)
218+
return nil, fmt.Errorf("unknown driver name: %s", name)
219219
}
220220
return d, err
221221
}
@@ -272,8 +272,7 @@ func breathFirstSearch(roots []string) {
272272
}()
273273
// thread to start all other threads from DomainChan
274274
go func() {
275-
for {
276-
domainNode := <-domainNodeInputChan
275+
for domainNode := range domainNodeInputChan {
277276

278277
// depth check
279278
if domainNode.Depth > config.maxDepth {
@@ -341,6 +340,7 @@ func breathFirstSearch(roots []string) {
341340
}()
342341

343342
wg.Wait() // wait for querying to finish
343+
close(domainNodeInputChan) // close input channel to signal goroutine to exit
344344
close(domainNodeOutputChan)
345345
<-done // wait for save to finish
346346
}

dns/ns.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ package dns
44
import (
55
"context"
66
"net"
7+
"sync"
78
"time"
89
)
910

1011
var (
11-
dnsCache = make(map[string]bool)
12+
dnsCache = &sync.Map{}
1213
dnsResolver = &net.Resolver{}
1314
)
1415

@@ -76,13 +77,12 @@ func HasRecordsCache(domain string, timeout time.Duration) (bool, error) {
7677
if err != nil {
7778
return false, err
7879
}
79-
hasDNS, found := dnsCache[domain]
80-
if found {
81-
return hasDNS, nil
80+
if cached, found := dnsCache.Load(domain); found {
81+
return cached.(bool), nil
8282
}
8383
hasRecords, err := HasRecords(domain, timeout)
84-
if err != nil {
85-
dnsCache[domain] = hasRecords
84+
if err == nil {
85+
dnsCache.Store(domain, hasRecords)
8686
}
8787
return hasRecords, err
8888
}

driver/censys/censys.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,11 @@ func (d *censys) QueryDomain(domain string) (driver.Result, error) {
203203
}
204204

205205
for _, r := range resp.Results {
206-
fp := fingerprint.FromHexHash(r.Fingerprint)
206+
fp, err := fingerprint.FromHexHash(r.Fingerprint)
207+
if err != nil {
208+
log.Printf("censys: invalid fingerprint %s: %v", r.Fingerprint, err)
209+
continue
210+
}
207211
results.fingerprints.Add(domain, fp)
208212
}
209213

driver/crtsh/crtsh.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ func Driver(maxQueryResults int, timeout time.Duration, savePath string, include
7979
return nil, err
8080
}
8181

82+
// Configure connection pool to prevent resource leaks
83+
d.db.SetMaxOpenConns(10)
84+
d.db.SetMaxIdleConns(2)
85+
d.db.SetConnMaxLifetime(time.Hour)
86+
8287
err = d.setSQLTimeout(d.timeout.Seconds())
8388

8489
return d, err
@@ -88,6 +93,14 @@ func (d *crtsh) GetName() string {
8893
return driverName
8994
}
9095

96+
// Close closes the database connection
97+
func (d *crtsh) Close() error {
98+
if d.db != nil {
99+
return d.db.Close()
100+
}
101+
return nil
102+
}
103+
91104
func (d *crtsh) setSQLTimeout(sec float64) error {
92105
_, err := d.db.Exec(fmt.Sprintf("SET statement_timeout TO %f;", (1000 * sec)))
93106
return err
@@ -161,6 +174,7 @@ func (d *crtsh) QueryDomain(domain string) (driver.Result, error) {
161174
if err != nil {
162175
return results, err
163176
}
177+
defer func() { _ = rows.Close() }()
164178

165179
for rows.Next() {
166180
var hash []byte
@@ -202,6 +216,7 @@ func (d *crtsh) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error
202216
if err != nil {
203217
return certNode, err
204218
}
219+
defer func() { _ = rows.Close() }()
205220

206221
for rows.Next() {
207222
var domain string
@@ -214,7 +229,7 @@ func (d *crtsh) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error
214229

215230
if d.save {
216231
var rawCert []byte
217-
queryStr = `SELECT certificate FORM certificate_and_identities WHERE digest(certificate, 'sha256') = $1;`
232+
queryStr = `SELECT certificate FROM certificate_and_identities WHERE digest(certificate, 'sha256') = $1;`
218233
row := d.db.QueryRow(queryStr, fp[:])
219234
err = row.Scan(&rawCert)
220235
if err != nil {

driver/http/http.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ func (c *httpCertDriver) dialTLS(network, addr string) (net.Conn, error) {
142142
connState := conn.ConnectionState()
143143

144144
// only look at leaf certificate which is valid for domain, rest of cert chain is ignored
145+
if len(connState.PeerCertificates) == 0 {
146+
return conn, fmt.Errorf("no peer certificates found")
147+
}
145148
certResult := driver.NewCertResult(connState.PeerCertificates[0])
146149
c.certs[certResult.Fingerprint] = certResult
147150
host, _, err := net.SplitHostPort(addr)

driver/smtp/smtp.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ func (d *smtpDriver) QueryDomain(host string) (driver.Result, error) {
129129
}
130130

131131
// only look at leaf certificate which is valid for domain, rest of cert chain is ignored
132+
if len(certs) == 0 {
133+
return results, fmt.Errorf("no certificates found")
134+
}
132135
certResult := driver.NewCertResult(certs[0])
133136
results.certs[certResult.Fingerprint] = certResult
134137
results.fingerprints.Add(host, certResult.Fingerprint)

fingerprint/fingerprint.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@ func FromRawCertBytes(data []byte) Fingerprint {
3535
}
3636

3737
// FromB64Hash returns a Fingerprint from a base64 encoded hash string
38-
func FromB64Hash(hash string) Fingerprint {
38+
func FromB64Hash(hash string) (Fingerprint, error) {
3939
data, err := base64.StdEncoding.DecodeString(hash)
4040
if err != nil {
41-
panic(err)
41+
return Fingerprint{}, err
4242
}
43-
return FromHashBytes(data)
43+
return FromHashBytes(data), nil
4444
}
4545

4646
// FromHexHash returns a Fingerprint from a hex encoded hash string
47-
func FromHexHash(hash string) Fingerprint {
47+
func FromHexHash(hash string) (Fingerprint, error) {
4848
decoded, err := hex.DecodeString(hash)
4949
if err != nil {
50-
panic(err)
50+
return Fingerprint{}, err
5151
}
52-
return FromHashBytes(decoded)
52+
return FromHashBytes(decoded), nil
5353
}
5454

5555
// B64Encode returns the b64 string of a Fingerprint

fingerprint/fp_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ func TestFromRawCertBytes(t *testing.T) {
6161

6262
func TestFromB64Hash(t *testing.T) {
6363

64-
fp := fingerprint.FromB64Hash(fpHashB64)
64+
fp, err := fingerprint.FromB64Hash(fpHashB64)
65+
if err != nil {
66+
t.Fatalf("FromB64Hash failed: %v", err)
67+
}
6568

6669
uppercaseHash := strings.ToUpper(fpHashHex)
6770
hashHex := fp.HexString()
@@ -79,7 +82,10 @@ func TestFromB64Hash(t *testing.T) {
7982

8083
func TestFromHexHash(t *testing.T) {
8184

82-
fp := fingerprint.FromHexHash(fpHashHex)
85+
fp, err := fingerprint.FromHexHash(fpHashHex)
86+
if err != nil {
87+
t.Fatalf("FromHexHash failed: %v", err)
88+
}
8389

8490
hashB64 := fp.B64Encode()
8591

graph/graph.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package graph
44
import (
55
"strings"
66
"sync"
7+
"sync/atomic"
78

89
"github.com/lanrat/certgraph/fingerprint"
910
)
@@ -12,7 +13,7 @@ import (
1213
type CertGraph struct {
1314
domains sync.Map
1415
certs sync.Map
15-
numDomains int
16+
numDomains int64
1617
depth uint
1718
}
1819

@@ -31,7 +32,7 @@ func (graph *CertGraph) AddCert(certNode *CertNode) {
3132

3233
// AddDomain add a DomainNode to the graph
3334
func (graph *CertGraph) AddDomain(domainNode *DomainNode) {
34-
graph.numDomains++
35+
atomic.AddInt64(&graph.numDomains, 1)
3536
// save the new maximum depth if greather then current
3637
if domainNode.Depth > graph.depth {
3738
graph.depth = domainNode.Depth
@@ -44,7 +45,7 @@ func (graph *CertGraph) AddDomain(domainNode *DomainNode) {
4445

4546
// NumDomains returns the number of domains in the graph
4647
func (graph *CertGraph) NumDomains() int {
47-
return graph.numDomains
48+
return int(atomic.LoadInt64(&graph.numDomains))
4849
}
4950

5051
// DomainDepth returns the maximum depth of the graph from the initial root domains

0 commit comments

Comments
 (0)