fix(inputs.modbus): Avoid overflow when calculating with uint16 addresses (#15146)
This commit is contained in:
parent
598dd1d6bc
commit
80891a6413
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"math"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
|
@ -249,6 +250,11 @@ func (c *ConfigurationPerMetric) newField(def metricFieldDefinition, mdef metric
|
|||
}
|
||||
}
|
||||
|
||||
// Check for address overflow
|
||||
if def.Address > math.MaxUint16-fieldLength {
|
||||
return field{}, fmt.Errorf("%w for field %q", errAddressOverflow, def.Name)
|
||||
}
|
||||
|
||||
// Initialize the field
|
||||
f := field{
|
||||
measurement: mdef.Measurement,
|
||||
|
|
|
|||
|
|
@ -318,3 +318,30 @@ func TestMetricResult(t *testing.T) {
|
|||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.IgnoreTime())
|
||||
}
|
||||
|
||||
func TestMetricAddressOverflow(t *testing.T) {
|
||||
logger := &testutil.CaptureLogger{}
|
||||
plugin := Modbus{
|
||||
Name: "Test",
|
||||
Controller: "tcp://localhost:1502",
|
||||
ConfigurationType: "metric",
|
||||
Log: logger,
|
||||
Workarounds: ModbusWorkarounds{ReadCoilsStartingAtZero: true},
|
||||
}
|
||||
plugin.Metrics = []metricDefinition{
|
||||
{
|
||||
SlaveID: 1,
|
||||
ByteOrder: "ABCD",
|
||||
Measurement: "test",
|
||||
Fields: []metricFieldDefinition{
|
||||
{
|
||||
Name: "field",
|
||||
Address: uint16(65534),
|
||||
InputType: "UINT64",
|
||||
RegisterType: "holding",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.ErrorIs(t, plugin.Init(), errAddressOverflow)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1189,3 +1189,73 @@ func TestRegisterReadMultipleHoldingRegisterLimit(t *testing.T) {
|
|||
|
||||
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
|
||||
}
|
||||
|
||||
func TestRegisterHighAddresses(t *testing.T) {
|
||||
// Test case for issue https://github.com/influxdata/telegraf/issues/15138
|
||||
|
||||
// Setup a server
|
||||
serv := mbserver.NewServer()
|
||||
require.NoError(t, serv.ListenTCP("localhost:1502"))
|
||||
defer serv.Close()
|
||||
|
||||
handler := mb.NewTCPClientHandler("localhost:1502")
|
||||
require.NoError(t, handler.Connect())
|
||||
defer handler.Close()
|
||||
client := mb.NewClient(handler)
|
||||
|
||||
// Write the register values
|
||||
data := []byte{
|
||||
0x4d, 0x6f, 0x64, 0x62, 0x75, 0x73, 0x20, 0x53,
|
||||
0x74, 0x72, 0x69, 0x6e, 0x67, 0x20, 0x48, 0x65,
|
||||
0x6c, 0x6c, 0x6f, 0x00,
|
||||
}
|
||||
_, err := client.WriteMultipleRegisters(65524, 10, data)
|
||||
require.NoError(t, err)
|
||||
_, err = client.WriteMultipleRegisters(65534, 1, []byte{0x10, 0x92})
|
||||
require.NoError(t, err)
|
||||
|
||||
modbus := Modbus{
|
||||
Name: "Issue-15138",
|
||||
Controller: "tcp://localhost:1502",
|
||||
Log: testutil.Logger{},
|
||||
}
|
||||
modbus.SlaveID = 1
|
||||
modbus.HoldingRegisters = []fieldDefinition{
|
||||
{
|
||||
Name: "DeviceName",
|
||||
ByteOrder: "AB",
|
||||
DataType: "STRING",
|
||||
Address: []uint16{65524, 65525, 65526, 65527, 65528, 65529, 65530, 65531, 65532, 65533},
|
||||
},
|
||||
{
|
||||
Name: "DeviceConnectionStatus",
|
||||
ByteOrder: "AB",
|
||||
DataType: "UINT16",
|
||||
Address: []uint16{65534},
|
||||
Scale: 1,
|
||||
},
|
||||
}
|
||||
|
||||
expected := []telegraf.Metric{
|
||||
testutil.MustMetric(
|
||||
"modbus",
|
||||
map[string]string{
|
||||
"type": cHoldingRegisters,
|
||||
"slave_id": strconv.Itoa(int(modbus.SlaveID)),
|
||||
"name": modbus.Name,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"DeviceName": "Modbus String Hello",
|
||||
"DeviceConnectionStatus": uint16(4242),
|
||||
},
|
||||
time.Unix(0, 0),
|
||||
),
|
||||
}
|
||||
|
||||
var acc testutil.Accumulator
|
||||
require.NoError(t, modbus.Init())
|
||||
require.NotEmpty(t, modbus.requests)
|
||||
require.Len(t, modbus.requests[1].holding, 1)
|
||||
require.NoError(t, modbus.Gather(&acc))
|
||||
testutil.RequireMetricsEqual(t, expected, acc.GetTelegrafMetrics(), testutil.IgnoreTime())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"math"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/models"
|
||||
|
|
@ -295,6 +296,11 @@ func (c *ConfigurationPerRequest) newFieldFromDefinition(def requestFieldDefinit
|
|||
}
|
||||
}
|
||||
|
||||
// Check for address overflow
|
||||
if def.Address > math.MaxUint16-fieldLength {
|
||||
return field{}, fmt.Errorf("%w for field %q", errAddressOverflow, def.Name)
|
||||
}
|
||||
|
||||
// Initialize the field
|
||||
f := field{
|
||||
measurement: def.Measurement,
|
||||
|
|
|
|||
|
|
@ -3303,3 +3303,28 @@ func TestRequestOverlap(t *testing.T) {
|
|||
require.Len(t, plugin.requests, 1)
|
||||
require.Len(t, plugin.requests[1].holding, 1)
|
||||
}
|
||||
|
||||
func TestRequestAddressOverflow(t *testing.T) {
|
||||
logger := &testutil.CaptureLogger{}
|
||||
plugin := Modbus{
|
||||
Name: "Test",
|
||||
Controller: "tcp://localhost:1502",
|
||||
ConfigurationType: "request",
|
||||
Log: logger,
|
||||
Workarounds: ModbusWorkarounds{ReadCoilsStartingAtZero: true},
|
||||
}
|
||||
plugin.Requests = []requestDefinition{
|
||||
{
|
||||
SlaveID: 1,
|
||||
RegisterType: "holding",
|
||||
Fields: []requestFieldDefinition{
|
||||
{
|
||||
Name: "field",
|
||||
InputType: "UINT64",
|
||||
Address: uint16(65534),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.ErrorIs(t, plugin.Init(), errAddressOverflow)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ var sampleConfigStart string
|
|||
//go:embed sample_general_end.conf
|
||||
var sampleConfigEnd string
|
||||
|
||||
var errAddressOverflow = errors.New("address overflow")
|
||||
|
||||
type ModbusWorkarounds struct {
|
||||
AfterConnectPause config.Duration `toml:"pause_after_connect"`
|
||||
PollPause config.Duration `toml:"pause_between_requests"`
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package modbus
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
|
|
@ -33,10 +34,16 @@ func splitMaxBatchSize(g request, maxBatchSize uint16) []request {
|
|||
|
||||
// Initialize the end to a safe value avoiding infinite loops
|
||||
end := g.address + g.length
|
||||
var batchEnd uint16
|
||||
if start >= math.MaxUint16-maxBatchSize {
|
||||
batchEnd = math.MaxUint16
|
||||
} else {
|
||||
batchEnd = start + maxBatchSize
|
||||
}
|
||||
for _, f := range g.fields[idx:] {
|
||||
// If the current field exceeds the batch size we need to split
|
||||
// the request here
|
||||
if f.address+f.length > start+maxBatchSize {
|
||||
if f.address+f.length > batchEnd {
|
||||
break
|
||||
}
|
||||
// End of field still fits into the batch so add it to the request
|
||||
|
|
|
|||
Loading…
Reference in New Issue