Skip to content

Commit

Permalink
Merge pull request #1 from LayerXcom/feat/support-dumped-generated-co…
Browse files Browse the repository at this point in the history
…lumns

feat: support INSERT statements of tables with generated columns
  • Loading branch information
suguru authored Aug 21, 2023
2 parents dc97dfa + 1dfebfc commit 64ff6be
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
41 changes: 27 additions & 14 deletions canal/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,26 @@ func (h *dumpParseHandler) Data(db string, table string, values []string) error
return errors.Trace(err)
}

vs := make([]interface{}, len(values))
vs := make([]interface{}, 0, len(tableInfo.Columns))

for i, v := range values {
i := 0
for _, c := range tableInfo.Columns {
v := values[i]
if v == "NULL" {
vs[i] = nil
vs = append(vs, nil)
i++
} else if c.IsVirtual {
vs = append(vs, nil)
// do not increment i since this column is virtual one.
} else if v == "_binary ''" {
vs[i] = []byte{}
vs = append(vs, []byte{})
i++
} else if v[0] != '\'' {
if tableInfo.Columns[i].Type == schema.TYPE_NUMBER || tableInfo.Columns[i].Type == schema.TYPE_MEDIUM_INT {
if c.Type == schema.TYPE_NUMBER || c.Type == schema.TYPE_MEDIUM_INT {
var n interface{}
var err error

if tableInfo.Columns[i].IsUnsigned {
if c.IsUnsigned {
n, err = strconv.ParseUint(v, 10, 64)
} else {
n, err = strconv.ParseInt(v, 10, 64)
Expand All @@ -74,38 +81,44 @@ func (h *dumpParseHandler) Data(db string, table string, values []string) error
return fmt.Errorf("parse row %v at %d error %v, int expected", values, i, err)
}

vs[i] = n
} else if tableInfo.Columns[i].Type == schema.TYPE_FLOAT {
vs = append(vs, n)
i++
} else if c.Type == schema.TYPE_FLOAT {
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err)
}
vs[i] = f
} else if tableInfo.Columns[i].Type == schema.TYPE_DECIMAL {
vs = append(vs, f)
i++
} else if c.Type == schema.TYPE_DECIMAL {
if h.c.cfg.UseDecimal {
d, err := decimal.NewFromString(v)
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, decimal expected", values, i, err)
}
vs[i] = d
vs = append(vs, d)
i++
} else {
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err)
}
vs[i] = f
vs = append(vs, f)
i++
}
} else if strings.HasPrefix(v, "0x") {
buf, err := hex.DecodeString(v[2:])
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, hex literal expected", values, i, err)
}
vs[i] = string(buf)
vs = append(vs, string(buf))
i++
} else {
return fmt.Errorf("parse row %v error, invalid type at %d", values, i)
}
} else {
vs[i] = v[1 : len(v)-1]
vs = append(vs, v[1:len(v)-1])
i++
}
}

Expand Down
14 changes: 14 additions & 0 deletions dump/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ type ParseHandler interface {
var binlogExp *regexp.Regexp
var useExp *regexp.Regexp
var valuesExp *regexp.Regexp
var keyedValuesExp *regexp.Regexp
var gtidExp *regexp.Regexp

func init() {
binlogExp = regexp.MustCompile(`^CHANGE MASTER TO MASTER_LOG_FILE='(.+)', MASTER_LOG_POS=(\d+);`)
useExp = regexp.MustCompile("^USE `(.+)`;")
valuesExp = regexp.MustCompile("^INSERT INTO `(.+?)` VALUES \\((.+)\\);$")
keyedValuesExp = regexp.MustCompile("^INSERT INTO `(.+?)` \\((.+)\\) VALUES \\((.+)\\);$")
// The pattern will only match MySQL GTID, as you know SET GLOBAL gtid_slave_pos='0-1-4' is used for MariaDB.
// SET @@GLOBAL.GTID_PURGED='1638041a-0457-11e9-bb9f-00505690b730:1-429405150';
// https://dev.mysql.com/doc/refman/5.7/en/replication-gtids-concepts.html
Expand Down Expand Up @@ -101,6 +103,18 @@ func Parse(r io.Reader, h ParseHandler, parseBinlogPos bool) error {
return errors.Trace(err)
}
}

if m := keyedValuesExp.FindAllStringSubmatch(line, -1); len(m) == 1 {
table := m[0][1]
values, err := parseValues(m[0][3])
if err != nil {
return errors.Errorf("parse values %v err", line)
}

if err = h.Data(db, table, values); err != nil && err != ErrSkip {
return errors.Trace(err)
}
}
}

return nil
Expand Down

0 comments on commit 64ff6be

Please sign in to comment.