fix(parsers.json_v2): Prevent race condition in parse function (#14149)

This commit is contained in:
Adam Thornton 2023-10-30 01:04:41 -07:00 committed by GitHub
parent 7ec04f8dd6
commit 38b8a1bcde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 63 deletions

View File

@ -65,10 +65,10 @@ type KafkaConsumer struct {
ticker *time.Ticker ticker *time.Ticker
fingerprint string fingerprint string
parserFunc telegraf.ParserFunc parser telegraf.Parser
topicLock sync.Mutex topicLock sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
cancel context.CancelFunc cancel context.CancelFunc
} }
type ConsumerGroup interface { type ConsumerGroup interface {
@ -91,8 +91,8 @@ func (*KafkaConsumer) SampleConfig() string {
return sampleConfig return sampleConfig
} }
func (k *KafkaConsumer) SetParserFunc(fn telegraf.ParserFunc) { func (k *KafkaConsumer) SetParser(parser telegraf.Parser) {
k.parserFunc = fn k.parser = parser
} }
func (k *KafkaConsumer) Init() error { func (k *KafkaConsumer) Init() error {
@ -318,7 +318,7 @@ func (k *KafkaConsumer) Start(acc telegraf.Accumulator) error {
k.startErrorAdder(acc) k.startErrorAdder(acc)
for ctx.Err() == nil { for ctx.Err() == nil {
handler := NewConsumerGroupHandler(acc, k.MaxUndeliveredMessages, k.parserFunc, k.Log) handler := NewConsumerGroupHandler(acc, k.MaxUndeliveredMessages, k.parser, k.Log)
handler.MaxMessageLen = k.MaxMessageLen handler.MaxMessageLen = k.MaxMessageLen
handler.TopicTag = k.TopicTag handler.TopicTag = k.TopicTag
//if message headers list specified, put it as map to handler //if message headers list specified, put it as map to handler
@ -377,12 +377,12 @@ type Message struct {
session sarama.ConsumerGroupSession session sarama.ConsumerGroupSession
} }
func NewConsumerGroupHandler(acc telegraf.Accumulator, maxUndelivered int, fn telegraf.ParserFunc, log telegraf.Logger) *ConsumerGroupHandler { func NewConsumerGroupHandler(acc telegraf.Accumulator, maxUndelivered int, parser telegraf.Parser, log telegraf.Logger) *ConsumerGroupHandler {
handler := &ConsumerGroupHandler{ handler := &ConsumerGroupHandler{
acc: acc.WithTracking(maxUndelivered), acc: acc.WithTracking(maxUndelivered),
sem: make(chan empty, maxUndelivered), sem: make(chan empty, maxUndelivered),
undelivered: make(map[telegraf.TrackingID]Message, maxUndelivered), undelivered: make(map[telegraf.TrackingID]Message, maxUndelivered),
parserFunc: fn, parser: parser,
log: log, log: log,
} }
return handler return handler
@ -394,11 +394,11 @@ type ConsumerGroupHandler struct {
TopicTag string TopicTag string
MsgHeadersToTags map[string]bool MsgHeadersToTags map[string]bool
acc telegraf.TrackingAccumulator acc telegraf.TrackingAccumulator
sem semaphore sem semaphore
parserFunc telegraf.ParserFunc parser telegraf.Parser
wg sync.WaitGroup wg sync.WaitGroup
cancel context.CancelFunc cancel context.CancelFunc
mu sync.Mutex mu sync.Mutex
undelivered map[telegraf.TrackingID]Message undelivered map[telegraf.TrackingID]Message
@ -476,12 +476,7 @@ func (h *ConsumerGroupHandler) Handle(session sarama.ConsumerGroupSession, msg *
len(msg.Value), h.MaxMessageLen) len(msg.Value), h.MaxMessageLen)
} }
parser, err := h.parserFunc() metrics, err := h.parser.Parse(msg.Value)
if err != nil {
return fmt.Errorf("creating parser: %w", err)
}
metrics, err := parser.Parse(msg.Value)
if err != nil { if err != nil {
h.release() h.release()
return err return err

View File

@ -294,15 +294,11 @@ func (c *FakeConsumerGroupClaim) Messages() <-chan *sarama.ConsumerMessage {
func TestConsumerGroupHandler_Lifecycle(t *testing.T) { func TestConsumerGroupHandler_Lifecycle(t *testing.T) {
acc := &testutil.Accumulator{} acc := &testutil.Accumulator{}
parserFunc := func() (telegraf.Parser, error) { parser := value.Parser{
parser := &value.Parser{ MetricName: "cpu",
MetricName: "cpu", DataType: "int",
DataType: "int",
}
err := parser.Init()
return parser, err
} }
cg := NewConsumerGroupHandler(acc, 1, parserFunc, testutil.Logger{}) cg := NewConsumerGroupHandler(acc, 1, &parser, testutil.Logger{})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -330,15 +326,12 @@ func TestConsumerGroupHandler_Lifecycle(t *testing.T) {
func TestConsumerGroupHandler_ConsumeClaim(t *testing.T) { func TestConsumerGroupHandler_ConsumeClaim(t *testing.T) {
acc := &testutil.Accumulator{} acc := &testutil.Accumulator{}
parserFunc := func() (telegraf.Parser, error) { parser := value.Parser{
parser := &value.Parser{ MetricName: "cpu",
MetricName: "cpu", DataType: "int",
DataType: "int",
}
err := parser.Init()
return parser, err
} }
cg := NewConsumerGroupHandler(acc, 1, parserFunc, testutil.Logger{}) require.NoError(t, parser.Init())
cg := NewConsumerGroupHandler(acc, 1, &parser, testutil.Logger{})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -451,15 +444,12 @@ func TestConsumerGroupHandler_Handle(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
acc := &testutil.Accumulator{} acc := &testutil.Accumulator{}
parserFunc := func() (telegraf.Parser, error) { parser := value.Parser{
parser := &value.Parser{ MetricName: "cpu",
MetricName: "cpu", DataType: "int",
DataType: "int",
}
err := parser.Init()
return parser, err
} }
cg := NewConsumerGroupHandler(acc, 1, parserFunc, testutil.Logger{}) require.NoError(t, parser.Init())
cg := NewConsumerGroupHandler(acc, 1, &parser, testutil.Logger{})
cg.MaxMessageLen = tt.maxMessageLen cg.MaxMessageLen = tt.maxMessageLen
cg.TopicTag = tt.topicTag cg.TopicTag = tt.topicTag
@ -573,12 +563,9 @@ func TestKafkaRoundTripIntegration(t *testing.T) {
MaxUndeliveredMessages: 1, MaxUndeliveredMessages: 1,
ConnectionStrategy: tt.connectionStrategy, ConnectionStrategy: tt.connectionStrategy,
} }
parserFunc := func() (telegraf.Parser, error) { parser := &influx.Parser{}
parser := &influx.Parser{} require.NoError(t, parser.Init())
err := parser.Init() input.SetParser(parser)
return parser, err
}
input.SetParserFunc(parserFunc)
require.NoError(t, input.Init()) require.NoError(t, input.Init())
acc := testutil.Accumulator{} acc := testutil.Accumulator{}
@ -634,12 +621,9 @@ func TestExponentialBackoff(t *testing.T) {
}, },
}, },
} }
parserFunc := func() (telegraf.Parser, error) { parser := &influx.Parser{}
parser := &influx.Parser{} require.NoError(t, parser.Init())
err := parser.Init() input.SetParser(parser)
return parser, err
}
input.SetParserFunc(parserFunc)
//time how long initialization (connection) takes //time how long initialization (connection) takes
start := time.Now() start := time.Now()
@ -682,13 +666,9 @@ func TestExponentialBackoffDefault(t *testing.T) {
}, },
}, },
} }
parserFunc := func() (telegraf.Parser, error) { parser := &influx.Parser{}
parser := &influx.Parser{} require.NoError(t, parser.Init())
err := parser.Init() input.SetParser(parser)
return parser, err
}
input.SetParserFunc(parserFunc)
require.NoError(t, input.Init()) require.NoError(t, input.Init())
// We don't need to start the plugin here since we're only testing // We don't need to start the plugin here since we're only testing

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/dimchansky/utfbom" "github.com/dimchansky/utfbom"
@ -35,6 +36,8 @@ type Parser struct {
iterateObjects bool iterateObjects bool
// objectConfig contains the config for an object, some info is needed while iterating over the gjson results // objectConfig contains the config for an object, some info is needed while iterating over the gjson results
objectConfig Object objectConfig Object
// parseMutex is here because Parse() is not threadsafe. If it is made threadsafe at some point, then we won't need it anymore.
parseMutex sync.Mutex
} }
type Config struct { type Config struct {
@ -114,6 +117,19 @@ func (p *Parser) Init() error {
} }
func (p *Parser) Parse(input []byte) ([]telegraf.Metric, error) { func (p *Parser) Parse(input []byte) ([]telegraf.Metric, error) {
// What we've done here is to put the entire former contents of Parse()
// into parseCriticalPath().
//
// As we determine what bits of parseCriticalPath() are and are not
// threadsafe, we can lift the safe pieces back up into Parse(), and
// shrink the scope (or scopes, if the critical sections are disjoint)
// of those pieces that need to be protected with a mutex.
return p.parseCriticalPath(input)
}
func (p *Parser) parseCriticalPath(input []byte) ([]telegraf.Metric, error) {
p.parseMutex.Lock()
defer p.parseMutex.Unlock()
reader := strings.NewReader(string(input)) reader := strings.NewReader(string(input))
body, _ := utfbom.Skip(reader) body, _ := utfbom.Skip(reader)
input, err := io.ReadAll(body) input, err := io.ReadAll(body)