package main import ( "context" "encoding/json" "flag" "fmt" "hilo-userProxy/common" "hilo-userProxy/common/consul" "net" "net/http" _ "net/http/pprof" "runtime/debug" "strconv" "strings" "sync" "sync/atomic" "time" "hilo-userProxy/common/mylogrus" "hilo-userProxy/protocol/userCenter" "hilo-userProxy/protocol/userProxy" "github.com/golang/protobuf/proto" "github.com/gorilla/websocket" consulapi "github.com/hashicorp/consul/api" uuid "github.com/satori/go.uuid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" ) // 序号号,一个进程一个 var serialNum uint64 = 0 var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: checkOrigin, } type ProtoData struct { MsgType uint32 Data []byte } var myAddress = "" var m sync.RWMutex var userChan = make(map[uint64]map[string]chan ProtoData) /* 其中CheckOringin是一个函数,该函数用于拦截或放行跨域请求。函数返回值为bool类型,即true放行,false拦截。 如果请求不是跨域请求可以不赋值,我这里是跨域请求并且为了方便直接返回true */ func checkOrigin(r *http.Request) bool { return true } const ( notLoginLimit = 10.0 heartBeatLossLimit = 30.0 ) type ConnectionInfo struct { RemoteAddr string `json:"remote_addr"` ConnectTime time.Time HeartbeatTime time.Time IsLogin bool Uid uint64 `json:"uid"` MsgId uint64 `json:"msg_id"` ExternalId string `json:"external_id"` Token string `json:"token"` } func setupRoutes() { http.HandleFunc("/", homePage) http.HandleFunc("/ws", serverWebsocket) } func getShortToken(token string) string { s := strings.Split(token, ".") if len(s) == 0 { return token } else { return s[len(s)-1] } } func serverWebsocket(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { mylogrus.MyLog.Errorf("upgrade err:%v", err) return } now := time.Now() ci := ConnectionInfo{ RemoteAddr: ws.RemoteAddr().String(), ConnectTime: now, HeartbeatTime: now, IsLogin: false, Uid: 0, MsgId: 0, ExternalId: "", Token: "", } traceId := uuid.NewV4().String() logger := mylogrus.MyLog.WithFields(log.Fields{ "addr": ci.RemoteAddr, "token": ci.Token, "traceId": traceId, }) logger.Infof("Websocket client successfully Connected at %s", ci.ConnectTime.String()) done := make(chan struct{}) writeChan := make(chan ProtoData, 1000) ticker := time.NewTicker(time.Second * 2) defer ticker.Stop() go func() { defer common.CheckGoPanic() defer ws.Close() defer close(done) //defer close(writeChan) // 读断了就把写也断了 for { messageType, message, err := ws.ReadMessage() if err != nil { logger.Infof("read error:%v", err) logger = logger.WithFields(log.Fields{ "read_err": err.Error(), }) break } if messageType != websocket.BinaryMessage { logger.Errorf("Unexpected messageType %d", messageType) continue } msgType, msgId, timeStamp, pbData, err := common.DecodeMessage(message) if err != nil { logger.Errorf("Decode error %s", err.Error()) continue } if msgId <= ci.MsgId { logger.Infof("Discard outdated message msgId = %d <= %d, msgType = %d", msgId, ci.MsgId, msgType) } else { ci.MsgId = msgId // 统计时延 now := time.Now() us := now.UnixNano()/1000 - int64(timeStamp) if us > 200000 { logger.Infof("Message take %d us to come here.", us) } if msgType == common.MsgTypeLogin { if ci.IsLogin { logger.Infof("Connection already logged in %+v", ci) } else { var rsp *userProxy.LoginRsp ci.Token, ci.Uid, rsp, err = processLogin(logger, ci.RemoteAddr, pbData) if err == nil && rsp != nil { logger = logger.WithFields(log.Fields{ "addr": ci.RemoteAddr, "userId": ci.Uid, "sToken": getShortToken(ci.Token), }) if buffer, err := proto.Marshal(rsp); err == nil { writeChan <- ProtoData{ MsgType: common.MsgTypeLoginRsp, Data: buffer, } } if rsp.Status == common.Login_success { ci.IsLogin = true ci.HeartbeatTime = now logger.Infof("Bind to channel %v", writeChan) setUserChan(ci.Uid, ci.RemoteAddr, writeChan) } } } } else if msgType == common.MsgTypeHeartBeat { status, extUid, err := processHeartBeat(pbData) logger.Infof("heartbeat") if err == nil { ci.HeartbeatTime = time.Now() if len(ci.ExternalId) == 0 { logger.Infof("Received first heartbeat") ci.ExternalId = extUid logger = logger.WithFields(log.Fields{ "addr": ci.RemoteAddr, "userId": ci.Uid, "sToken": getShortToken(ci.Token), "externalId": ci.ExternalId, }) } msg := &userProxy.HeartBeatRsp{ Status: status, } if buffer, err := proto.Marshal(msg); err == nil { writeChan <- ProtoData{ MsgType: common.MsgTypeHeartBeatRsp, Data: buffer, } } } } else if msgType == common.MsgTypeGlobalGiftBannerRsp { logger = logger.WithFields(log.Fields{ "addr": ci.RemoteAddr, "userId": ci.Uid, "sToken": getShortToken(ci.Token), "externalId": ci.ExternalId, }) rsp := &userProxy.GlobalGiftBannerRsp{} err := proto.Unmarshal(pbData, rsp) if err == nil { logger.Infof("GlobalGiftBannerRsp, msgType = %d, %v", msgType, rsp) } else { logger.Errorf("Unmarshal error") } } else if msgType == common.MsgTypeBiz { logger = logger.WithFields(log.Fields{ "addr": ci.RemoteAddr, "userId": ci.Uid, }) rsp, err := processBizRequest(logger, ci.Uid, pbData) if err == nil && rsp != nil { logger.Infof("processBizRequest rsp %+v", rsp) if buffer, err := proto.Marshal(rsp); err == nil { writeChan <- ProtoData{ MsgType: common.MsgTypeBizRsp, Data: buffer, } } } } else { logger.Warnf("Unknown message type %d", msgType) } if err != nil { logger.Infof("process message error %s", err.Error()) } } } logger.Infof("exiting read loop for token %s, user %d", ci.Token, ci.Uid) sz := removeUserChan(ci.Uid, ci.RemoteAddr) b, _ := json.Marshal(ci) logger.Infof("exiting read loop for %s, size = %d", string(b), sz) if ci.IsLogin { err, status := doLogout(ci.RemoteAddr, ci.Uid) if err == nil { logger.Printf("logout result %d", status) } else { logger.Printf("Logout failed, %s", err.Error()) } ci.IsLogin = false } }() Loop: for { select { case <-done: break Loop case <-ticker.C: timeDiff := time.Since(ci.HeartbeatTime) if ci.IsLogin { if timeDiff.Seconds() > heartBeatLossLimit/2 { logger.Infof("Heartbeat lost for %f seconds", timeDiff.Seconds()) } if timeDiff.Seconds() > heartBeatLossLimit { logger.Warnln("Heartbeat lost, terminate!") err := ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat loss"), time.Now().Add(time.Second)) if err != nil { logger.Println("write close:", err) } ws.Close() break Loop } } else { if timeDiff.Seconds() > notLoginLimit { logger.Infof("Not loggined for %f seconds, kick!", timeDiff.Seconds()) err := ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "no login"), time.Now().Add(time.Second)) if err != nil { logger.Println("write close:", err) } _ = ws.Close() break Loop } } case d := <-writeChan: if d.MsgType == common.MsgTypeKickUser { // 特殊消息,用于踢走用户 logger.Infof("Login from another device, kick!") // 取消登录态,避免readloop结束时发logout消息,把新记录也清了 ci.IsLogin = false err := ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "double login"), time.Now().Add(time.Second*3)) if err != nil { logger.Println("write close:", err) break } _ = ws.Close() break Loop } atomic.AddUint64(&serialNum, 1) logger.Infof("About to send msgType = %d, size = %d", d.MsgType, len(d.Data)) data := common.EncodeMessage(d.MsgType, serialNum, d.Data) err = ws.WriteMessage(websocket.BinaryMessage, data) if err != nil { logger.Println("write close:", err) } } } b, _ := json.Marshal(ci) logger.Infof("exiting write loop:%s", string(b)) } func processLogin(logger *log.Entry, clientAddr string, pbData []byte) (string, uint64, *userProxy.LoginRsp, error) { rsp := &userProxy.LoginRsp{} //logger := log.WithFields(log.Fields{ // "addr": clientAddr, //}) loginMsg := &userProxy.Login{} var loginUid uint64 = 0 err := proto.Unmarshal(pbData, loginMsg) if err == nil { logger := logger.WithFields(log.Fields{ "addr": clientAddr, "token": loginMsg.Token, }) logger.Info(loginMsg.String()) err, rsp.Status, loginUid = doLogin(loginMsg.Token, clientAddr) if err == nil { logger.Infof("login status = %d, uid = %d,proxyAddr:%s,clientAddr:%s", rsp.Status, loginUid, myAddress, clientAddr) } else { logger.Warnf("login RPC failed for %d, %v", rsp.Status, err) } return loginMsg.Token, loginUid, rsp, err } else { logger.Warn("unmarshal error") return loginMsg.Token, 0, rsp, err } } func processHeartBeat(pbData []byte) (uint32, string, error) { heartBeat := &userProxy.HeartBeat{} err := proto.Unmarshal(pbData, heartBeat) if err == nil { // FIXME: 增加heartbeat过频控制 return 0, heartBeat.ExternalUid, err } return 0, "", err } func processBizRequest(logger *log.Entry, uid uint64, pbData []byte) (*userProxy.BizResponse, error) { //logger := log.WithFields(log.Fields{ // "uid": uid, //}) rsp := &userProxy.BizResponse{} msg := &userProxy.BizRequest{} err := proto.Unmarshal(pbData, msg) if err == nil { logger.Infof("doBizRequest msgType = %d, msg: %s", msg.Type, msg.PayLoad) rsp.Status, err = doBizRequest(uid, msg.Type, msg.PayLoad) if err != nil { logger.Warnf("doBizRequest RPC failed for %d, %v, uid = %d", uid, rsp.Status, err) } return rsp, err } else { logger.Warn("unmarshal error") return rsp, err } } // 直接发匹配成功消息,调试用 func homePage(w http.ResponseWriter, r *http.Request) { rsp := "Home Page" var targetUserId uint64 = 0 fromUserId := "" toUserId := "" if r.ParseForm() == nil { targetUserId, _ = strconv.ParseUint(r.FormValue("target"), 10, 64) fromUserId = r.FormValue("from") toUserId = r.FormValue("to") } if targetUserId == 0 || len(fromUserId) == 0 || len(toUserId) == 0 { rsp = "targetUserId or fromUserId or toUserId is missing." } else { m := getChans(targetUserId) if m == nil { rsp = "User not found" } else { for i, c := range m { rsp += fmt.Sprintf("address %d bound to channel %v", i, c) } } } fmt.Fprint(w, rsp) } func setUserChan(uid uint64, addr string, c chan ProtoData) { m.Lock() defer m.Unlock() if userChan[uid] == nil { userChan[uid] = make(map[string]chan ProtoData) } userChan[uid][addr] = c } func getUserChan(uid uint64, addr string) chan ProtoData { m.RLock() defer m.RUnlock() if userChan[uid] == nil { return nil } else { return userChan[uid][addr] } } func getChans(uid uint64) map[string]chan ProtoData { m.RLock() defer m.RUnlock() return userChan[uid] } func removeUserChan(uid uint64, addr string) int { m.Lock() defer m.Unlock() if chs, ok := userChan[uid]; ok { delete(chs, addr) if len(chs) == 0 { delete(userChan, uid) } } return len(userChan) } const ( rpcListenPortBase = 50050 ) type server struct { userCenter.UnimplementedRouterServer } func (s *server) Route(ctx context.Context, in *userCenter.RouteMessage) (*userCenter.RouteMessageRsp, error) { defer func() { if r := recover(); r != nil { //打印错误堆栈信息 mylogrus.MyLog.Errorf("Route SYSTEM ACTION PANIC: %v, stack: %v", r, string(debug.Stack())) } }() peerInfo, _ := peer.FromContext(ctx) mylogrus.MyLog.Infof("Received Route request from %s: msgType = %d, uid = %d, payload size = %d", peerInfo.Addr.String(), in.GetMsgType(), in.GetUid(), len(in.GetPayLoad())) var status uint32 = common.ROUTE_SUCCESS m := getChans(in.GetUid()) if len(m) > 0 { for i, c := range m { mylogrus.MyLog.Infof("Route to user %d, addr %s, channel %v", in.Uid, i, c) c <- ProtoData{ MsgType: in.GetMsgType(), Data: in.GetPayLoad(), } } } else { status = common.ROUTE_CHANNEL_NOT_FOUND mylogrus.MyLog.Warnf("No write channel for user %d", in.GetUid()) } return &userCenter.RouteMessageRsp{Status: status}, nil } func (s *server) KickUser(ctx context.Context, in *userCenter.KickMessage) (*userCenter.KickMessageRsp, error) { peerInfo, _ := peer.FromContext(ctx) logger := mylogrus.MyLog.WithFields(log.Fields{ "uid": in.Uid, "addr": in.Addr, }) logger.Infof("Received KickUser request from %s", peerInfo.Addr.String()) c := getUserChan(in.Uid, in.Addr) if c == nil { logger.Infof("No write channel") } else { c <- ProtoData{ MsgType: common.MsgTypeKickUser, } } return &userCenter.KickMessageRsp{Status: 0}, nil } var userClient userCenter.UserClient func doLogin(token string, clientAddr string) (error, uint32, uint64) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() r, err := userClient.Login(ctx, &userCenter.LoginMessage{ Token: token, ProxyAddr: myAddress, ClientAddr: clientAddr, }) if err == nil && r != nil { return err, r.Status, r.Uid } else { return err, 0, 0 } } func doLogout(addr string, uid uint64) (error, uint32) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() r, err := userClient.Logout(ctx, &userCenter.LogoutMessage{ ClientAddr: addr, Uid: uid, }) if err == nil && r != nil { return err, r.Status } else { return err, 0 } } func doBizRequest(uid uint64, msgType uint32, payLoad string) (uint32, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() r, err := userClient.Transmit(ctx, &userCenter.BizMessage{ Uid: uid, MsgType: msgType, PayLoad: payLoad, }) if err == nil && r != nil { return r.Status, err } else { return 0, err } } var kacp = keepalive.ClientParameters{ Time: 10 * time.Second, // send pings every 10 seconds if there is no activity Timeout: time.Second, // wait 1 second for ping ack before considering the connection dead PermitWithoutStream: true, // send pings even without active streams } var ( defaultUserCenterAddr = "127.0.0.1:50040" // userCenter default addr userCenterAddr = defaultUserCenterAddr userCenterConsulName = "userCenter" ) // grpc服务发现 type Builder struct { addrs map[string][]string cc resolver.ClientConn } func (b *Builder) Scheme() string { return "uc" // userCenter } type Resolver struct { } func (r Resolver) ResolveNow(options resolver.ResolveNowOptions) {} func (r Resolver) Close() {} func (b *Builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { r := &Resolver{} paths := b.addrs[target.URL.Path] addrs := make([]resolver.Address, len(paths)) for i, s := range paths { addrs[i] = resolver.Address{Addr: s} } cc.UpdateState(resolver.State{Addresses: addrs}) b.cc = cc return r, nil } func (b *Builder) UpdateState(addrs []string) { as := make([]resolver.Address, len(addrs)) for i, s := range addrs { as[i] = resolver.Address{Addr: s} } b.cc.UpdateState(resolver.State{Addresses: as}) } var certFile = flag.String("certFile", "ws.faceline.live.pem", "cert file") var keyFile = flag.String("keyFile", "ws.faceline.live.key", "key file") var listenPort = flag.String("port", "8081", "listen port") func main() { flag.Parse() client, err := consulapi.NewClient(consulapi.DefaultConfig()) //非默认情况下需要设置实际的参数 if client == nil { mylogrus.MyLog.Fatalln("Fail to get consul client.") } myIp, myNodeName := consul.GetAgentInfo(client) myLocalIp, err := common.GetClientIpV2() if err != nil { mylogrus.MyLog.Fatal(err) } mylogrus.MyLog.Infof("myIp is %s, myNodeName: %s, localIp is %s", myIp, myNodeName, myLocalIp) cataLog := client.Catalog() if cataLog == nil { mylogrus.MyLog.Fatalln("No catalog.") } services, _, err := cataLog.Service(userCenterConsulName, "", nil) if err != nil { mylogrus.MyLog.Fatalln(err) } if len(services) == 0 { mylogrus.MyLog.Fatalln("userCenter not found.") } var addrs []string bd := &Builder{addrs: map[string][]string{"/api": {userCenterAddr}}} for _, s := range services { addrs = append(addrs, fmt.Sprintf("%s:%d", s.ServiceAddress, s.ServicePort)) } if len(addrs) > 0 { bd = &Builder{addrs: map[string][]string{"/api": addrs}} userCenterAddr = "uc:///api" } mylogrus.MyLog.Infof("userCenterAddr:%v,addr:%v", userCenterAddr, addrs) // 服务发现 resolver.Register(bd) go func() { consul.RegisterWatcher(userCenterConsulName, func(addr []string) { if len(addr) > 0 { bd.UpdateState(addr) // 更新新的注册名 } }) }() // Set up a connection to the userCenter. conn, err := grpc.Dial(userCenterAddr, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithKeepaliveParams(kacp), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy":"%s"}`, "round_robin"))) if err != nil { mylogrus.MyLog.Fatalf("did not connect: %v", err) } //defer conn.Close() userClient = userCenter.NewUserClient(conn) if userClient == nil { mylogrus.MyLog.Fatalln("userClient null") } go func() { var i int var lis net.Listener var err error for i = 0; i < 10; i++ { addr := ":" + strconv.Itoa(rpcListenPortBase+i) lis, err = net.Listen("tcp", addr) if err == nil { mylogrus.MyLog.Infof("Go RPC listening on %s", lis.Addr().String()) break } } if i >= 10 { mylogrus.MyLog.Fatalln("No RPC Listen port available.") } myAddress = myLocalIp + ":" + strconv.Itoa(rpcListenPortBase+i) mylogrus.MyLog.Infof("My address is %s", myAddress) s := grpc.NewServer() userCenter.RegisterRouterServer(s, &server{}) if err := s.Serve(lis); err != nil { mylogrus.MyLog.Fatalf("failed to serve: %v", err) } }() setupRoutes() addr := ":" + *listenPort mylogrus.MyLog.Infof("Go Websocket listening on %s", addr) // http ws mux := http.NewServeMux() mux.HandleFunc("/ws", serverWebsocket) go http.ListenAndServe(":8082", mux) // https wss mylogrus.MyLog.Infof("certFile = %s, keyFile = %s", *certFile, *keyFile) mylogrus.MyLog.Printf("%s", http.ListenAndServeTLS(addr, *certFile, *keyFile, nil).Error()) }