From 3c0a1c505cdd7b662ad7212bbea3ce3c78371d37 Mon Sep 17 00:00:00 2001 From: Triston Armstrong Date: Sat, 11 Jan 2025 14:43:21 -0500 Subject: [PATCH] refactor some things --- src/commit_handler.rs | 7 +++- src/git_grabber.rs | 31 +++++++------- src/main.rs | 55 ++++++------------------- src/mind_bridge.rs | 1 - src/{transporter.rs => ml_interface.rs} | 33 +++++++-------- src/pr_handler.rs | 7 +++- 6 files changed, 56 insertions(+), 78 deletions(-) delete mode 100644 src/mind_bridge.rs rename src/{transporter.rs => ml_interface.rs} (80%) diff --git a/src/commit_handler.rs b/src/commit_handler.rs index 3475b72..5d46159 100644 --- a/src/commit_handler.rs +++ b/src/commit_handler.rs @@ -1,7 +1,10 @@ +use crate::git_grabber::GitGrabber; + pub struct CommitHandler {} impl CommitHandler { - pub fn new() -> Option { - Some(String::from("create a short commit message from this diff, with format Task(): . only respond with the commit message")) + pub fn new() -> Option<(String, String)> { + let dirs = String::from("create a short commit message from this diff, with format Task(): . only respond with the commit message"); + Some((dirs, GitGrabber::get_diff())) } } diff --git a/src/git_grabber.rs b/src/git_grabber.rs index 1f4ff3f..14f61e9 100644 --- a/src/git_grabber.rs +++ b/src/git_grabber.rs @@ -13,25 +13,28 @@ impl GitGrabber { GitGrabber {} } - pub fn get_diff(&self) -> String { - // just to print - let staged_files_output = Command::new("git") - .args(["diff", "--staged", "--stat"]) + pub fn get_current_branch() -> String { + let branch = Command::new("git") + .args(["branch", "--show-current"]) .output() - .expect("Failed to get diff"); - let _ = std::io::stdout().write_all(&staged_files_output.stdout); + .expect("Failed to get branch"); - // actual output - let output = Command::new("git") - .args(["diff", "--staged", "--", ".", "':(exclude)*lock*'"]) - .output() - .expect("Failed to execute process"); - - let b = from_utf8(&output.stdout).unwrap(); + let b = from_utf8(&branch.stdout).unwrap(); String::from(b) } - pub fn generate_repo_desc(&self, origin_branch: &str, local_branch: &str) -> String { + pub fn get_diff() -> String { + // just to print + let staged_files_output = Command::new("git") + .args(["diff", "--staged"]) + .output() + .expect("Failed to get diff"); + + let b = from_utf8(&staged_files_output.stdout).unwrap(); + String::from(b) + } + + pub fn generate_repo_desc(origin_branch: &str, local_branch: &str) -> String { let output = Command::new("git") .args([ "rev-list", diff --git a/src/main.rs b/src/main.rs index 3577747..e9d7bd9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,16 @@ mod arg_parser; mod commit_handler; +mod git_grabber; +mod ml_interface; mod pr_handler; -// mod git_grabber; -// mod mind_bridge; -// mod transporter; use arg_parser::ArgParser; use commit_handler::CommitHandler; +use ml_interface::{MlBody, MlInterface, MlResponse}; use pr_handler::PrHandler; -// use core::panic; -// use ferrum_input::ArgParser; -// use git_grabber::GitGrabber; -// use mind_bridge::*; -// use transporter::Transporter; fn main() { - let prompt: Option = match ArgParser::parse() { + let prompt: Option<(String, String)> = match ArgParser::parse() { Some(arg_parser::ParsedArg::Commit) => CommitHandler::new(), Some(arg_parser::ParsedArg::PullRequest) => PrHandler::new(), None => None, @@ -26,37 +21,13 @@ fn main() { return; } - println!("{prompt:?}") - - // let mut transporter = Transporter::new(); - // let gg = GitGrabber::new(); - // let diff = gg.get_diff(); - // let commits_msg = String::from("create a short commit message from this diff, with format Task(): . only respond with the commit message"); - // let mind_gen_text = MindGen::new(commits_msg, format!("input: {}; branch: {}", diff, "dev")); - // let res_text = transporter - // .make_request(mind_gen_text) - // .unwrap() - // .text() - // .unwrap(); - // let response: Result = serde_json::from_str(&res_text); - // if response.is_err() { - // panic!("oop something went wrong: {:?}", response.err()); - // } - // println!("{:#?}", response.unwrap().response); - - // let z = gg.generate_repo_desc("master", "dev"); - // let pr_msg = String::from( - // "create a pull request description in markdown from the following revlog output. Only respond with pr description, nothing else.", - // ); - // let mind_gen_repo = MindGen::new(pr_msg, z); - // let repo_text = transporter - // .make_request(mind_gen_repo) - // .unwrap() - // .text() - // .unwrap(); - // let repo_response: Result = serde_json::from_str(&repo_text); - // if repo_response.is_err() { - // panic!("oop something went wrong: {:?}", repo_response.err()); - // } - // println!("{:#?}", repo_response.unwrap().response); + let mut ml = MlInterface::new(); + let (directions, content) = prompt.unwrap(); + let mind_gen_text = MlBody::new(content, directions); + let res_text = ml.make_request(mind_gen_text).unwrap().text().unwrap(); + let response: Result = serde_json::from_str(&res_text); + if response.is_err() { + panic!("oop something went wrong: {:?}", response.err()); + } + println!("{:#?}", response.unwrap()); } diff --git a/src/mind_bridge.rs b/src/mind_bridge.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/mind_bridge.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/transporter.rs b/src/ml_interface.rs similarity index 80% rename from src/transporter.rs rename to src/ml_interface.rs index e982a59..96bf7b4 100644 --- a/src/transporter.rs +++ b/src/ml_interface.rs @@ -1,18 +1,12 @@ -use crate::MindGen; use reqwest::blocking::Client; use serde::{Deserialize, Serialize}; #[allow(unused)] pub static OLLAMA_ENDP: &str = "http://localhost:11434/api/generate"; -#[allow(unused)] -pub struct Transporter { - pub client: Client, -} - #[derive(Debug, Deserialize)] #[allow(unused)] -pub struct GenRes { +pub struct MlResponse { model: String, created_at: String, pub response: String, @@ -26,7 +20,7 @@ pub struct GenRes { } #[derive(Debug, Serialize)] -struct GenOptions { +struct MlOptions { temperature: f32, num_predict: u8, repeat_last_n: u8, @@ -35,26 +29,26 @@ struct GenOptions { } #[derive(Debug, Serialize)] -pub struct MindGen { +pub struct MlBody { model: String, prompt: String, stream: bool, raw: bool, system: String, - options: GenOptions, + options: MlOptions, } -impl MindGen { +impl MlBody { #[allow(unused)] - pub fn new(directions: String, input: String) -> Self { + pub fn new(content: String, directions: String) -> Self { Self { model: String::from("llama3.1"), stream: false, raw: false, - prompt: input, + prompt: content, system: directions, - options: GenOptions { - temperature: 0.1, + options: MlOptions { + temperature: 0.5, num_predict: 0, repeat_last_n: 0, top_k: 10, @@ -65,7 +59,12 @@ impl MindGen { } #[allow(unused)] -impl Transporter { +pub struct MlInterface { + pub client: Client, +} + +#[allow(unused)] +impl MlInterface { #[allow(unused)] pub fn new() -> Self { Self { @@ -76,7 +75,7 @@ impl Transporter { #[allow(unused)] pub fn make_request( &mut self, - gen_data: MindGen, + gen_data: MlBody, ) -> Result { let json_body = serde_json::to_string(&gen_data).unwrap(); self.client.post(OLLAMA_ENDP).body(json_body).send() diff --git a/src/pr_handler.rs b/src/pr_handler.rs index 2f6885f..ae5c931 100644 --- a/src/pr_handler.rs +++ b/src/pr_handler.rs @@ -1,6 +1,9 @@ +use crate::git_grabber::GitGrabber; + pub struct PrHandler {} impl PrHandler { - pub fn new() -> Option { - Some("Pull Request Handler".to_string()) + pub fn new() -> Option<(String, String)> { + let dirs = String::from("Pull Request Handler"); + Some((dirs, GitGrabber::generate_repo_desc("", ""))) } }