Lab2
Jump to navigation
Jump to search
INSERT INTO users (username, password1, org_id) VALUES ('test2', crypt('secret', gen_salt('bf', 10)), 123);
$2a$10$J4TMoF9gBM6wCH5GMiPfzO7a0CZK8PQFMv/k2mEE2TPaqYykEKn5G
package main
import (
"bytes"
"context"
"crypto/tls"
"database/sql"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
"path"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/golang/protobuf/proto"
"github.com/golang/snappy"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rqlite/gorqlite"
"github.com/sirupsen/logrus"
_ "github.com/lib/pq"
pb "github.com/prometheus/prometheus/prompb"
)
var (
store UserStore
mimirURL *url.URL
mimirUser string
mimirPass string
httpClient *http.Client
log = logrus.New()
userCache = make(map[string]User)
cacheMux = &sync.RWMutex{}
cacheTTL = 30 * time.Second
)
type User struct {
Username string
Password string
OrgID string
}
type UserStore interface {
LoadAll() (map[string]User, error)
}
type PostgresStore struct {
db *sql.DB
}
type RqliteStore struct {
conn *gorqlite.Connection
}
func main() {
initLogger()
validateEnvVars("MIMIR_URL", "MIMIR_USERNAME", "MIMIR_PASSWORD")
store = openStore()
var err error
mimirURL, err = url.Parse(os.Getenv("MIMIR_URL"))
if err != nil {
log.WithError(err).Fatal("Invalid MIMIR_URL")
}
mimirUser = os.Getenv("MIMIR_USERNAME")
mimirPass = os.Getenv("MIMIR_PASSWORD")
if os.Getenv("BACKEND_SKIP_TLS_VERIFY") == "true" {
log.Warn("BACKEND_SKIP_TLS_VERIFY is true — skipping TLS verification for backend requests!")
httpClient = &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
} else {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
if ttlStr := os.Getenv("USER_CACHE_TTL"); ttlStr != "" {
if d, err := time.ParseDuration(ttlStr); err == nil {
cacheTTL = d
}
}
go cacheRefresher()
mux := http.NewServeMux()
mux.HandleFunc("/prometheus/", handlePrometheusQuery)
mux.HandleFunc("/api/v1/push", handlePush)
mux.Handle("/metrics", promhttp.Handler())
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
srv := &http.Server{
Addr: ":8080",
Handler: mux,
}
go func() {
log.WithField("addr", srv.Addr).Info("HTTP server running")
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.WithError(err).Fatal("HTTP server error")
}
}()
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
log.Info("Shutting down gracefully...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.WithError(err).Fatal("Server forced to shutdown")
}
log.Info("Server stopped.")
}
func cacheRefresher() {
for {
loadUserCache()
time.Sleep(cacheTTL)
}
}
func loadUserCache() {
log.Info("Loading user cache...")
tmp, err := store.LoadAll()
if err != nil {
log.WithError(err).Error("Failed to load user cache")
return
}
cacheMux.Lock()
userCache = tmp
cacheMux.Unlock()
log.WithField("count", len(userCache)).Info("User cache refreshed")
}
func validateUser(username, password string) bool {
cacheMux.RLock()
defer cacheMux.RUnlock()
u, ok := userCache[username]
return ok && u.Password == password
}
func getOrgID(username string) (string, error) {
cacheMux.RLock()
defer cacheMux.RUnlock()
u, ok := userCache[username]
if !ok {
return "", fmt.Errorf("user not found")
}
return u.OrgID, nil
}
func (p *PostgresStore) LoadAll() (map[string]User, error) {
tmp := make(map[string]User)
rows, err := p.db.Query("SELECT username, password, org_id FROM users")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var u User
if err := rows.Scan(&u.Username, &u.Password, &u.OrgID); err != nil {
continue
}
tmp[u.Username] = u
}
return tmp, nil
}
func (r *RqliteStore) LoadAll() (map[string]User, error) {
tmp := make(map[string]User)
rs, err := r.conn.QueryOne("SELECT username, password, org_id FROM users")
if err != nil {
return nil, err
}
for rs.Next() {
var u User
if err := rs.Scan(&u.Username, &u.Password, &u.OrgID); err != nil {
log.WithError(err).Warn("failed to scan row")
continue
}
tmp[u.Username] = u
}
return tmp, nil
}
func handlePrometheusQuery(w http.ResponseWriter, r *http.Request) {
if !authenticate(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
username, _, ok := r.BasicAuth()
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
orgID, err := getOrgID(username)
if err != nil {
http.Error(w, "Failed to retrieve org_id", http.StatusInternalServerError)
return
}
backendURL := *mimirURL
backendURL.Path = path.Join("/prometheus", strings.TrimPrefix(r.URL.Path, "/prometheus"))
backendURL.RawQuery = r.URL.RawQuery
backendReq, err := http.NewRequestWithContext(r.Context(), r.Method, backendURL.String(), r.Body)
if err != nil {
http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
log.WithError(err).Error("error creating backend request")
return
}
backendReq.Header = r.Header.Clone()
backendReq.SetBasicAuth(mimirUser, mimirPass)
if orgID != "" {
backendReq.Header.Set("X-Scope-OrgID", orgID)
}
resp, err := httpClient.Do(backendReq)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to query Mimir backend: %v", err), http.StatusBadGateway)
log.WithError(err).Error("error querying Mimir")
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
http.Error(w, fmt.Sprintf("Mimir returned error: %s (Status Code: %d)", string(body), resp.StatusCode), http.StatusBadGateway)
log.WithFields(logrus.Fields{
"status": resp.StatusCode,
"body": string(body),
"username": username,
"org_id": orgID,
}).Error("Mimir returned error on /prometheus query")
return
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func handlePush(w http.ResponseWriter, r *http.Request) {
if !authenticate(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
username, _, ok := r.BasicAuth()
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
orgID, err := getOrgID(username)
if err != nil {
http.Error(w, "Failed to retrieve org_id", http.StatusInternalServerError)
return
}
data, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
defer r.Body.Close()
var (
compressedData []byte
contentType string
)
if isLikelyText(data) {
log.Info("Received text/plain metric, parsing…")
writeReq, err := parseRawMetric(string(data), username, orgID)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to parse metric: %v", err), http.StatusBadRequest)
log.WithFields(logrus.Fields{
"err": err,
"raw_data": string(data),
}).Warn("failed to parse metric")
return
}
serialized, err := proto.Marshal(writeReq)
if err != nil {
http.Error(w, "Failed to marshal protobuf message", http.StatusInternalServerError)
log.WithError(err).Error("error marshaling protobuf")
return
}
compressedData = snappy.Encode(nil, serialized)
contentType = "text/plain"
} else {
log.Info("Received snappy/protobuf metric, forwarding as-is…")
compressedData = data
contentType = "application/x-protobuf"
}
backendURL := *mimirURL
backendURL.Path = "/api/v1/push"
backendReq, err := http.NewRequestWithContext(r.Context(), "POST", backendURL.String(), bytes.NewReader(compressedData))
if err != nil {
http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
log.WithError(err).Error("error creating backend request")
return
}
backendReq.Header.Set("Content-Encoding", "snappy")
backendReq.Header.Set("Content-Type", contentType)
backendReq.SetBasicAuth(mimirUser, mimirPass)
if orgID != "" {
backendReq.Header.Set("X-Scope-OrgID", orgID)
}
resp, err := httpClient.Do(backendReq)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to push to Mimir backend: %v", err), http.StatusBadGateway)
log.WithError(err).Error("error pushing to Mimir")
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
http.Error(w, fmt.Sprintf("Mimir returned error: %s (Status Code: %d)", string(body), resp.StatusCode), http.StatusBadGateway)
log.WithFields(logrus.Fields{
"status": resp.StatusCode,
"body": string(body),
"username": username,
"org_id": orgID,
}).Error("Mimir returned error")
return
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func parseRawMetric(raw, username, orgID string) (*pb.WriteRequest, error) {
raw = strings.TrimSpace(raw)
idxOpen := strings.Index(raw, "{")
if idxOpen == -1 {
return nil, fmt.Errorf("invalid format: missing '{'")
}
idxClose := strings.Index(raw, "}")
if idxClose == -1 || idxClose < idxOpen {
return nil, fmt.Errorf("invalid format: missing '}'")
}
metricName := strings.TrimSpace(raw[:idxOpen])
labelPart := raw[idxOpen+1 : idxClose]
rest := strings.TrimSpace(raw[idxClose+1:])
parts := strings.Fields(rest)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid format: missing value or timestamp")
}
value, err := strconv.ParseFloat(parts[0], 64)
if err != nil {
return nil, fmt.Errorf("invalid metric value: %v", err)
}
timestamp, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid timestamp: %v", err)
}
labels := []pb.Label{{Name: "__name__", Value: metricName}}
if labelPart != "" {
labelPairs := strings.Split(labelPart, ",")
for _, pair := range labelPairs {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
return nil, fmt.Errorf("invalid label format: %s", pair)
}
key := strings.TrimSpace(kv[0])
val := strings.Trim(strings.TrimSpace(kv[1]), "\"")
labels = append(labels, pb.Label{Name: key, Value: val})
}
}
labels = append(labels, pb.Label{Name: "username", Value: username})
if orgID != "" {
labels = append(labels, pb.Label{Name: "org_id", Value: orgID})
}
ts := pb.TimeSeries{
Labels: labels,
Samples: []pb.Sample{{Value: value, Timestamp: timestamp}},
}
return &pb.WriteRequest{
Timeseries: []pb.TimeSeries{ts},
}, nil
}
func authenticate(r *http.Request) bool {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") {
return false
}
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic "))
if err != nil {
return false
}
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
return false
}
username, password := parts[0], parts[1]
return validateUser(username, password)
}
func isLikelyText(data []byte) bool {
max := len(data)
if max > 512 {
max = 512
}
for i := 0; i < max; i++ {
c := data[i]
if (c < 32 && c != 9 && c != 10 && c != 13) || c > 126 {
return false
}
}
return true
}
func initLogger() {
levelStr := strings.ToLower(os.Getenv("LOG_LEVEL"))
if levelStr == "" {
levelStr = "info"
}
level, err := logrus.ParseLevel(levelStr)
if err != nil {
level = logrus.InfoLevel
}
log.SetLevel(level)
log.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: time.RFC3339,
})
log.SetOutput(os.Stdout)
log.WithField("level", level).Info("log level set")
}
func validateEnvVars(vars ...string) {
for _, v := range vars {
if os.Getenv(v) == "" {
log.WithField("var", v).Fatal("Environment variable must be set")
}
}
}