项目地址:
https://github.com/go-sql-driver/mysql
使用:
package main import ( "database/sql" "fmt" _ "github.com/go-sql-driver/mysql" "strconv" "time" ) /** * class DB */ type DB struct { Username string Password string Network string Server string Port string Database string Context *sql.DB } /** * 初始化方法:确定 mysql 连接参数 */ func (this *DB) Init() { conn := fmt.Sprintf("%s:%s@%s(%s:%s)/%s", this.Username, this.Password, this.Network, this.Server, this.Port, this.Database) var err error this.Context, err = sql.Open("mysql", conn) if err != nil { fmt.Println("connection to mysql failed:", err) return } this.Context.SetConnMaxLifetime(100 * time.Second) // 最大连接周期,超时关闭连接 this.Context.SetMaxOpenConns(100) // 设置最大连接数 } /** * 创建表 * @param sql string 创建表语句 * @return bool true:创建成功 false:创建失败 */ func (this *DB) Create(sql string) bool { if _, err := this.Context.Exec(sql); err != nil { fmt.Println("create table failed:", err) return false } return true } /** * 增 * @param sql string 插入语句 * @param args any 参数 * @return int 新增的ID */ func (this *DB) Insert(sql string, args ...interface{}) int { result, err := this.Context.Exec(sql, args...) if err != nil { fmt.Printf("Insert data failed,err:%v", err) return 0 } lastid64, err := result.LastInsertId() if err != nil { fmt.Printf("Get insert id failed,err:%v", err) return 0 } lastid, _ := strconv.Atoi(strconv.FormatInt(lastid64, 10)) // int64 转 int return lastid } /** * 删 * @param sql string 删除语句 * @param args any 参数 * @return int 影响的行数 */ func (this *DB) Delete(sql string, args ...interface{}) int { result, err := this.Context.Exec(sql, args...) if err != nil { fmt.Printf("delete failed,err:%v\n", err) return 0 } rows64, err := result.RowsAffected() if err != nil { fmt.Printf("Get RowsAffected failed,err:%v\n", err) return 0 } rows, _ := strconv.Atoi(strconv.FormatInt(rows64, 10)) // int64 转 int return rows } /** * 改 * @param sql string 更新语句 * @param args any 参数 * @return int 影响的行数 */ func (this *DB) Update(sql string, args ...interface{}) int { result, err := this.Context.Exec(sql, args...) if err != nil { fmt.Printf("Update failed,err:%v\n", err) return 0 } rows64, err := result.RowsAffected() if err != nil { fmt.Printf("Get RowsAffected failed,err:%v\n", err) return 0 } rows, _ := strconv.Atoi(strconv.FormatInt(rows64, 10)) // int64 转 int return rows } /** * 查单行 * @param sql string 查询语句 * @param args *any 参数(指针类型) */ func (this *DB) GetRow(sql string, data ...interface{}) { row := this.Context.QueryRow(sql) if err := row.Scan(data...); err != nil { fmt.Printf("scan failed, err:%v\n", err) return } } /** * 查多行 * @param sql string 查询语句 * @param args any 参数 * @return array 关联数组 [{id:1,name:""},{id:2,name:""}...] */ func (this *DB) Select(query string, args ...interface{}) []map[string]string { data := []map[string]string{} rows, err := this.Context.Query(query, args...) defer func() { if rows != nil { _ = rows.Close() // 关闭掉未scan的sql连接 } }() if err != nil { fmt.Printf("Query failed,err:%v\n", err) return []map[string]string{} } columns, err := rows.Columns() if err != nil { fmt.Println(err.Error()) return []map[string]string{} } values := make([]sql.RawBytes, len(columns)) scanArgs := make([]interface{}, len(values)) for i := range values { scanArgs[i] = &values[i] } for rows.Next() { err = rows.Scan(scanArgs...) if err != nil { fmt.Println(err.Error()) return []map[string]string{} } var value string tmp := map[string]string{} for i, col := range values { if col == nil { value = "NULL" } else { value = string(col) } tmp[columns[i]] = value } data = append(data, tmp) } if err = rows.Err(); err != nil { fmt.Println(err.Error()) return []map[string]string{} } return data } /** * 返回原生 DB对象 * @return *sql.DB 可以进行事务等原生操作 */ func (this *DB) Raw() *sql.DB { return this.Context } /** * 辅助函数:返回实例化后的 DB 对象(自动执行 DB.Init()) */ func NewDB(conf map[string]string) DB { db := DB{ Username: conf["username"], Password: conf["password"], Network: conf["network"], Server: conf["server"], Port: conf["port"], Database: conf["database"], } db.Init() // 实例化 DB 对象 return db } func main() { //conf := map[string]string{ // "username": "root", // "password": "123456", // "network": "tcp", // "server": "192.168.0.21", // "port": "3306", // "database": "test", //} //database := NewDB(conf) //// todo: 创建表 //_sql := ` //CREATE TABLE IF NOT EXISTS users( //id INT(4) PRIMARY KEY AUTO_INCREMENT NOT NULL, //username VARCHAR(64), //password VARCHAR(64), //status INT(4), //createtime INT(10) //);` //ret1 := database.Create(_sql) //fmt.Println("创建表:", ret1) //// todo: 增 //ret2 := database.Insert("INSERT INTO users (username,password) VALUES (?,?)", "test", "123456") //fmt.Println("增:", ret2) //// todo: 删 //ret3 := database.Delete("DELETE FROM users WHERE id > ? AND id < ?", 1, 3) //fmt.Println("删:", ret3) //// todo: 改 //ret4 := database.Update("UPDATE users SET username = ? WHERE id = ?", "test1", 3) //fmt.Println("改:", ret4) //// todo: 查单行 //var username string //database.GetRow("SELECT username FROM users WHERE id = 3", &username) //fmt.Println("查单行:", username) //// todo: 查多行 //ret5 := database.Select("SELECT username,password FROM users WHERE id < ?", 10) //fmt.Println("查多行:", ret5) //// todo: 原生操作 //raw_db := database.Raw() // 获取对象 //tx, _ := raw_db.Begin() // 开启事务 //tx.Commit() // 提交事务 //tx.Rollback() // 回滚事务 }
文档更新时间: 2024-04-20 10:57 作者:lee