-
Notifications
You must be signed in to change notification settings - Fork 4
/
service.go
115 lines (105 loc) · 3.3 KB
/
service.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package main
import (
"context"
"encoding/base64"
"errors"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
"github.com/gin-gonic/gin"
)
func (p *program) run() {
if err := loadConfig(); err != nil {
logger.Errorf("ERR-loadConfig: %v", err.Error())
os.Exit(1)
}
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
r.Any("/*proxyPath", proxy)
go r.Run(net.JoinHostPort(config.Host, config.Port))
logger.Infof("basicToOauth version: %s, started on: %s\r\n", VERSION, net.JoinHostPort(config.Host, config.Port))
go func() { // Check and delete expired tokens every 5 minutes
for {
tokensMap.delExpired()
time.Sleep(10 * time.Minute)
}
}()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
<-sigs
<-p.exit
}
func proxy(c *gin.Context) {
remote, err := url.Parse(config.ProxyURL)
if err != nil {
logger.Errorf("ERR-remoteURL: %v", err.Error())
c.AbortWithStatus(http.StatusInternalServerError)
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Header.Set("Authorization", getAuthHeader(c.Request.Header.Get("Authorization")))
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Param("proxyPath")
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func getAuthHeader(authHeader string) string {
if strings.Split(authHeader, " ")[0] != "Basic" { // If anythig else than Basic auth, return original header
return authHeader
}
currHeader := strings.Split(authHeader, " ")
if currHeader[0] == "Basic" && len(currHeader) == 2 { // If authHeader seems to be Basic auth
mapToken := tokensMap.get(currHeader[1])
if mapToken != nil { // If token is in map
if mapToken.expire.Unix()-60 < time.Now().Unix() { // If token is about to expire try to get new one
newToken, err := getAzureToken(currHeader[1])
if err != nil {
logger.Warningf("ERR-getAzureToken: %v", err.Error())
return authHeader
}
tokensMap.add(currHeader[1], newToken) // Add new token to map
}
return "Bearer " + tokensMap.get(currHeader[1]).token // Return token from map
} else { // If token is not in map
newToken, err := getAzureToken(currHeader[1]) // Get new token
if err != nil {
logger.Warningf("ERR-getAzureToken: %v", err.Error())
return authHeader
}
tokensMap.add(currHeader[1], newToken) // Add new token to map
return "Bearer " + newToken.token // Return new token
}
}
return authHeader
}
// Request Azure token
func getAzureToken(baseKey string) (*tToken, error) {
baseDecode, err := base64.StdEncoding.DecodeString(baseKey)
if err != nil {
return nil, err
}
baseSplit := strings.Split(string(baseDecode), ":")
if len(baseSplit) != 2 {
return nil, errors.New("basicAuthParseFailed")
}
publicClientApp, err := public.New(config.ClientID, public.WithAuthority(config.AuthorityURL+config.TenantID))
if err != nil {
return nil, err
}
result, err := publicClientApp.AcquireTokenByUsernamePassword(context.Background(), config.Scopes, baseSplit[0], baseSplit[1])
if err != nil {
return nil, err
}
return &tToken{token: result.AccessToken, expire: result.ExpiresOn}, nil
}