main.go (2874B)
1 package main 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "flag" 8 "fmt" 9 "log" 10 "net/http" 11 "os" 12 "strings" 13 ) 14 15 // Config represents the configuration for the program. 16 type Config struct { 17 Lang string `json:"lang"` 18 Endpoint string `json:"endpoint"` 19 Token string `json:"token"` 20 Parameters map[string]any `json:"parameters"` 21 Options map[string]any `json:"options"` 22 } 23 24 func main() { 25 flag.Parse() 26 27 args := flag.Args() 28 if len(args) < 1 { 29 log.Fatalln("Usage: addgen [source.yaml]") 30 } 31 32 // Read the configuration file 33 data, err := os.ReadFile("add.json") 34 if err != nil { 35 log.Fatal(err) 36 } 37 38 var config Config 39 err = json.Unmarshal(data, &config) 40 if err != nil { 41 log.Fatal(err) 42 } 43 44 // Verify the configuration 45 if len(config.Lang) == 0 || len(config.Endpoint) == 0 || len(config.Token) == 0 { 46 log.Fatalf("Invalid configuration: lang=%q, endpoint=%q, token=%q\n", config.Lang, config.Endpoint, config.Token) 47 } 48 49 // Read the prompt file 50 prompt, err := os.ReadFile("prompt.txt") 51 if err != nil { 52 log.Fatal(err) 53 } 54 55 // Read the yaml file 56 inputData, err := os.ReadFile(args[0]) 57 if err != nil { 58 log.Fatal(err) 59 } 60 61 input := appendToPrompt( 62 // Replace placeholders in the prompt 63 strings.ReplaceAll(string(prompt), "{{ content }}", string(inputData)), 64 65 // Append the language identifier to the prompt 66 fmt.Sprintf("```%s\n", config.Lang), 67 ) 68 69 // Query the endpoint repeatedly until there is no more results 70 for { 71 result, err := query(config, input) 72 if err != nil { 73 log.Fatal(err) 74 } 75 if len(result) == 0 { 76 break 77 } 78 79 input = appendToPrompt(input, result) 80 } 81 } 82 83 func appendToPrompt(prompt string, s string) string { 84 fmt.Print(s) 85 return prompt + s 86 } 87 88 func query(config Config, prompt string) (string, error) { 89 // Create the HTTP client 90 client := http.Client{} 91 92 var queryParams = map[string]any{ 93 "inputs": prompt, 94 "parameters": config.Parameters, 95 "options": config.Options, 96 } 97 98 body, err := json.Marshal(queryParams) 99 if err != nil { 100 return "", nil 101 } 102 103 // Set up the request 104 req, err := http.NewRequest("POST", config.Endpoint, bytes.NewBuffer(body)) 105 if err != nil { 106 return "", err 107 } 108 109 // Add the authorization header 110 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.Token)) 111 req.Header.Set("Content-Type", "application/json") 112 113 // Make the request 114 resp, err := client.Do(req) 115 if err != nil { 116 return "", err 117 } 118 119 defer resp.Body.Close() 120 121 // Check the status code 122 if resp.StatusCode != http.StatusOK { 123 return "", errors.New(fmt.Sprintf("Unexpected status code: %d", resp.StatusCode)) 124 } 125 126 // Decode the response 127 var response []struct { 128 GeneratedText string `json:"generated_text"` 129 } 130 131 err = json.NewDecoder(resp.Body).Decode(&response) 132 if err != nil { 133 return "", err 134 } 135 136 // Return the generated text 137 return response[0].GeneratedText, nil 138 }